Compare commits

...

71 Commits

Author SHA1 Message Date
Evan
e8078f5a0e wow 2026-01-20 15:07:03 +00:00
rltakashige
8b709e68b2 Mark slow tests as slow (#1220)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 15:03:46 +00:00
Evan Quiney
4da6eeb11f fix a test broken by #1204 (#1219)
bad merge broke a test - fix it
2026-01-20 14:56:20 +00:00
Evan
3d2eee4884 quiet localhost log
this log is just noise - remove it
2026-01-20 14:51:26 +00:00
Evan
116558839e don't clear mdns discovered connections
pingers currently removes mdns discovered connections - these systems
should be independent
2026-01-20 14:46:20 +00:00
Evan Quiney
d4f551c602 Simplify model cards (#1204)
## Motivation

We have a lot of unneeded data in the model card - lets just keep the
necessary stuff and add back more data when we need it

## Test Plan

EXO still runs! (pipeline on 2)

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-20 11:01:19 +00:00
Alex Cheema
176ab5ba40 Add GLM-4.7-Flash model cards (4bit, 5bit, 6bit, 8bit) (#1214)
## Motivation

Add support for GLM-4.7-Flash, a lighter variant of GLM-4.7 with the
`glm4_moe_lite` architecture. These models are smaller and faster while
maintaining good performance.

## Changes

1. **Added 4 new model cards** for GLM-4.7-Flash variants:
   - `glm-4.7-flash-4bit` (~18 GB)
   - `glm-4.7-flash-5bit` (~21 GB)
   - `glm-4.7-flash-6bit` (~25 GB)
   - `glm-4.7-flash-8bit` (~32 GB)

   All variants have:
   - `n_layers`: 47 (vs 91 in GLM-4.7)
   - `hidden_size`: 2048 (vs 5120 in GLM-4.7)
   - `supports_tensor`: True (native `shard()` method)

2. **Bumped mlx from 0.30.1 to 0.30.3** - required by mlx-lm 0.30.4

3. **Updated mlx-lm from 0.30.2 to 0.30.4** - adds `glm4_moe_lite`
architecture support

4. **Added type ignores** in `auto_parallel.py` for stricter type
annotations in new mlx-lm

5. **Fixed EOS token IDs** for GLM-4.7-Flash - uses different tokenizer
with IDs `[154820, 154827, 154829]` vs other GLM models' `[151336,
151329, 151338]`

6. **Renamed `MLX_IBV_DEVICES` to `MLX_JACCL_DEVICES`** - env var name
changed in new mlx

## Why It Works

The model cards follow the same pattern as existing GLM-4.7 models.
Tensor parallel support is enabled because GLM-4.7-Flash implements the
native `shard()` method in mlx-lm 0.30.4, which is automatically
detected in `auto_parallel.py`.

GLM-4.7-Flash uses a new tokenizer with different special token IDs.
Without the correct EOS tokens, generation wouldn't stop properly.

## Test Plan

### Manual Testing
Tested generation with GLM-4.7-Flash-4bit - now correctly stops at EOS
tokens.

### Automated Testing
- `basedpyright`: 0 errors
- `ruff check`: All checks passed
- `pytest`: 162/162 tests pass (excluding pre-existing
`test_distributed_fix.py` timeout failures)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 03:58:09 +00:00
rltakashige
f5e6aa82d2 Load layers individually (#1211)
## Motivation

Certain models hang at model loading in tensor parallel. 

Hopefully closes #1205 

## Changes

- Load layer by layer for tensor parallel sharding
- Move eval_with_timeout to auto_parallel.py to resolve circular import.

## Why It Works

The naive way to fix this is to use load model with lazy = False and
then shard in tensor parallel. However, this requires the entire model
to be loaded into memory.

Instead, we can load layer by layer and shard after loading. There is a
very small memory footprint to this, but it is negligible.

I tried loading layer by layer after the sharding, and this allowed
model loading but got stuck at warming up.

## Test Plan

### Manual Testing
GPT OSS loads with TP and FAST SYNCH. Kimi does too.

### Automated Testing
We need to run a suite of exo_bench before merging this!
2026-01-20 03:26:51 +00:00
Alex Cheema
39f0ed6018 Prepend <think> tag to stream for thinking models like GLM-4.7 (#1186)
## Motivation

For thinking models like GLM-4.7, the `<think>` tag is inserted by the
tokenizer's `apply_chat_template()` into the **prompt** (input). The
model generates tokens starting *after* this tag, so `<think>` never
appears in the streamed output. The frontend expects
`<think>...</think>` tags to extract and display thinking content.

**Log evidence:**
```
[gMASK]<sop><|system|>...<|user|>...<|assistant|><think>
```
The prompt ends with `<think>`, so the model generates content after it,
never returning the opening tag.

## Changes

- Added `detect_thinking_prompt_suffix()` helper function in
`utils_mlx.py` to detect if a prompt ends with `<think>` tag
- Added `parse_thinking_models()` generator wrapper in `runner.py` that
prepends the thinking tag to the output stream
- Modified the main generation loop to use the thinking wrapper for
non-GptOssModel models when a thinking prefix is detected
- Updated test mocks to handle the new `apply_chat_template` call

## Why It Works

The solution follows the same pattern as `parse_gpt_oss()` - a generator
wrapper that transforms the output stream. When the chat template ends
with `<think>`, we prepend this tag to the first generated token so the
frontend receives the complete `<think>...</think>` structure it
expects.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Run exo: `uv run exo`
- Send a chat request to GLM-4.7:
  ```bash
curl http://localhost:52415/v1/chat/completions -H "Content-Type:
application/json" -d '{
    "model": "mlx-community/GLM-4.7-8bit-gs32",
    "messages": [{"role": "user", "content": "What is 2+2?"}],
    "stream": true
  }'
  ```
- Verify the streamed response starts with `<think>` tag
- Verify the frontend dashboard correctly shows the thinking section
collapsed

### Automated Testing
- All 72 worker tests pass: `uv run pytest src/exo/worker/`
- Type checker passes: `uv run basedpyright`
- Linter passes: `uv run ruff check`

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-01-19 19:44:51 +00:00
Alex Cheema
ee43b598fe Split NodePerformanceProfile into granular state mappings (#1209)
## Motivation

The current `NodePerformanceProfile` is a monolithic object where every
update (even 1-second memory updates) replaces the entire profile,
touching unrelated data. Different fields update at vastly different
frequencies:

| Data | Update Frequency |
|------|------------------|
| Memory, System | 1 second |
| Thunderbolt | 5 seconds |
| Network interfaces | 10 seconds |
| Friendly name | 60 seconds |
| Model/Chip ID | Once at startup |

## Changes

Split into separate state mappings so each data type updates
independently:

- `node_identities`: Static and slow-changing data (model_id, chip_id,
friendly_name)
- `node_memory`: RAM and swap usage
- `node_system`: GPU usage, temperature, power, CPU metrics
- `node_network`: Network interface information
- `node_thunderbolt`: Thunderbolt interface identifiers

Added a backwards-compatible `node_profiles` property that reconstructs
`NodePerformanceProfile` from the granular mappings for dashboard
compatibility.

**Files modified:**
- `src/exo/shared/types/profiling.py` - Added `NodeIdentity`,
`NodeNetworkInfo`, `NodeThunderboltInfo` types
- `src/exo/shared/types/state.py` - Added 5 new mappings +
`node_profiles` property
- `src/exo/shared/apply.py` - Updated `apply_node_gathered_info` and
`apply_node_timed_out`

## Why It Works

Each info type now writes only to its specific mapping, avoiding
unnecessary updates to unrelated data. The `MacThunderboltConnections`
handler reads from `node_thunderbolt` instead of the old `node_profiles`
for RDMA connection mapping. The backwards-compatible property ensures
the dashboard continues to work unchanged.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Start exo and verify dashboard shows node info
- Verify memory/GPU updates stream correctly
- Check that node timeout properly cleans up all mappings

### Automated Testing
- All 162 existing tests pass
- basedpyright: 0 errors
- ruff check: All checks passed
- nix fmt: Applied

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 18:24:15 +00:00
rltakashige
5fd55594c9 Wrap pipeline models for explicit mx.depends between cache and logits (#1206)
## Motivation

GPU timeouts often when prompt size > profile_step_size. It also happens
for seemingly random models.

## Changes

Add mx.depends for cache on the logits.
All gather at the model level rather than the layer level, reducing the
amount of data sent.

## Why It Works

mlx_lm's prefill loop only evaluates cache state, not logits.
When prompt > prefill_step_size, the all_gather is never evaluated,
causing GPU timeout.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
Added failing test cases and then resolved them.
2026-01-19 17:49:42 +00:00
Jake Hillion
5ab1f8b3e2 NetworkSetupHelper: detect stale startup script content
The daemonAlreadyInstalled() function checked that the startup script
file existed and validated plist properties, but did not compare the
actual script content. If the setupScript constant was updated in a new
app version, the stale on-disk script would not be detected or replaced.

Added a guard clause that reads the installed script from disk and
compares it against the expected setupScript content (with whitespace
normalization). When content differs, the function returns false,
triggering the reinstallation flow with an admin privileges prompt.

Test plan:
- Installed on a cluster that had the previous network config. Got the
  popup asking for permissions. After accepting I could run Kimi K2
  Thinking Tensor RDMA on all 4 nodes.
2026-01-19 17:36:15 +00:00
Evan Quiney
2202685c3e refactor all information sources (including ipless rdma discovery) (#928)
## Motivation

Information gathering is tightly coupled to MacMon - we should start
generalizing our information sources so we can add more in future.

## Changes

Added a new system to gather any information. Currently, it is attached
to the Worker - though this is mostly to keep the data processing logic
simple. It could be made independent quite easily.

I also refactored topology to include different kinds of connections as
we can gather RDMA connections without having a pre-existing socket
connection, and made the relevant placement updates. We should no longer
need the network locations script in the app.

Other sources of information now include:
- static node information like "model" and "chip" (macos, "Unknown"
fallback)
- device friendly name (macos, falls back to device hostname)
- network interfaces + ips (cross platform)
- thunderbolt interfaces (macos)
- thunderbolt connections (macos)
- RAM usage (cross platform)
- per-device configuration written to EXO_HOME/config.toml

## Limitations

Model and Chip are not cross platform concepts.

We do not differentiate between unified and non-unified memory systems.

A lot of this data collection is based on simple timers. Watching the SC
store on macos is the correct way to gather some of this information,
but requires a detour into rust for macos.

## Why It Works

The InfoGatherer is a generic subsystem which returns a union of metric
datatypes. It writes them to an event, which is applied to state. It is
currently re-spawned with the worker so each cluster receives the
correct information.

As for topology, macOS identifies TB ports with a uuid in
SPThunderboltDataType, and also stores remote uuids if it can find them.
These changes read that data with the system_profiler, hopefully not so
often as to cause notable performance impacts (though this should be
tuned) but frequently enough for moderate responsiveness.
As we can identify TB connections between devices without needing ips
attached to each interface, we can remove the network setup script
(almost) completely.

## Test Plan

### Manual Testing
Spawn RDMA instances without enabling DHCP on the RDMA interfaces.

### Automated Testing
Updated the current master and shared tests to cover the topology
refactor and new events.

---------

Co-authored-by: Sami Khan <smsak99@gmail.com>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Jake Hillion <jake@hillion.co.uk>
2026-01-19 16:58:09 +00:00
Andrei Onel
ce3ad391b1 Update README.md with some changes from release 1.0.61 (#1157)
Updated README.md with documentation for four new features:

- added a "Benchmarking" section documenting the exo-bench tool for
measuring model performance across different placement configurations
- documented the custom namespace feature for cluster isolation in the
macOS app section
- added a "Configuration Options" subsection explaining the --no-worker
CLI flag for coordinator-only nodes
- added a "File Locations (Linux)" subsection documenting XDG Base
Directory Specification compliance on Linux systems

Issue #930
2026-01-19 16:43:18 +00:00
Jake Hillion
fb0151630d shard_downloader: make on_progress callback async
The on_progress callback was synchronous but always invoked from async
contexts, forcing the use of send_nowait() which could raise WouldBlock
if the channel buffer was full, potentially dropping progress updates.

Changed the callback type from `Callable[[ShardMetadata,
RepoDownloadProgress], None]` to return a coroutine, updated all
implementations to be async, and replaced send_nowait() with await
send() in the worker's download progress handler.

This allows proper backpressure handling when sending download progress
events through the channel, eliminating the "Footgun!" that was
previously documented in the code.

Test plan:
- Built a DMG and ran it on one node. All existing models showed as
  downloaded.
- Downloaded a new model. The progress bar on the download page worked.
- Downloaded another new model. The progress bar on the home page
  worked.
2026-01-19 16:19:37 +00:00
Alex Cheema
346b13e2c9 Enhance LaTeX rendering in dashboard markdown (#1197)
## Motivation

When models output LaTeX-formatted math proofs, the dashboard was not
rendering them correctly. Issues included:
- `\documentclass`, `\begin{document}`, `\usepackage` showing as raw
text
- `$...$` inline math with complex expressions (like `\frac`, `\ldots`)
not rendering due to markdown escaping backslashes
- `\begin{align*}...\end{align*}` and other math environments showing as
raw text
- `\emph{...}`, `\textbf{...}` LaTeX formatting commands not being
converted
- `$\require{...}$` (MathJax-specific) causing KaTeX errors
- `\begin{proof}...\end{proof}` showing as raw text

## Changes

Enhanced `MarkdownContent.svelte` with comprehensive LaTeX support:

**Math extraction before markdown processing:**
- Extract `$...$`, `$$...$$`, `\(...\)`, `\[...\]` into placeholders
before markdown processes the text
- Use alphanumeric placeholders (`MATHPLACEHOLDERINLINE0END`) that won't
be interpreted as HTML tags
- Restore and render with KaTeX after markdown processing

**LaTeX document command removal:**
- Strip `\documentclass{...}`, `\usepackage{...}`, `\begin{document}`,
`\end{document}`
- Strip `\maketitle`, `\title{...}`, `\author{...}`, `\date{...}`
- Strip `\require{...}` (MathJax-specific, not KaTeX)
- Replace `tikzpicture` environments with `[diagram]` placeholder
- Strip `\label{...}` cross-reference commands

**LaTeX math environments:**
- Convert `\begin{align*}`, `\begin{equation}`, `\begin{gather}`, etc.
to display math blocks

**LaTeX text formatting:**
- `\emph{...}` and `\textit{...}` → `<em>...</em>`
- `\textbf{...}` → `<strong>...</strong>`
- `\texttt{...}` → `<code>...</code>`
- `\underline{...}` → `<u>...</u>`

**LaTeX environments styling:**
- `\begin{proof}...\end{proof}` → styled proof block with QED symbol
- `\begin{theorem}`, `\begin{lemma}`, etc. → styled theorem blocks

**Display math enhancements:**
- Wrapped in styled container with subtle gold border
- "LaTeX" label and copy button appear on hover
- Dark theme KaTeX color overrides for better readability
- Custom scrollbar for overflow

## Why It Works

The key insight is that markdown processing was escaping backslashes in
LaTeX before KaTeX could see them. By extracting all math expressions
into alphanumeric placeholders *before* markdown runs, then restoring
them *after*, the LaTeX content passes through to KaTeX unmodified.

Using purely alphanumeric placeholders like `MATHPLACEHOLDERINLINE0END`
instead of `<<MATH_INLINE_0>>` prevents markdown from interpreting them
as HTML tags and stripping them.

## Test Plan

### Manual Testing
- Hardware: Any machine with the dashboard
- What you did:
  - Ask model to "write a proof in latex"
  - Verify inline math like `$x \in S$` renders correctly
- Verify display math like `\begin{align*}...\end{align*}` renders as
block
  - Verify `\documentclass`, `\begin{document}` are stripped (not shown)
  - Verify `\emph{...}` converts to italics
  - Verify copy button works on display math blocks
- Test edge cases: `$5` (currency) stays as text, `\$50` (escaped)
becomes `$50`

Before:
<img width="799" height="637" alt="Screenshot 2026-01-19 at 11 51 22 AM"
src="https://github.com/user-attachments/assets/62a705b8-b3c2-47b8-afd0-5d0c1b240e44"
/>

After:
<img width="809" height="642" alt="Screenshot 2026-01-19 at 11 46 58 AM"
src="https://github.com/user-attachments/assets/4f35fa1d-333c-4285-bc68-58a50f8f148e"
/>


### Automated Testing
- Dashboard builds successfully with `npm run build`
- Existing functionality preserved

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:50:41 +00:00
rltakashige
ea0588429b Custom mlx layer composition (#1201)
## Motivation

With a single pipeline layer, PipelineFirstLayer gets composed with
PipelineLastLayer.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing


### Automated Testing
Made failing tests. Fixed them!
2026-01-19 12:36:25 +00:00
rltakashige
73b3f87e07 Set swa_idx and ga_idx for single layer (#1202)
## Motivation

Layer types does not contain either "sliding_attention" or
"full_attention" for pipeline parallel (single layer).

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Manually tested single layer of GPT OSS. Doesn't crash

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-19 12:31:11 +00:00
Evan Quiney
746589ba6b tidy: remove context manager from api (#1199) 2026-01-19 11:58:13 +00:00
rltakashige
f82f862fd7 Fix several issues with placement (#1200)
## Motivation

Uneven placements were causing issues for some users with lopsided
setups. While fixing, I ran into another issue with impossible
allocation of memory.

## Changes

- Allocate at least 1 layer per device.
- Catch overallocation of memory with an error.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Tested that GPT OSS is placed correctly.

### Automated Testing
Added breaking tests in the first commit. Resolved with new placement
algorithm in the second one.
2026-01-19 11:52:35 +00:00
Alex Cheema
7ff937d8a1 Add dashboard screenshots to README (#1185)
## Motivation

The README showcases exo's features and benchmarks but doesn't show what
the dashboard actually looks like. Adding a screenshot helps users
understand what they'll get when they run exo.

## Changes

- Added dashboard screenshot to `docs/imgs/dashboard-cluster-view.png`:
Shows the cluster topology view with 4 × 512GB M3 Ultra Mac Studio
running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)
- Added a new "Dashboard" section to README.md below Features,
displaying the screenshot with caption

## Why It Works

Visual documentation helps users understand what exo offers before they
install it. The screenshot demonstrates the cluster management
capabilities.

## Test Plan

### Manual Testing
- Verified image renders correctly in GitHub markdown preview

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 10:43:27 +00:00
Evan Quiney
d19bf02404 re-raise exceptions in the runner (#1198)
## Motivation

Runners that crash can swallow errors - we should re-raise. Also the
exception handler annoyed me.

## Changes

The try: except in the runner's chat now re-raises.
2026-01-19 10:35:23 +00:00
rltakashige
618cee5223 Resolve test event ordering flakiness (#1194)
## Motivation

mp sender occasionally does not have time to flush its events before
collect() is called, making the event ordering test fail.

## Changes

- Replace mp_channel with simple collector for event ordering test
- Also suppress warning for <frozen importlib._bootstrap>:488 <frozen
importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject
has no __module__ attribute


## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
Ran the test 100 times without it failing.
2026-01-18 20:33:20 +00:00
Antonio Lujano Luna
9c29eb7d48 Add proxy and custom SSL certificate support for corporate networks (#1189)
Support HTTPS_PROXY/HTTP_PROXY environment variables for proxy
configuration and SSL_CERT_FILE for custom CA certificates, enabling use
in corporate environments with SSL inspection.

## Motivation
Users in corporate environments often need to route traffic through HTTP
proxies and use custom CA certificates for SSL inspection. Without this
support, exo cannot download models in these network configurations.

## Changes
- Added `HTTPS_PROXY`/`HTTP_PROXY` environment variable support to
`create_http_session()` in `download_utils.py`
- Added `SSL_CERT_FILE` environment variable support for custom CA
certificate bundles, falling back to certifi's default bundle

## Why It Works
- `aiohttp.ClientSession` natively supports the `proxy` parameter for
routing requests through HTTP proxies
- `ssl.create_default_context(cafile=...)` accepts a custom CA bundle
path, allowing corporate CAs to be trusted
- Using environment variables is consistent with the codebase's existing
configuration patterns (e.g., `EXO_HOME`, `HF_ENDPOINT`)

## Test Plan
### Manual Testing
- Set `HTTPS_PROXY` environment variable and verified model downloads
route through proxy
- Set `SSL_CERT_FILE` to custom CA bundle and verified SSL verification
succeeds with corporate SSL inspection

### Automated Testing
- No automated tests added; this change is configuration-only and does
not alter existing behavior when environment variables are unset
2026-01-18 12:05:50 +00:00
Alex Cheema
c5158bee53 Add pre-commit checks documentation to AGENTS.md (#1184)
## Motivation

CI failures can be avoided by running checks locally before committing.
This adds clear documentation to AGENTS.md so that AI agents (and
humans) know exactly which checks must pass before pushing code.

## Changes

Added a new "Pre-Commit Checks (REQUIRED)" section to AGENTS.md that:
- Lists all 4 required checks (basedpyright, ruff, nix fmt, pytest)
- Provides a one-liner to run all checks in sequence
- Notes that `nix fmt` changes must be staged before committing
- Explains that CI runs `nix flake check` which verifies everything

## Why It Works

Clear documentation prevents CI failures by ensuring contributors run
checks locally first. The one-liner command makes it easy to run all
checks before committing.

## Test Plan

### Manual Testing
- Verified the documented commands work correctly

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 21:50:24 +00:00
rltakashige
5c8a237940 Handle model timeouts (#1177)
- Add eval with a timeout.
- Add fast synch flag

## Motivation

Because of the experimental FAST SYNCH flag, some models may not work.
This PR catches when this occurs and allows users to specify a run
without fast synch

## Changes

- Adds a flag to enable or disable fast synch (--fast-synch and
--no-fast-synch)
- Adds a heuristic timeout
- Reduces exo_bench default timeout to 10 minutes.

## Why It Works

Heuristic timeout assumes normal loading times on Mac devices (60 +
model size in gb / 5: e.g. DeepSeek takes up to 120 seconds to load on
tensor parallel, and timeout is set to 60 + 120 = 180s.

We could raise this value if necessary.

## Test Plan

### Manual Testing
Catches that GPT OSS fails to load in Tensor RDMA
Can launch with --no-fast-synch flag to launch GPT OSS.

**GPT OSS 20B**
TP with fast synch
<img width="3064" height="456" alt="image"
src="https://github.com/user-attachments/assets/f6e25cd8-8621-4e99-99fe-292ee05c4035"
/>

TP without fast synch
<img width="3098" height="496" alt="image"
src="https://github.com/user-attachments/assets/d36453d9-6686-4cfe-aa7c-a7d458369d4d"
/>
[Note: the performance is really not great as fast synch is off]

(As a sanity check)
PP with fast synch
<img width="3124" height="496" alt="image"
src="https://github.com/user-attachments/assets/e97d4547-c6fa-483d-badb-4b371b900b4c"
/>

PP without fast synch
<img width="3078" height="508" alt="image"
src="https://github.com/user-attachments/assets/b2e20dfd-4b0e-4295-8a92-417dfe745c28"
/>

PP without RDMA
<img width="3070" height="498" alt="image"
src="https://github.com/user-attachments/assets/a8509d68-0aef-4cda-bca5-a67d39a0801e"
/>

TP without RDMA
<img width="3068" height="496" alt="image"
src="https://github.com/user-attachments/assets/b5691429-89f4-4369-bcf2-8fde2ad7154a"
/>
2026-01-16 20:25:12 +00:00
rltakashige
745343c705 Return error responses for Chat Completions (#1173)
- Error chunks
- Use error handling in exo_bench.py

## Motivation

Return when an error occurs so that generation stops. Adding timeouts is
a separate TODO for model loading and chat completions.

## Changes

- Return HTTP exceptions as JSON responses in an OpenAI compatible
format.
- Context manager for generation to catch and return error messages.
- Use error handling in exo_bench.py.

## Test Plan

### Manual Testing
Manually tested that exo_bench returns on failures within and outside
generation

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 19:24:37 +00:00
Alex Cheema
5e28664c41 Fix draft release detection (attempt 3) (#1176)
## Motivation

Previous fix still failed in CI. Suspecting permissions issue with
GITHUB_TOKEN not being able to see draft releases via API.

## Changes

1. Add explicit `permissions: contents: write` to the job
2. Use `gh release list` first to check if draft exists (this uses a
different code path that might work better)
3. Add debug echo statements

## Test Plan

Delete v1.0.63 tag and re-push after merging.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:26:06 +00:00
Alex Cheema
ae0a804ccb Fix draft release detection query (#1175)
## Motivation

Fixes the draft release detection that failed on the v1.0.63 release
attempt.

## Changes

The jq query was piped to `head -1` which truncated multi-line JSON
output to just `{`, causing the empty check to fail.

Changed to use `first // empty` in jq instead.

## Test Plan

Tested locally:
```bash
GITHUB_REF_NAME="v1.0.63"
gh api repos/exo-explore/exo/releases --jq "[.[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")] | first // empty"
# Returns the full draft release JSON (2711 chars)
```

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:05:24 +00:00
Alex Cheema
07cf2c1aa1 Add GitHub releases with Sparkle release notes integration (#1172)
## Motivation

Closes #1140

Currently releases are uploaded to S3 for Sparkle updates but there's no
GitHub Release created, and Sparkle update dialogs don't show release
notes. Users have no visibility into what changed.

## Changes

- Added release workflow documentation comment at top of `build-app.yml`
- Added "Fetch release notes for Sparkle" step that converts markdown
from draft GitHub release to HTML
- Added "Inject release notes into appcast" step that embeds HTML in
appcast.xml with CDATA
- Added "Publish GitHub Release" step that attaches DMG and publishes
the draft

## Why It Works

- Sparkle's `<description>` tag supports HTML wrapped in CDATA for
rendering in update dialogs
- GitHub's markdown API (`/markdown`) converts the release notes to HTML
with proper formatting
- Draft releases allow writing polished notes before the build, then the
workflow publishes them automatically
- The workflow fails if no draft release exists, ensuring release notes
are always provided

## Test Plan

### Manual Testing
1. Create a draft GitHub release for a new tag with markdown release
notes
2. Push the tag to trigger the workflow
3. Verify the GitHub release is published with DMG attached
4. Download appcast.xml from S3 and verify
`<description><![CDATA[...]]></description>` contains HTML
5. Test Sparkle update dialog on macOS to confirm release notes appear

### Automated Testing
No automated tests added - this is CI workflow configuration.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 16:47:33 +00:00
Evan
83c5285a80 reduce logs
previous commits logs were too verbose, this tones them down a bit
2026-01-16 14:05:47 +00:00
Evan Quiney
39ee2bf7bd switch from synchronous threaded pinging to an async implementation (#1170)
still seeing churn in our networking - lets properly rate limit it

## changes

added an httpx client with max connections with a persistent AsyncClient

## testing

deployed on cluster, discovery VASTLY more stable (the only deleted
edges were those discovered by mdns)
2026-01-16 13:20:03 +00:00
Sami Khan
991adfbd6f fix local network warning (#1136)
## Motivation

Local network warning banner was showing on fresh install even though
mDNS was working. The check would fail before the user had a chance to
grant permission via the macOS prompt.

## Changes

- Added `hasWorkedBefore` flag persisted in UserDefaults
- Only show warning if permission previously worked but now doesn't

## Why It Works

On fresh install, the check may fail (no permission yet), but
`hasWorkedBefore` is false so no warning shows. Once the user grants
permission and a check succeeds, we record it. Future failures (zombie
permission after restart) will show the warning since `hasWorkedBefore`
is now true.

## Test Plan

### Manual Testing
Run locally

### Automated Testing
N/A
2026-01-16 13:10:50 +00:00
rltakashige
4b3de6b984 Fix exo bench for transformers 5.x (#1168)
## Motivation
Prompt Sizer was broken as transformers 5.x tokenizers create
BatchEncodings which are essentially a dictionary of {input_ids: []}
instead of the list of input ids.

## Test Plan

### Manual Testing
Tested that exo bench runs as expected.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 12:39:22 +00:00
Evan
c8de3b90ea quiet rust logs
rust logs were too verbose - now only warnings propagate to python

entirely happy not to merge this and to clean up rust logging instead,
but this felt saner right now
2026-01-16 12:34:28 +00:00
Sami Khan
6e6567a802 resolve issue #1070 (#1076)
## Motivation

https://github.com/exo-explore/exo/issues/1070

## Changes

Added check in ChatForm.svelte to reset selectedChatModel when it no
longer matches any running instance.

## Why It Works

The $effect now detects when the selected model is stale (not in
availableModels()) and resets to the first available model.

## Test Plan

### Manual Testing

1. Create instance of Model A → Delete it → Create instance of Model B →
Chat
2. Verify request goes to Model B (not Model A)

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
2026-01-15 20:00:41 +00:00
rltakashige
a735dad667 Parse GPT OSS in runner (#1160)
## Motivation

Simplification of API + moving model specific code to the runner

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Test Plan

### Manual Testing
Tested that GPT OSS outputs are parsed correctly on the dashboard.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-15 19:53:55 +00:00
rltakashige
aaf4e36bc3 FIX GPT OSS (#1165)
## Motivation

Adds several unmerged fixes for GPT OSS.
Also adds GPT OSS 20B MXFP4 Q8 instead of Q4 for numerical stability (as
this is unstable for MLX LM too)
<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->


## Test Plan

### Manual Testing
Manually tested. No further gibberish responses.

### Automated Testing
Ran EXO Bench - pipeline, tensor and single node work on both 20B and
120B models
2026-01-15 19:20:17 +00:00
Evan Quiney
3e623ccf0d up http timeout to 3 seconds and retry on BadStatusLine (#1164)
we're seeing a lot of network churn - perhaps this is a connection
timing out issue? lets also re-try after a second

## testing
none yet

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 18:15:12 +00:00
Evan Quiney
c22dad8a7d dashboard: add peer: true to package lock (#1162)
this happens every time i run npm install - lets upstream it

## testing
dashboard builds and renders
2026-01-15 17:01:43 +00:00
Evan
4bc4d50685 rust: remove dead code
the system custodian has been made unnecessary with the swift app - we
can remove it

## testing
everything still builds
2026-01-15 16:51:46 +00:00
Jake Hillion
e0aab46fd8 model_cards.py: clean up commented out code
Clean up the commented out code and make sure the comments are unified.
Carrying around the commented out code means people making changes to
model_cards are supposed to update it, but that's not clear and won't be
picked up by type checking etc. Drop it for now - it's in the git
history.

Also make the rest of the comments a bit more uniform, and place
comments about a specific model card inside the model card (instead of
above) so they don't get lost when code is added/moved around.

Test plan:
- my eyes
2026-01-15 13:21:58 +00:00
Evan Quiney
82ba42bae9 add glm-47, minimax-m21 (#1147)
Adds support glm 4.7 and MiniMax M2.1

Manual testing:
Tensor + Pipeline execution of both models.

Closes #1141 and #1142
2026-01-14 16:33:17 +00:00
Jake Hillion
3671528fa4 nix: add dashboard build with dream2nix
Continue working towards a fully Nix based build by building the
dashboard with Nix. Continuing the theme of using the existing lock
files, use dream2nix to parse the lock file and build the tree of
dependency derivations.

dream2nix doesn't like the bundleDependencies, so we apply a small patch
to the lock file that drops all dependencies that are bundled. This
should ideally be contributed upstream but that can be done later.

Use this new dashboard build in the build-app CI workflow, meaning
future macOS apps will include this reproducible dashboard.

Test plan:
- Built a DMG, shipped to a cluster, loaded in a browser with no cache
  and the dashboard looks good.

- Directory layout is as expected:
```
$ nix build .#dashboard
$ find result/
...
result/_app/immutable/entry
result/_app/immutable/entry/app.CTPAnMjf.js
result/_app/immutable/entry/start.fUSEa-2O.js
result/_app/immutable/nodes
result/_app/immutable/nodes/3.DqQr1Obm.js
result/_app/immutable/nodes/0.DgEY44RO.js
result/_app/immutable/nodes/2.BjZg_lJh.js
result/_app/immutable/nodes/1.D6vGUYYT.js
result/_app/env.js
result/_app/version.json
result/exo-logo.png
result/favicon.ico
result/index.html
```
2026-01-14 15:58:16 +01:00
Jake Hillion
e6434ec446 nix: add Rust builds with crane and fenix
The Rust workspace lacked Nix build support, making it difficult to
build packages reproducibly or run checks in CI.

Added a flake-parts module at rust/parts.nix that uses crane for Rust
builds and fenix for the nightly toolchain. The source filter isolates
rust/ and root Cargo files to prevent Python/docs changes from
triggering Rust rebuilds. Exports packages (system_custodian,
exo_pyo3_bindings wheel, exo-rust-workspace) and checks (cargo-nextest,
cargo-doc) for all three target platforms.

The devShell now uses inputsFrom to inherit build dependencies from the
workspace package, removing the need for manual pkg-config/openssl setup.

Test plan:
- Ran `nix flake check` successfully
- Built `nix build ".#checks.x86_64-linux.cargo-nextest"` and tests pass
- Built `nix build ".#exo_pyo3_bindings"` and wheel is produced
2026-01-14 11:52:29 +00:00
Jake Hillion
bdb43e1dbb nix: drop noisy echos from devshell
Drop all the printing when entering a devshell. It's annoying, and not a
super accurate description of how to develop exo anyway.
2026-01-14 10:04:57 +00:00
Jake Hillion
e4a01e2b0e chore(deps): nix lock file maintenance
Update nix flake inputs. Add a second input as Swift is currently broken
in nixpkgs on Linux for `swift-format` as we want `nix fmt` to continue
being reproducible everywhere.
2026-01-13 19:57:14 +01:00
Evan Quiney
1200a7db64 Add tensor sharding for GPT-OSS (#1144)
## Motivation

GPT OSS did not previously support tensor sharding

## Changes

Add GPT sharding support in tensor_auto_parallel.
Code is mostly @rltakashige's

## Test Plan

### Manual Testing
Tested GPT-OSS - MLX Fast Sync causes issues in Tensor RDMA - this is a general problem at the moment.
2026-01-13 17:25:52 +00:00
Evan Quiney
47ceb54bc1 up the rlimit (#1148)
Fixes #1117 

Manual testing:
Launched 100 instances. worked. yay.
2026-01-13 15:00:54 +00:00
Jake Hillion
f8112fdf25 nix: convert to flake-parts
Preparing to add a flake-parts module for Rust builds. The flake-utils
library doesn't support the module system needed for cleanly separating
the Rust build configuration.

Converted from flake-utils to flake-parts, switching to the treefmt-nix
flakeModule import pattern. The devShell and formatter outputs remain
functionally equivalent.

Test plan:
- Ran `nix flake check` successfully
- Verified `nix develop` provides the same environment
2026-01-13 15:06:44 +01:00
Alex Cheema
e388f59480 docs: add AGENTS.md for AI coding agents guidance (#1132)
## Motivation

Add documentation to help AI coding agents (Claude Code, Cursor, GitHub
Copilot, etc.) understand the exo codebase and contribute effectively.

## Changes

- Add `AGENTS.md` with guidance for AI agents working on the codebase
- Add symlink `CLAUDE.md -> AGENTS.md` for backwards compatibility with
Claude Code

## Why It Works

`AGENTS.md` is becoming a standard convention for AI agent instructions.
The symlink ensures Claude Code (which looks for `CLAUDE.md`) continues
to work while supporting the broader `AGENTS.md` convention.

## Test Plan

### Manual Testing
- Verified symlink works correctly

### Automated Testing
- N/A (documentation only)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 13:05:47 +00:00
Alex Cheema
e5e74e1eef Upgrade mlx-lm to 0.30.2 with transformers 5.x compatibility (#1125)
## Motivation

Upgrade mlx-lm to version 0.30.2 which requires transformers 5.0.0rc2 as
a prerelease dependency. This enables support for newer models like Kimi
K2 Thinking while maintaining compatibility with existing models.

The transformers 5.x release includes breaking changes that affect
custom tokenizers like Kimi's TikTokenTokenizer, requiring compatibility
fixes.

## Changes

### Core Changes
- **mlx-lm upgrade**: Bump to 0.30.2 with locked exact versions for
mlx/mlx-lm to prevent breaking changes
- **transformers 5.x compatibility**: Enable prerelease transformers
dependency

### Kimi K2 Tokenizer Fixes
- Add `bytes_to_unicode` monkey-patch to restore function moved in
transformers 5.0.0rc2
- Load `TikTokenTokenizer` directly instead of via `AutoTokenizer` to
bypass transformers 5.x bug with `auto_map` fallback
- Patch `encode()` to use tiktoken directly with `allowed_special="all"`
to handle special tokens from chat templates

### Other Changes
- Dashboard: Show disk usage for completed model downloads
- CI: Add `workflow_dispatch` trigger to build-app workflow
- Docs: Add basic API documentation

### Testing
- Add comprehensive tokenizer unit tests for all supported models
- Tests verify encode/decode, special token handling, and chat template
encoding

## Why It Works

**bytes_to_unicode issue**: transformers 5.0.0rc2 moved
`bytes_to_unicode` from `transformers.models.gpt2.tokenization_gpt2` to
`transformers.convert_slow_tokenizer`. Kimi's `tokenization_kimi.py`
imports from the old location. The monkey-patch restores it at module
load time.

**AutoTokenizer issue**: transformers 5.x has a bug where
`tokenizer_class_from_name('TikTokenTokenizer')` returns `None` for
custom tokenizers with `auto_map`. Loading the tokenizer directly
bypasses this.

**encode() issue**: transformers 5.x's `pad()` method fails for slow
tokenizers. Using tiktoken's encode directly with
`allowed_special="all"` avoids this path and properly handles special
tokens like `<|im_user|>` from chat templates.

## Test Plan

### Manual Testing
- Hardware: 2x Mac Studios connected via Thunderbolt 5 (mike22 and
james21)
- Tested Kimi K2 Thinking, GPT-OSS-120B, GPT-OSS-20B, LLama-3.1-8B-bf16, qwen3-30B-A3B-8bit model with pipeline parallelism across both
nodes
- Verified warmup inference completes successfully
- Verified chat completions work with special tokens

### Automated Testing
- Added `test_tokenizers.py` with 31 tests covering:
- Basic encode/decode for all model families (deepseek, kimi, llama,
qwen, gpt-oss, glm)
  - Special token encoding (critical for chat templates)
  - Chat template application and encoding
  - Kimi-specific and GLM-specific edge cases
- All tests pass: `uv run pytest
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py`

### Failing Tests
RDMA with all models.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-01-13 12:06:04 +00:00
Jake Hillion
b968d6f0a0 ci: remove old commented out job 2026-01-13 12:42:04 +01:00
Jake Hillion
3bfffd9b4f ci: build all Nix outputs on all platforms and push to cachix
The CI was only running `nix flake check` on ubuntu-latest, missing
builds for other platforms and not caching packages or devShells.

Added a matrix-based `nix-build` job that runs on macos-26 (aarch64-darwin),
ubuntu-latest (x86_64-linux), and ubuntu-24.04-arm (aarch64-linux). Each
job enumerates all packages and devShells via `nix flake show --json`,
builds them in a single `nix build` call for parallelization, then runs
`nix flake check`. The cachix-action pushes all built outputs automatically.

This ensures all Nix outputs are built and cached for every supported
platform, speeding up local development and CI runs.

Test plan:
- Tested jq enumeration command locally, correctly outputs devShell paths
- Verified xargs pipeline works with the enumerated outputs
2026-01-13 12:37:12 +01:00
Jake Hillion
007eb80029 nix: enable cachix
Enable cachix and push to it in the pipeline.yml workflow. This won't
cache a huge amount yet but will automatically extend our caching as we
build more of the repo with Nix in CI. It can also be used by local
users by accepting our cache to improve the speed of local builds.

Test plan:
- CI
2026-01-12 17:24:59 +01:00
Jake Hillion
8d7b6789b3 dashboard: show disk usage for completed models
The downloads dashboard showed "Completed" for finished model downloads
but provided no indication of how much disk space each model or the
total models on a node were using.

Added total_bytes field to DownloadCompleted type so the size is
preserved when a download completes. Updated the dashboard to display
the model size next to "Completed" status (e.g., "Completed (251.1GB)")
and a total disk usage line below the model count for each node (e.g.,
"502.2GB on disk").

Test plan:
- Ran unit tests for download apply and planning logic
- Type checked all modified files with basedpyright
2026-01-12 16:34:29 +01:00
Jake Hillion
3c5b7ea670 ci: add workflow_dispatch trigger to build-app
Build app is the most convenient way to get a DMG for testing, but
currently it's a bit limited. You have to push to test-app every time
which is far from ideal and requires a bit too much force pushing for my
liking.

Add the workflow_dispatch trigger. This adds a button in the actions UI
to trigger a workflow for a named branch, which means you can use your
normal dev branch instead of having to push to test-app. We'll leave
that behaviour there for now too, though it may change in future.

Filter on `"${{ github.event_name }}" == "workflow_dispatch"` and set
those to alpha as well. Will verify by pushing the first version from
`main` just in case. Unfortunately we do have to merge this before we
can test it.

Test plan:
- Looking really hard.
2026-01-12 12:14:21 +01:00
PG
b74a610537 Add a basic documentation to the api interface (#1122)
## Motivation

Adds basic api documentation

## Changes

- Add docs/api.md
- Modify README.md
2026-01-11 18:44:40 +00:00
Jake Hillion
18c4e49f91 nix: put treefmt in devshell
treefmt is a useful to be able to access directly for some formatters like
`jj fix`. Expose it in the devshell.

Test plan:
- Used with `jj fix` on a large branch. It worked.
2026-01-09 17:53:50 +01:00
Sami Khan
d85b5d3781 feat: uninstall button (#1077)
## Motivation

https://github.com/exo-explore/exo/issues/1075

## Changes

- Added in-app "Uninstall" option under Advanced menu that cleanly
removes all system components
- Added NetworkSetupHelper.uninstall() to remove LaunchDaemon, scripts,
logs, and restore network settings
- Added LaunchAtLoginHelper.disable() to unregister from login items
- Created standalone uninstall-exo.sh script for users who already
deleted the app
- Added uninstall documentation to README

<img width="386" height="577" alt="image"
src="https://github.com/user-attachments/assets/6bbcd18a-992a-409d-8791-ed5e13bbcfe0"
/>
<img width="372" height="432" alt="image"
src="https://github.com/user-attachments/assets/ee76b45d-c111-4807-ab28-3f2f20e01140"
/>


## Why It Works

The in-app uninstaller runs a privileged shell script (via AppleScript)
to launchctl bootout the daemon, remove files, and restore the
"Automatic" network location. The standalone script provides the same
cleanup for users who already deleted the app.

## Test Plan

### Manual Testing
Hardware: MacBook Pro
- Built and ran app, verified LaunchDaemon and network location were
created
- Used in-app Uninstall, verified all components removed and network
restored to Automatic
- Rebuilt app, quit normally, ran sudo ./uninstall-exo.sh, verified same
cleanup

### Automated Testing
N/A

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-01-09 14:49:08 +00:00
Evan Quiney
caafc48693 Forward tools to the models chat template properly (#1106)
We did not properly forward tools to the chat template before. This is not a full tool calling impl - but it should improve things slightly.

## Changes made

Pass tools to the hf tokenizers chat template
Join message chunks into a larger message (opencode does this sometimes - we were ignoring before)

## Future work

We need to parse the model output and normalise the return format to be compatible with the openai api.
2026-01-09 13:28:41 +00:00
Evan
cca8c9984a cleanup unused dependencies
we have a lot of dependencies we have no intent of using. kill them with
fire!

## testing
exo still launches and does the worst inference known to man on my Qwen3
instance. tests pass too!!
2026-01-09 13:11:58 +00:00
Sami Khan
d1e88def42 scrollbars fixed (#1113)
## Motivation

Fixes https://github.com/exo-explore/exo/issues/1107 - Horizontal
scrollbar always appears in instances section, and vertical scrollbar
appears too early (with just 1-2 instances on large screens).


## Changes

- Added overflow-x-hidden to remove horizontal scrollbar
- Added xl:max-h-96 for responsive vertical height (384px on xl+ screens
vs 288px default)
- Added py-px to accommodate corner accent decorations that extend 1px
outside cards

## Why It Works

- overflow-x-hidden prevents horizontal scroll regardless of content
- Larger max-height on xl screens fits 2 instances without scrollbar;
3rd triggers it
- 1px vertical padding accommodates the -top-px/-bottom-px positioned
corner accents that caused tiny overflow

## Test Plan

### Manual Testing
<img width="1190" height="868" alt="image"
src="https://github.com/user-attachments/assets/2a582328-5b4f-4490-a488-52106f2e85ef"
/>

### Automated Testing
N/A
2026-01-09 12:51:05 +00:00
Sami Khan
59e7594e34 UNKNOWN to PREPARING (#1112)
## Motivation

The "UNKNOWN" status shown when first launching an instance is confusing
and unhelpful. "PREPARING" better describes what's actually happening.

![telegram-cloud-photo-size-4-5981245965962251168-x](https://github.com/user-attachments/assets/65b0802b-fb64-4fa7-bff7-c13757035b3a)


## Changes

- Renamed status from "UNKNOWN" to "PREPARING" in dashboard
(+page.svelte)
- Renamed unknown state to preparing in macOS app
(InstanceViewModel.swift, InstanceRowView.swift)

## Why It Works

The status appears when an instance exists but runners haven't reported
status yet. "PREPARING" accurately describes this transitional state.

## Test Plan

### Manual Testing
Hardware: MacBook Pro
<img width="319" height="200" alt="image"
src="https://github.com/user-attachments/assets/9a1c3caf-026d-47ea-80d1-63c6e41d93aa"
/>

### Automated Testing
N/A
2026-01-09 11:46:51 +00:00
Chris A
c65320acd3 Fix mlx seed (#1094)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-01-09 01:40:15 +00:00
Jake Hillion
b9a78f6f3a ci: compute CURRENT_PROJECT_VERSION from semver
Previous Sparkle builds were cut from a different repo with different
build numbers, breaking version ordering. Users aren't receiving updates
because CFBundleVersion values don't reflect the actual version sequence.

Added a step to compute the build version deterministically from semver:
PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR).
Release versions use prerelease=999 to ensure they're always higher than
their prereleases (e.g., 1.0.61 > 1.0.61-alpha.3).

This ensures consistent version ordering across repos, allowing Sparkle
to correctly identify and deliver updates to users.

Test plan:
- Verified formula with test script:

```sh
compute_version() {
  VERSION="$1"
  BASE_VERSION="${VERSION%%-*}"
  MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
  MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
  PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)

  if [[ "$VERSION" == *-* ]]; then
    PRERELEASE_PART="${VERSION#*-}"
    PRERELEASE_NUM="${PRERELEASE_PART##*.}"
    if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
      PRERELEASE_NUM=0
    fi
  else
    PRERELEASE_NUM=999
  fi

  BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
  printf "%-20s -> %12s\n" "$VERSION" "$BUILD_VERSION"
}

compute_version "1.0.61-alpha.2"
compute_version "1.0.61-alpha.3"
compute_version "1.0.61"
compute_version "1.0.62-alpha.1"
compute_version "1.1.0-alpha.1"
compute_version "2.0.0-alpha.1"
compute_version "0.0.0-alpha.0"
compute_version "0.0.1-alpha.1"
compute_version "1.2.3"
compute_version "1.2.3-beta.5"
```

- Output:

```sh
Version              -> Build Number
----------------------------------------
1.0.61-alpha.2       ->   1000061002
1.0.61-alpha.3       ->   1000061003
1.0.61               ->   1000061999
1.0.62-alpha.1       ->   1000062001
1.1.0-alpha.1        ->   1001000001
2.0.0-alpha.1        ->   2000000001
0.0.0-alpha.0        ->            0
0.0.1-alpha.1        ->         1001
1.2.3                ->   1002003999
1.2.3-beta.5         ->   1002003005
```

- Confirmed ordering: alpha.2 < alpha.3 < release < next-alpha
2026-01-08 19:52:33 +01:00
Jake Hillion
8f7f0e893a ci: avoid uploading alpha appcasts
Currently alpha appcasts get uploaded. It turns out these overwrite the
standard appcast, so even though no one will update to the alpha
channel, everyone will miss regular updates while the latest build was
an alpha one.

Ideally we should combine the source of truth for both the alpha and
release channels, but as no one is using the alpha channel for yet let's
stop uploading it for now.

Test plan:

![eyes](https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExeGNwdDk0dmdscjlkZnd6eGxhcjJzdDBsYndmc2t2cnlpZDNxZnZhYSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/gKHGnB1ml0moQdjhEJ/giphy.gif)
2026-01-08 18:52:10 +01:00
Alex Cheema
4759b09d4c Use presigned URLs for bug report uploads (#1109)
## Motivation

Previously we hardcoded AWS credentials into the app.
This is not good practice.

## Changes

Use presigned URLs instead.

## Why It Works

Presigned URLs are an S3 feature for this kind of thing. They provide an
expiring presigned URL with certain permissions. In this case we have a
presigned URL with `s3:PutObject` permission that expires after 5
minutes. The client uses this presigned URL to upload a bug report
instead of using its own credentials to sign a request. This also
simplifies a lot of the Swift code.

## Test Plan

### Manual Testing
On a single MacBook, I downloaded the app and sent a bug report. It
worked and appeared in the bucket.
2026-01-08 17:17:48 +00:00
Alex Cheema
ca680185f3 Display RDMA debug info in macOS app. (#1072)
## Motivation

Often users are running into issues with RDMA. See
https://github.com/exo-explore/exo/issues?q=is%3Aissue%20rdma
Having some debug info in the macOS app will help to debug these issues.

## Changes

Displays output of the following commands in the debug info section of
the macOS app:

1. `rdma_ctl status`
2. `ibv_devices`
3. `ibv_devinfo`

## Why It Works

It displays RDMA debug info in the debug info section of the macOS app.

## Test Plan

### Manual Testing
We need to make a new build of the macOS app and check the output under
the following conditions:

1. No RDMA enabled.
2. RDMA enabled but no devices connected over TB5.
3. RDMA enabled and devices connected over TB5.
2026-01-08 15:17:00 +00:00
Jake Hillion
383309e24e fmt: add typescript formatting
Add typescript auto formatting with Prettier and treefmt-nix. Added a
.prettierrc to useTabs, which isn't the default, to reduce churn. The
rest looks okay and will be checked by CI.

Test plan:
- CI
2026-01-08 13:47:27 +00:00
Jake Hillion
55463a9806 fmt: add swift formatting
Swift code currently has no auto formatting. Add `swift-format` to the
`treefmt-nix` config to get this formatted.

As our existing Swift code uses 4-space formatting instead of the
default 2-space, also adds a custom `.swift-format

Test plan:
- CI
2026-01-08 13:34:45 +00:00
130 changed files with 8954 additions and 5159 deletions

View File

@@ -1,6 +1,18 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
tags:
- "v*"
@@ -10,14 +22,17 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.8.1
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
AWS_REGION: ${{ secrets.AWS_REGION }}
EXO_BUILD_NUMBER: ${{ github.run_number }}
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
@@ -34,7 +49,7 @@ jobs:
- name: Derive release version from tag
run: |
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
VERSION="0.0.0-alpha.0"
echo "IS_ALPHA=true" >> $GITHUB_ENV
else
@@ -47,6 +62,32 @@ jobs:
fi
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
- name: Compute build version from semver
run: |
VERSION="$RELEASE_VERSION"
# Extract major.minor.patch (strip prerelease suffix)
BASE_VERSION="${VERSION%%-*}"
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
if [[ "$VERSION" == *-* ]]; then
PRERELEASE_PART="${VERSION#*-}"
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
# Default to 0 if not a number
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
PRERELEASE_NUM=0
fi
else
PRERELEASE_NUM=999
fi
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
echo "Computed build version: $BUILD_VERSION from $VERSION"
- name: Ensure tag commit is on main
if: github.ref_type == 'tag'
run: |
@@ -59,6 +100,52 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -85,11 +172,22 @@ jobs:
uv python install
uv sync --locked
- name: Install Nix
uses: cachix/install-nix-action@v31
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Configure Cachix
uses: cachix/cachix-action@v14
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build dashboard
run: |
cd dashboard
npm ci
npm run build
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
mkdir -p dashboard/build
cp -r "$DASHBOARD_OUT"/* dashboard/build/
- name: Install Sparkle CLI
run: |
@@ -162,11 +260,12 @@ jobs:
-configuration Release \
-derivedDataPath build \
MARKETING_VERSION="$RELEASE_VERSION" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
EXO_BUILD_TAG="$RELEASE_VERSION" \
EXO_BUILD_COMMIT="$GITHUB_SHA" \
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
mkdir -p ../../output
@@ -264,6 +363,28 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -294,5 +415,28 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
if [[ "$IS_ALPHA" != "true" ]]; then
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache

View File

@@ -20,6 +20,12 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
@@ -88,9 +94,19 @@ jobs:
- uses: ./.github/actions/typecheck
nix-flake-check:
name: Check Nix flake
runs-on: ubuntu-latest
nix:
name: Build and check (${{ matrix.system }})
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- runner: macos-26
system: aarch64-darwin
- runner: ubuntu-latest
system: x86_64-linux
- runner: ubuntu-24.04-arm
system: aarch64-linux
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -101,83 +117,20 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Run nix flake check
run: |
nix flake check
shell: bash
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
# ci:
# needs: typecheck
# runs-on: ubuntu-latest
# permissions:
# contents: read
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
# token: ${{ secrets.GITHUB_TOKEN }}
# lfs: true
#
# - name: Configure git user
# run: |
# git config --local user.email "github-actions@users.noreply.github.com"
# git config --local user.name "github-actions bot"
# shell: bash
#
# - name: Pull LFS files
# run: |
# echo "Pulling Git LFS files..."
# git lfs pull
# shell: bash
#
# - name: Setup EXO_HOME and API_PORT
# run: |
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
# # Generate random port (macOS compatible method)
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
# echo "Created EXO_HOME: $EXO_HOME"
# echo "Generated API_PORT: $API_PORT"
# shell: bash
#
# - name: Setup Nix Environment
# run: |
# echo "Checking for nix installation..."
#
# # Check if nix binary exists directly
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
# echo "PATH=$PATH" >> $GITHUB_ENV
# nix --version
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
# echo "Found nix profile script, sourcing..."
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
# nix --version
# elif command -v nix >/dev/null 2>&1; then
# echo "Nix already in PATH"
# nix --version
# else
# echo "Nix not found. Debugging info:"
# echo "Contents of /nix/var/nix/profiles/default/:"
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
# exit 1
# fi
# shell: bash
#
# - uses: ./.github/actions/lint-check
#
# - uses: ./.github/actions/unit-test
#
# - name: Cleanup EXO_HOME
# run: |
# echo "Cleaning up EXO_HOME: $EXO_HOME"
# rm -rf "$EXO_HOME"
# shell: bash
# if: always()
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
[
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
] | .[]
' | xargs nix build
- name: Run nix flake check
run: nix flake check

1
.gitignore vendored
View File

@@ -16,6 +16,7 @@ digest.txt
*.xcuserdatad/
**/.DS_Store
app/EXO/build/
dist/
# rust

View File

@@ -0,0 +1,156 @@
"""Type stubs for mlx_lm.models.deepseek_v3"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: Optional[int]
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
attention_bias: bool
class DeepseekV3Attention(nn.Module):
config: ModelArgs
hidden_size: int
num_heads: int
max_position_embeddings: int
rope_theta: float
q_lora_rank: Optional[int]
qk_rope_head_dim: int
kv_lora_rank: int
v_head_dim: int
qk_nope_head_dim: int
q_head_dim: int
scale: float
q_proj: nn.Linear
q_a_proj: nn.Linear
q_a_layernorm: nn.RMSNorm
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
o_proj: nn.Linear
rope: Any
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3MLP(nn.Module):
config: ModelArgs
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class MoEGate(nn.Module):
config: ModelArgs
top_k: int
norm_topk_prob: bool
n_routed_experts: Optional[int]
routed_scaling_factor: float
n_group: int
topk_group: int
weight: mx.array
e_score_correction_bias: mx.array
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class DeepseekV3MoE(nn.Module):
config: ModelArgs
num_experts_per_tok: int
switch_mlp: SwitchGLU
gate: MoEGate
shared_experts: DeepseekV3MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class DeepseekV3DecoderLayer(nn.Module):
self_attn: DeepseekV3Attention
mlp: DeepseekV3MLP | DeepseekV3MoE
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3Model(nn.Module):
vocab_size: int
embed_tokens: nn.Embedding
layers: list[DeepseekV3DecoderLayer]
norm: nn.RMSNorm
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
class Model(nn.Module):
model_type: str
model: DeepseekV3Model
lm_head: nn.Linear
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[DeepseekV3DecoderLayer]: ...

View File

@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
def __call__(self, x, gate): ...
class SwitchGLU(nn.Module):
gate_proj: SwitchLinear
up_proj: SwitchLinear
down_proj: SwitchLinear
activation: SwiGLU
def __init__(
self,
input_dims: int,

View File

@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
from functools import partial
from pathlib import Path
from typing import Any
from transformers import PreTrainedTokenizerFast
@@ -103,37 +104,55 @@ class TokenizerWrapper:
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
def add_eos_token(self, token: str): # -> None:
...
@property
def has_thinking(self): # -> bool:
...
@property
def think_start(self): # -> str | None:
...
@property
def think_end(self): # -> str | None:
...
@property
def has_tool_calling(self): # -> bool:
...
@property
def tool_call_start(self): # -> str | None:
...
@property
def tool_call_end(self): # -> str | None:
...
@property
def detokenizer(self): # -> NaiveStreamingDetokenizer:
"""
Get a stateful streaming detokenizer.
"""
def __getattr__(self, attr): # -> set[Any] | Any:
...
def __setattr__(self, attr, value): # -> None:
...
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
all_special_tokens: list[str]
def __init__(
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,
tool_call_end: str | None = ...,
) -> None: ...
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
def apply_chat_template(
self,
messages: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = False,
tools: Any = None,
**kwargs: Any,
) -> str: ...
def get_vocab(self) -> dict[str, int]: ...
def add_eos_token(self, token: str) -> None: ...
@property
def has_thinking(self) -> bool: ...
@property
def think_start(self) -> str | None: ...
@property
def think_end(self) -> str | None: ...
@property
def has_tool_calling(self) -> bool: ...
@property
def tool_call_start(self) -> str | None: ...
@property
def tool_call_end(self) -> str | None: ...
@property
def detokenizer(self) -> NaiveStreamingDetokenizer:
"""Get a stateful streaming detokenizer."""
def __getattr__(self, attr: str) -> Any: ...
def __setattr__(self, attr: str, value: Any) -> None: ...
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load_tokenizer(
def load(
model_path: Path,
tokenizer_config_extra=...,
return_tokenizer=...,
eos_token_ids=...,
) -> (
TokenizerWrapper
| type[SPMStreamingDetokenizer]
| partial[SPMStreamingDetokenizer]
| type[BPEStreamingDetokenizer]
| type[NaiveStreamingDetokenizer]
):
tokenizer_config_extra: dict[str, Any] | None = None,
eos_token_ids: list[int] | int | None = None,
) -> TokenizerWrapper:
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@@ -165,4 +177,7 @@ def load_tokenizer(
a Hugging Face repo ID.
"""
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
# Alias for backward compatibility
load_tokenizer = load
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...

3
.prettierrc Normal file
View File

@@ -0,0 +1,3 @@
{
"useTabs": true
}

6
.swift-format Normal file
View File

@@ -0,0 +1,6 @@
{
"version": 1,
"indentation": {
"spaces": 4
}
}

121
AGENTS.md Normal file
View File

@@ -0,0 +1,121 @@
# AGENTS.md
This file provides guidance to AI coding agents when working with code in this repository.
## Project Overview
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
## Build & Run Commands
```bash
# Build the dashboard (required before running exo)
cd dashboard && npm install && npm run build && cd ..
# Run exo (starts both master and worker with API at http://localhost:52415)
uv run exo
# Run with verbose logging
uv run exo -v # or -vv for more verbose
# Run tests (excludes slow tests by default)
uv run pytest
# Run all tests including slow tests
uv run pytest -m ""
# Run a specific test file
uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition
A single exo `Node` (src/exo/main.py) runs multiple components:
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
- **Worker**: Handles inference tasks, downloads models, manages runner processes
- **Master**: Coordinates cluster state, places model instances across nodes
- **Election**: Bully algorithm for master election
- **API**: FastAPI server for OpenAI-compatible chat completions
### Message Flow
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
- `LOCAL_EVENTS`: Workers send events to master for indexing
- `COMMANDS`: Workers/API send commands to master
- `ELECTION_MESSAGES`: Election protocol messages
- `CONNECTION_MESSAGES`: libp2p connection updates
### Event Sourcing
The system uses event sourcing for state management:
- `State` (src/exo/shared/types/state.py): Immutable state object
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
- Master indexes events and broadcasts; workers apply indexed events
### Key Type Hierarchy
- `src/exo/shared/types/`: Pydantic models for all shared types
- `events.py`: Event types (discriminated union)
- `commands.py`: Command types
- `tasks.py`: Task types for worker execution
- `state.py`: Cluster state model
### Rust Components
Rust code in `rust/` provides:
- `networking`: libp2p networking (gossipsub, peer discovery)
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
- `system_custodian`: System-level operations
### Dashboard
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
## Code Style Requirements
From .cursorrules:
- Strict, exhaustive typing - never bypass the type-checker
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
- Pydantic models with `frozen=True` and `strict=True`
- Pure functions with injectable effect handlers for side-effects
- Descriptive names - no abbreviations or 3-letter acronyms
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

19
Cargo.lock generated
View File

@@ -4340,25 +4340,6 @@ dependencies = [
"libc",
]
[[package]]
name = "system_custodian"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
name = "tagptr"
version = "0.2.0"

View File

@@ -3,7 +3,6 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
]
@@ -25,7 +24,6 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools

41
MISSED_THINGS.md Normal file
View File

@@ -0,0 +1,41 @@
# Missed things
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] We put cache limit back in utils_mlx.py.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).

117
README.md
View File

@@ -27,13 +27,22 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Dashboard
exo includes a built-in dashboard for managing your cluster and chatting with models.
<p align="center">
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
</p>
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
## Benchmarks
<details>
<summary>Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-1-qwen3-235b.jpeg" alt="Benchmark - Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
</p>
</details>
@@ -41,7 +50,7 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
<summary>DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-2-deepseek-3.1-671b.jpeg" alt="Benchmark - DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
</p>
</details>
@@ -49,7 +58,7 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
<summary>Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-3-kimi-k2-thinking.jpeg" alt="Benchmark - Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
</p>
</details>
@@ -154,6 +163,24 @@ This starts the exo dashboard and API at http://localhost:52415/
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
**Configuration Options:**
- `--no-worker`: Run exo without the worker component. Useful for coordinator-only nodes that handle networking and orchestration but don't execute inference tasks. This is helpful for machines without sufficient GPU resources but with good network connectivity.
```bash
uv run exo --no-worker
```
**File Locations (Linux):**
exo follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) on Linux:
- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)
- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)
- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)
You can override these locations by setting the corresponding XDG environment variables.
### macOS App
exo ships a macOS app that runs in the background on your Mac.
@@ -166,6 +193,37 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
**Custom Namespace for Cluster Isolation:**
The macOS app includes a custom namespace feature that allows you to isolate your exo cluster from others on the same network. This is configured through the `EXO_LIBP2P_NAMESPACE` setting:
- **Use cases**:
- Running multiple separate exo clusters on the same network
- Isolating development/testing clusters from production clusters
- Preventing accidental cluster joining
- **Configuration**: Access this setting in the app's Advanced settings (or set the `EXO_LIBP2P_NAMESPACE` environment variable when running from source)
The namespace is logged on startup for debugging purposes.
#### Uninstalling the macOS App
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
If you've already deleted the app, you can run the standalone uninstaller script:
```bash
sudo ./app/EXO/uninstall-exo.sh
```
This removes:
- Network setup LaunchDaemon
- Network configuration script
- Log files
- The "exo" network location
**Note:** You'll need to manually remove EXO from Login Items in System Settings → General → Login Items.
---
### Enabling RDMA on macOS
@@ -287,7 +345,56 @@ curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
- List all models: `curl http://localhost:52415/models`
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
For further details, see:
- API basic documentation in [docs/api.md](docs/api.md).
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
---
## Benchmarking
The `exo-bench` tool measures model prefill and token generation speed across different placement configurations. This helps you optimize model performance and validate improvements.
**Prerequisites:**
- Nodes should be running with `uv run exo` before benchmarking
- The tool uses the `/bench/chat/completions` endpoint
**Basic usage:**
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--pp 128,256,512 \
--tg 128,256
```
**Key parameters:**
- `--model`: Model to benchmark (short ID or HuggingFace ID)
- `--pp`: Prompt size hints (comma-separated integers)
- `--tg`: Generation lengths (comma-separated integers)
- `--max-nodes`: Limit placements to N nodes (default: 4)
- `--instance-meta`: Filter by `ring`, `jaccl`, or `both` (default: both)
- `--sharding`: Filter by `pipeline`, `tensor`, or `both` (default: both)
- `--repeat`: Number of repetitions per configuration (default: 1)
- `--warmup`: Warmup runs per placement (default: 0)
- `--json-out`: Output file for results (default: bench/results.json)
**Example with filters:**
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--pp 128,512 \
--tg 128 \
--max-nodes 2 \
--sharding tensor \
--repeat 3 \
--json-out my-results.json
```
The tool outputs performance metrics including prompt tokens per second (prompt_tps), generation tokens per second (generation_tps), and peak memory usage for each configuration.
---
@@ -299,4 +406,4 @@ On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working
## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.

View File

@@ -19,6 +19,7 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.8.1;
minimumVersion = 2.9.0-beta.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
}
}
],

View File

@@ -18,10 +18,11 @@ struct ContentView: View {
@State private var deletingInstanceIDs: Set<String> = []
@State private var showAllNodes = false
@State private var showAllInstances = false
@State private var showAdvanced = false
@State private var showDebugInfo = false
@State private var bugReportInFlight = false
@State private var bugReportMessage: String?
@State private var showAdvancedOptions = false
@State private var uninstallInProgress = false
@State private var pendingNamespace: String = ""
var body: some View {
@@ -55,6 +56,11 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}
@@ -70,10 +76,12 @@ struct ContentView: View {
.font(.caption)
.fontWeight(.semibold)
}
Text("Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO")
.font(.caption2)
.foregroundColor(.secondary)
.fixedSize(horizontal: false, vertical: true)
Text(
"Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO"
)
.font(.caption2)
.foregroundColor(.secondary)
.fixedSize(horizontal: false, vertical: true)
Button {
openLocalNetworkSettings()
} label: {
@@ -96,14 +104,18 @@ struct ContentView: View {
private func openLocalNetworkSettings() {
// Open Privacy & Security settings - Local Network section
if let url = URL(string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork") {
if let url = URL(
string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork")
{
NSWorkspace.shared.open(url)
}
}
private var topologySection: some View {
Group {
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
if let topology = stateService.latestSnapshot?.topologyViewModel(
localNodeId: stateService.localNodeId), !topology.nodes.isEmpty
{
TopologyMiniView(topology: topology)
}
}
@@ -137,8 +149,10 @@ struct ContentView: View {
VStack(alignment: .leading, spacing: 4) {
HStack {
VStack(alignment: .leading) {
Text("\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB")
.font(.headline)
Text(
"\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB"
)
.font(.headline)
Text("Memory")
.font(.caption)
.foregroundColor(.secondary)
@@ -247,13 +261,7 @@ struct ContentView: View {
Divider()
.padding(.vertical, 4)
}
controlButton(title: "Check for Updates") {
updater.checkForUpdates()
}
.padding(.bottom, 8)
advancedOptionsSection
.padding(.bottom, 8)
debugSection
advancedSection
.padding(.bottom, 8)
controlButton(title: "Quit", tint: .secondary) {
controller.stop()
@@ -262,7 +270,57 @@ struct ContentView: View {
}
}
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void) -> some View {
private var advancedSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvanced)
}
.animation(nil, value: showAdvanced)
if showAdvanced {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
HoverButton(title: "Check for Updates", small: true) {
updater.checkForUpdates()
}
debugSection
HoverButton(title: "Uninstall", tint: .red, small: true) {
showUninstallConfirmationAlert()
}
.disabled(uninstallInProgress)
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvanced)
}
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void)
-> some View
{
HoverButton(title: title, tint: tint, trailingSystemImage: nil, action: action)
}
@@ -293,9 +351,12 @@ struct ContentView: View {
Button {
isExpanded.wrappedValue.toggle()
} label: {
Label(isExpanded.wrappedValue ? "Hide" : "Show All", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
Label(
isExpanded.wrappedValue ? "Hide" : "Show All",
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
)
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
}
.buttonStyle(.plain)
.font(.caption2)
@@ -383,57 +444,16 @@ struct ContentView: View {
}
}
private var advancedOptionsSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced Options")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvancedOptions)
}
.animation(nil, value: showAdvancedOptions)
if showAdvancedOptions {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Debug Info")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showDebugInfo)
VStack(alignment: .leading, spacing: 4) {
HoverButton(
title: "Debug Info",
tint: .primary,
trailingSystemImage: showDebugInfo ? "chevron.up" : "chevron.down",
small: true
) {
showDebugInfo.toggle()
}
.animation(nil, value: showDebugInfo)
if showDebugInfo {
VStack(alignment: .leading, spacing: 4) {
Text("Version: \(buildTag)")
@@ -446,15 +466,63 @@ struct ContentView: View {
.font(.caption2)
.foregroundColor(thunderboltStatusColor)
interfaceIpList
rdmaStatusView
sendBugReportButton
.padding(.top, 6)
}
.padding(.leading, 8)
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
}
private var rdmaStatusView: some View {
let rdma = networkStatusService.status.rdmaStatus
return VStack(alignment: .leading, spacing: 1) {
Text("RDMA: \(rdmaStatusText(rdma))")
.font(.caption2)
.foregroundColor(rdmaStatusColor(rdma))
if !rdma.devices.isEmpty {
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
.font(.caption2)
.foregroundColor(.secondary)
}
if !rdma.activePorts.isEmpty {
Text(" Active Ports:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(rdma.activePorts, id: \.device) { port in
Text(" \(port.device) port \(port.port): \(port.state)")
.font(.caption2)
.foregroundColor(.green)
}
}
}
}
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
switch rdma.rdmaCtlEnabled {
case .some(true):
return "Enabled"
case .some(false):
return "Disabled"
case nil:
return rdma.devices.isEmpty ? "Not Available" : "Available"
}
}
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
switch rdma.rdmaCtlEnabled {
case .some(true):
return .green
case .some(false):
return .orange
case nil:
return rdma.devices.isEmpty ? .secondary : .green
}
}
private var sendBugReportButton: some View {
VStack(alignment: .leading, spacing: 4) {
Button {
@@ -544,6 +612,88 @@ struct ContentView: View {
bugReportInFlight = false
}
private func showUninstallConfirmationAlert() {
let alert = NSAlert()
alert.messageText = "Uninstall EXO"
alert.informativeText = """
This will remove EXO and all its system components:
• Network configuration daemon
• Launch at login registration
• EXO network location
The app will be moved to Trash.
"""
alert.alertStyle = .warning
alert.addButton(withTitle: "Uninstall")
alert.addButton(withTitle: "Cancel")
// Style the Uninstall button as destructive
if let uninstallButton = alert.buttons.first {
uninstallButton.hasDestructiveAction = true
}
let response = alert.runModal()
if response == .alertFirstButtonReturn {
performUninstall()
}
}
private func performUninstall() {
uninstallInProgress = true
// Stop EXO process first
controller.cancelPendingLaunch()
controller.stop()
stateService.stopPolling()
// Run the privileged uninstall on a background thread
// Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess
DispatchQueue.global(qos: .utility).async {
do {
// Remove network setup daemon and components (requires admin privileges)
try NetworkSetupHelper.uninstall()
DispatchQueue.main.async {
// Unregister from launch at login
LaunchAtLoginHelper.disable()
// Move app to trash
self.moveAppToTrash()
// Quit the app
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
NSApplication.shared.terminate(nil)
}
}
} catch {
DispatchQueue.main.async {
self.showErrorAlert(message: error.localizedDescription)
self.uninstallInProgress = false
}
}
}
}
private func showErrorAlert(message: String) {
let alert = NSAlert()
alert.messageText = "Uninstall Failed"
alert.informativeText = message
alert.alertStyle = .critical
alert.addButton(withTitle: "OK")
alert.runModal()
}
private func moveAppToTrash() {
guard let appURL = Bundle.main.bundleURL as URL? else { return }
do {
try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)
} catch {
// If we can't trash the app, that's OK - user can do it manually
// The important system components have already been cleaned up
}
}
private var buildTag: String {
Bundle.main.infoDictionary?["EXOBuildTag"] as? String ?? "unknown"
}
@@ -557,14 +707,27 @@ private struct HoverButton: View {
let title: String
let tint: Color
let trailingSystemImage: String?
let small: Bool
let action: () -> Void
init(
title: String, tint: Color = .primary, trailingSystemImage: String? = nil,
small: Bool = false, action: @escaping () -> Void
) {
self.title = title
self.tint = tint
self.trailingSystemImage = trailingSystemImage
self.small = small
self.action = action
}
@State private var isHovering = false
var body: some View {
Button(action: action) {
HStack {
Text(title)
.font(small ? .caption : nil)
Spacer()
if let systemName = trailingSystemImage {
Image(systemName: systemName)
@@ -572,8 +735,8 @@ private struct HoverButton: View {
}
}
.frame(maxWidth: .infinity, alignment: .leading)
.padding(.vertical, 6)
.padding(.horizontal, 8)
.padding(.vertical, small ? 4 : 6)
.padding(.horizontal, small ? 6 : 8)
.background(
RoundedRectangle(cornerRadius: 6)
.fill(
@@ -588,4 +751,3 @@ private struct HoverButton: View {
.onHover { isHovering = $0 }
}
}

View File

@@ -8,9 +8,9 @@
import AppKit
import CoreImage
import CoreImage.CIFilterBuiltins
import ServiceManagement
import Sparkle
import SwiftUI
import ServiceManagement
import UserNotifications
import os.log
@@ -113,7 +113,7 @@ struct EXOApp: App {
filter.contrast = 0.9
guard let output = filter.outputImage,
let rendered = ciContext.createCGImage(output, from: output.extent)
let rendered = ciContext.createCGImage(output, from: output.extent)
else {
return nil
}
@@ -126,7 +126,26 @@ struct EXOApp: App {
do {
try SMAppService.mainApp.register()
} catch {
Logger().error("Failed to register EXO for launch at login: \(error.localizedDescription)")
Logger().error(
"Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
}
/// Helper for managing EXO's launch-at-login registration
enum LaunchAtLoginHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LaunchAtLogin")
/// Unregisters EXO from launching at login
static func disable() {
guard SMAppService.mainApp.status == .enabled else { return }
do {
try SMAppService.mainApp.unregister()
logger.info("Unregistered EXO from launch at login")
} catch {
logger.error(
"Failed to unregister EXO from launch at login: \(error.localizedDescription, privacy: .public)"
)
}
}
}
@@ -151,7 +170,7 @@ final class SparkleUpdater: NSObject, ObservableObject {
center.requestAuthorization(options: [.alert, .sound]) { _, _ in }
controller.updater.automaticallyChecksForUpdates = true
controller.updater.automaticallyDownloadsUpdates = false
controller.updater.updateCheckInterval = 900 // 15 minutes
controller.updater.updateCheckInterval = 900 // 15 minutes
DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in
controller?.updater.checkForUpdatesInBackground()
}
@@ -218,7 +237,8 @@ private final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterD
func userNotificationCenter(
_ center: UNUserNotificationCenter,
willPresent notification: UNNotification,
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) -> Void
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->
Void
) {
completionHandler([.banner, .list, .sound])
}

View File

@@ -31,7 +31,8 @@ final class ExoProcessController: ObservableObject {
@Published private(set) var launchCountdownSeconds: Int?
@Published var customNamespace: String = {
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
}() {
}()
{
didSet {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
@@ -221,7 +222,9 @@ final class ExoProcessController: ObservableObject {
if let tag = Bundle.main.infoDictionary?["EXOBuildTag"] as? String, !tag.isEmpty {
return tag
}
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String, !short.isEmpty {
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String,
!short.isEmpty
{
return short
}
return "dev"

View File

@@ -8,6 +8,8 @@
<string>$(EXO_BUILD_TAG)</string>
<key>EXOBuildCommit</key>
<string>$(EXO_BUILD_COMMIT)</string>
<key>EXOBugReportPresignedUrlEndpoint</key>
<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>
<key>NSLocalNetworkUsageDescription</key>
<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>
<key>NSBonjourServices</key>

View File

@@ -16,10 +16,13 @@ struct ClusterState: Decodable {
self.instances = rawInstances.mapValues(\.instance)
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
let rawTasks = try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
let rawTasks =
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
self.tasks = rawTasks.compactMapValues(\.task)
self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)
let rawDownloads = try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads) ?? [:]
let rawDownloads =
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
?? [:]
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
}
@@ -41,7 +44,8 @@ private struct TaggedInstance: Decodable {
let payloads = try container.decode([String: ClusterInstancePayload].self)
guard let entry = payloads.first else {
throw DecodingError.dataCorrupted(
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
DecodingError.Context(
codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
)
}
self.instance = ClusterInstance(
@@ -77,7 +81,8 @@ struct RunnerStatusSummary: Decodable {
let payloads = try container.decode([String: RunnerStatusDetail].self)
guard let entry = payloads.first else {
throw DecodingError.dataCorrupted(
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
DecodingError.Context(
codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
)
}
self.status = entry.key
@@ -257,7 +262,9 @@ struct ChatCompletionTaskParameters: Decodable, Equatable {
func promptPreview() -> String? {
guard let messages else { return nil }
if let userMessage = messages.last(where: { $0.role?.lowercased() == "user" && ($0.content?.isEmpty == false) }) {
if let userMessage = messages.last(where: {
$0.role?.lowercased() == "user" && ($0.content?.isEmpty == false)
}) {
return userMessage.content
}
return messages.last?.content
@@ -365,5 +372,3 @@ extension ClusterState {
func availableModels() -> [ModelOption] { [] }
}

View File

@@ -1,4 +1,3 @@
import CryptoKit
import Foundation
struct BugReportOutcome: Equatable {
@@ -7,17 +6,17 @@ struct BugReportOutcome: Equatable {
}
enum BugReportError: LocalizedError {
case missingCredentials
case invalidEndpoint
case presignedUrlFailed(String)
case uploadFailed(String)
case collectFailed(String)
var errorDescription: String? {
switch self {
case .missingCredentials:
return "Bug report upload credentials are not set."
case .invalidEndpoint:
return "Bug report endpoint is invalid."
case .presignedUrlFailed(let message):
return "Failed to get presigned URLs: \(message)"
case .uploadFailed(let message):
return "Bug report upload failed: \(message)"
case .collectFailed(let message):
@@ -27,11 +26,13 @@ enum BugReportError: LocalizedError {
}
struct BugReportService {
struct AWSConfig {
let accessKey: String
let secretKey: String
let region: String
let bucket: String
private struct PresignedUrlsRequest: Codable {
let keys: [String]
}
private struct PresignedUrlsResponse: Codable {
let urls: [String: String]
let expiresIn: Int?
}
func sendReport(
@@ -39,9 +40,9 @@ struct BugReportService {
now: Date = Date(),
isManual: Bool = false
) async throws -> BugReportOutcome {
let credentials = try loadCredentials()
let timestamp = ISO8601DateFormatter().string(from: now)
let prefix = "reports/\(timestamp)/"
let timestamp = Self.runTimestampString(now)
let dayPrefix = Self.dayPrefixString(now)
let prefix = "reports/\(dayPrefix)/\(timestamp)/"
let logData = readLog()
let ifconfigText = try await captureIfconfig()
@@ -66,28 +67,82 @@ struct BugReportService {
("\(prefix)exo.log", logData),
("\(prefix)state.json", stateData),
("\(prefix)events.json", eventsData),
("\(prefix)report.json", reportJSON)
("\(prefix)report.json", reportJSON),
]
let uploader = try S3Uploader(config: credentials)
for item in uploads {
guard let data = item.data else { continue }
try await uploader.upload(
objectPath: item.path,
body: data
)
let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in
guard let body = item.data else { return nil }
return (key: item.path, body: body)
}
return BugReportOutcome(success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
guard !uploadItems.isEmpty else {
return BugReportOutcome(success: false, message: "No data to upload")
}
let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\.key))
for item in uploadItems {
guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {
throw BugReportError.uploadFailed("Missing presigned URL for \(item.key)")
}
try await uploadToPresignedUrl(url: url, body: item.body)
}
return BugReportOutcome(
success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
}
private func loadCredentials() throws -> AWSConfig {
return AWSConfig(
accessKey: "AKIAYEKP5EMXTOBYDGHX",
secretKey: "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",
region: "us-east-1",
bucket: "exo-bug-reports"
)
private static func dayPrefixString(_ date: Date) -> String {
var calendar = Calendar(identifier: .gregorian)
calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
let components = calendar.dateComponents([.year, .month, .day], from: date)
let year = components.year ?? 0
let month = components.month ?? 0
let day = components.day ?? 0
return String(format: "%04d/%02d/%02d", year, month, day)
}
private static func runTimestampString(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.locale = Locale(identifier: "en_US_POSIX")
formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
formatter.dateFormat = "yyyy-MM-dd'T'HHmmss.SSS'Z'"
return formatter.string(from: date)
}
private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws
-> [String: String]
{
guard
let endpointString = bundle.infoDictionary?["EXOBugReportPresignedUrlEndpoint"]
as? String
else {
throw BugReportError.invalidEndpoint
}
let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)
guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)
else {
throw BugReportError.invalidEndpoint
}
var request = URLRequest(url: endpoint)
request.httpMethod = "POST"
request.timeoutInterval = 10
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
let encoder = JSONEncoder()
request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse else {
throw BugReportError.presignedUrlFailed("Non-HTTP response")
}
guard (200..<300).contains(http.statusCode) else {
throw BugReportError.presignedUrlFailed("HTTP status \(http.statusCode)")
}
let decoder = JSONDecoder()
let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)
return decoded.urls
}
private func readLog() -> Data? {
@@ -100,7 +155,8 @@ struct BugReportService {
private func captureIfconfig() async throws -> String {
let result = runCommand(["/sbin/ifconfig"])
guard result.exitCode == 0 else {
throw BugReportError.collectFailed(result.error.isEmpty ? "ifconfig failed" : result.error)
throw BugReportError.collectFailed(
result.error.isEmpty ? "ifconfig failed" : result.error)
}
return result.output
}
@@ -108,12 +164,23 @@ struct BugReportService {
private func readDebugInfo() -> DebugInfo {
DebugInfo(
thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),
interfaces: readInterfaces()
interfaces: readInterfaces(),
rdma: readRDMADebugInfo()
)
}
private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {
DebugInfo.RDMADebugInfo(
rdmaCtlStatus: safeRunCommand(["/usr/bin/rdma_ctl", "status"]),
ibvDevices: safeRunCommand(["/usr/bin/ibv_devices"]),
ibvDevinfo: safeRunCommand(["/usr/bin/ibv_devinfo"])
)
}
private func readThunderboltBridgeDisabled() -> Bool? {
let result = runCommand(["/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
let result = runCommand([
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased()
if output.contains("enabled") {
@@ -156,7 +223,8 @@ struct BugReportService {
request.timeoutInterval = 5
do {
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)
else {
return nil
}
return data
@@ -165,6 +233,36 @@ struct BugReportService {
}
}
private func uploadToPresignedUrl(url: URL, body: Data) async throws {
let maxAttempts = 2
var lastError: Error?
for attempt in 1...maxAttempts {
do {
var request = URLRequest(url: url)
request.httpMethod = "PUT"
request.httpBody = body
request.timeoutInterval = 30
let (_, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse else {
throw BugReportError.uploadFailed("Non-HTTP response")
}
guard (200..<300).contains(http.statusCode) else {
throw BugReportError.uploadFailed("HTTP status \(http.statusCode)")
}
return
} catch {
lastError = error
if attempt < maxAttempts {
try await Task.sleep(nanoseconds: 400_000_000)
}
}
}
throw BugReportError.uploadFailed(lastError?.localizedDescription ?? "Unknown error")
}
private func makeReportJson(
timestamp: String,
hostName: String,
@@ -182,7 +280,7 @@ struct BugReportService {
"system": system,
"exo_version": exo.version as Any,
"exo_commit": exo.commit as Any,
"report_type": isManual ? "manual" : "automated"
"report_type": isManual ? "manual" : "automated",
]
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
}
@@ -213,10 +311,13 @@ struct BugReportService {
let user = safeRunCommand(["/usr/bin/whoami"])
let consoleUser = safeRunCommand(["/usr/bin/stat", "-f%Su", "/dev/console"])
let uptime = safeRunCommand(["/usr/bin/uptime"])
let diskRoot = safeRunCommand(["/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'"])
let diskRoot = safeRunCommand([
"/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'",
])
let interfacesList = safeRunCommand(["/usr/sbin/ipconfig", "getiflist"])
let interfacesAndIPs = interfacesList?
let interfacesAndIPs =
interfacesList?
.split(whereSeparator: { $0 == " " || $0 == "\n" })
.compactMap { iface -> [String: Any]? in
let name = String(iface)
@@ -227,7 +328,8 @@ struct BugReportService {
} ?? []
let wifiSSID: String?
let airportPath = "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
let airportPath =
"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
if FileManager.default.isExecutableFile(atPath: airportPath) {
wifiSSID = safeRunCommand([airportPath, "-I"]).flatMap(parseWifiSSID)
} else {
@@ -255,7 +357,7 @@ struct BugReportService {
"disk_root": diskRoot as Any,
"interfaces_and_ips": interfacesAndIPs,
"ipconfig_getiflist": interfacesList as Any,
"wifi_ssid": wifiSSID as Any
"wifi_ssid": wifiSSID as Any,
]
}
@@ -313,7 +415,8 @@ struct BugReportService {
for line in airportOutput.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("SSID:") {
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(in: .whitespaces)
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(
in: .whitespaces)
}
}
return nil
@@ -350,6 +453,7 @@ struct BugReportService {
private struct DebugInfo {
let thunderboltBridgeDisabled: Bool?
let interfaces: [InterfaceStatus]
let rdma: RDMADebugInfo
struct InterfaceStatus {
let name: String
@@ -358,7 +462,21 @@ private struct DebugInfo {
func toDictionary() -> [String: Any] {
[
"name": name,
"ip": ip as Any
"ip": ip as Any,
]
}
}
struct RDMADebugInfo {
let rdmaCtlStatus: String?
let ibvDevices: String?
let ibvDevinfo: String?
func toDictionary() -> [String: Any] {
[
"rdma_ctl_status": rdmaCtlStatus as Any,
"ibv_devices": ibvDevices as Any,
"ibv_devinfo": ibvDevinfo as Any,
]
}
}
@@ -366,7 +484,8 @@ private struct DebugInfo {
func toDictionary() -> [String: Any] {
[
"thunderbolt_bridge_disabled": thunderboltBridgeDisabled as Any,
"interfaces": interfaces.map { $0.toDictionary() }
"interfaces": interfaces.map { $0.toDictionary() },
"rdma": rdma.toDictionary(),
]
}
}
@@ -376,163 +495,3 @@ private struct CommandResult {
let output: String
let error: String
}
private struct S3Uploader {
let config: BugReportService.AWSConfig
init(config: BugReportService.AWSConfig) throws {
self.config = config
}
func upload(objectPath: String, body: Data) async throws {
let host = "\(config.bucket).s3.amazonaws.com"
guard let url = URL(string: "https://\(host)/\(objectPath)") else {
throw BugReportError.invalidEndpoint
}
let now = Date()
let amzDate = awsTimestamp(now)
let dateStamp = dateStamp(now)
let payloadHash = sha256Hex(body)
let headers = [
"host": host,
"x-amz-content-sha256": payloadHash,
"x-amz-date": amzDate
]
let canonicalRequest = buildCanonicalRequest(
method: "PUT",
url: url,
headers: headers,
payloadHash: payloadHash
)
let stringToSign = buildStringToSign(
amzDate: amzDate,
dateStamp: dateStamp,
canonicalRequestHash: sha256Hex(canonicalRequest.data(using: .utf8) ?? Data())
)
let signingKey = deriveKey(secret: config.secretKey, dateStamp: dateStamp, region: config.region, service: "s3")
let signature = hmacHex(key: signingKey, data: Data(stringToSign.utf8))
let signedHeaders = "host;x-amz-content-sha256;x-amz-date"
let authorization = """
AWS4-HMAC-SHA256 Credential=\(config.accessKey)/\(dateStamp)/\(config.region)/s3/aws4_request, SignedHeaders=\(signedHeaders), Signature=\(signature)
"""
var request = URLRequest(url: url)
request.httpMethod = "PUT"
request.httpBody = body
request.setValue(headers["x-amz-content-sha256"], forHTTPHeaderField: "x-amz-content-sha256")
request.setValue(headers["x-amz-date"], forHTTPHeaderField: "x-amz-date")
request.setValue(host, forHTTPHeaderField: "Host")
request.setValue(authorization, forHTTPHeaderField: "Authorization")
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
let statusText = (response as? HTTPURLResponse)?.statusCode ?? -1
_ = data // ignore response body for UX
throw BugReportError.uploadFailed("HTTP status \(statusText)")
}
}
private func buildCanonicalRequest(
method: String,
url: URL,
headers: [String: String],
payloadHash: String
) -> String {
let canonicalURI = encodePath(url.path)
let canonicalQuery = url.query ?? ""
let sortedHeaders = headers.sorted { $0.key < $1.key }
let canonicalHeaders = sortedHeaders
.map { "\($0.key.lowercased()):\($0.value)\n" }
.joined()
let signedHeaders = sortedHeaders.map { $0.key.lowercased() }.joined(separator: ";")
return [
method,
canonicalURI,
canonicalQuery,
canonicalHeaders,
signedHeaders,
payloadHash
].joined(separator: "\n")
}
private func encodePath(_ path: String) -> String {
return path
.split(separator: "/")
.map { segment in
segment.addingPercentEncoding(withAllowedCharacters: Self.rfc3986) ?? String(segment)
}
.joined(separator: "/")
.prependSlashIfNeeded()
}
private func buildStringToSign(
amzDate: String,
dateStamp: String,
canonicalRequestHash: String
) -> String {
"""
AWS4-HMAC-SHA256
\(amzDate)
\(dateStamp)/\(config.region)/s3/aws4_request
\(canonicalRequestHash)
"""
}
private func deriveKey(secret: String, dateStamp: String, region: String, service: String) -> Data {
let kDate = hmac(key: Data(("AWS4" + secret).utf8), data: Data(dateStamp.utf8))
let kRegion = hmac(key: kDate, data: Data(region.utf8))
let kService = hmac(key: kRegion, data: Data(service.utf8))
return hmac(key: kService, data: Data("aws4_request".utf8))
}
private func hmac(key: Data, data: Data) -> Data {
let keySym = SymmetricKey(data: key)
let mac = HMAC<SHA256>.authenticationCode(for: data, using: keySym)
return Data(mac)
}
private func hmacHex(key: Data, data: Data) -> String {
hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined()
}
private func sha256Hex(_ data: Data) -> String {
let digest = SHA256.hash(data: data)
return digest.compactMap { String(format: "%02x", $0) }.joined()
}
private func awsTimestamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
formatter.timeZone = TimeZone(abbreviation: "UTC")
return formatter.string(from: date)
}
private func dateStamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.dateFormat = "yyyyMMdd"
formatter.timeZone = TimeZone(abbreviation: "UTC")
return formatter.string(from: date)
}
private static let rfc3986: CharacterSet = {
var set = CharacterSet.alphanumerics
set.insert(charactersIn: "-._~")
return set
}()
}
private extension String {
func prependSlashIfNeeded() -> String {
if hasPrefix("/") {
return self
}
return "/" + self
}
}

View File

@@ -57,7 +57,9 @@ final class ClusterStateService: ObservableObject {
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
return
}
if let nodeId = try? decoder.decode(String.self, from: data) {
@@ -113,7 +115,9 @@ final class ClusterStateService: ObservableObject {
}
}
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int) async {
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)
async
{
do {
var request = URLRequest(url: baseURL.appendingPathComponent("instance"))
request.httpMethod = "POST"
@@ -122,7 +126,7 @@ final class ClusterStateService: ObservableObject {
"model_id": modelId,
"sharding": sharding,
"instance_meta": instanceMeta,
"min_nodes": minNodes
"min_nodes": minNodes,
]
request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])
let (_, response) = try await session.data(for: request)
@@ -143,7 +147,9 @@ final class ClusterStateService: ObservableObject {
do {
let url = baseURL.appendingPathComponent("models")
let (data, response) = try await session.data(from: url)
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
throw URLError(.badServerResponse)
}
let list = try decoder.decode(ModelListResponse.self, from: data)

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,30 +35,43 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
lastConnectionState = "connecting"
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
checkTask = Task { [weak self] in
guard let self else { return }
let result = await self.performCheck()
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
connection?.cancel()
connection = nil
@@ -84,22 +97,7 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
conn.stateUpdateHandler = { state in
switch state {
case .ready:
resumeOnce(.working)
@@ -108,10 +106,12 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH") ||
errorStr.contains("permission") || errorStr.contains("denied") {
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|| errorStr.contains("permission") || errorStr.contains("denied")
{
resumeOnce(.notWorking(reason: "Permission denied"))
} else {
resumeOnce(.notWorking(reason: "Failed: \(error.localizedDescription)"))
@@ -126,7 +126,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: 3_000_000_000)
try? await Task.sleep(nanoseconds: timeout)
let state = conn.state
switch state {
case .ready:

View File

@@ -5,64 +5,37 @@ import os.log
enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination = "/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
private static let setupScript = """
#!/usr/bin/env bash
#!/usr/bin/env bash
set -euo pipefail
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
static func ensureLaunchDaemonInstalled() {
Task.detached {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
do {
if daemonAlreadyInstalled() {
return
@@ -70,19 +43,86 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
try await installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error("Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)")
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
/// Requires admin privileges.
static func uninstall() throws {
let uninstallScript = makeUninstallScript()
try runShellAsAdmin(uninstallScript)
logger.info("EXO network setup components removed successfully")
}
/// Checks if there are any EXO network components installed that need cleanup
static func hasInstalledComponents() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
return scriptExists || plistExists
}
private static func makeUninstallScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
# Unload the LaunchDaemon if running
launchctl bootout system/"$LABEL" 2>/dev/null || true
# Remove LaunchDaemon plist
rm -f "$PLIST_DEST"
# Remove the script and parent directory if empty
rm -f "$SCRIPT_DEST"
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
# Remove log files
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic 2>/dev/null || true
# Delete the exo network location if it exists
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable Thunderbolt Bridge if it exists
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
} || true
echo "EXO network components removed successfully"
"""
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any]
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
@@ -92,7 +132,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String], programArgs.contains(scriptDestination) == false {
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
@@ -105,58 +147,59 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
mkdir -p "$(dirname "$SCRIPT_DEST")"
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript = script
let escapedScript =
script
.replacingOccurrences(of: "\\", with: "\\\\")
.replacingOccurrences(of: "\"", with: "\\\"")
let appleScriptSource = """
do shell script "\(escapedScript)" with administrator privileges
"""
do shell script "\(escapedScript)" with administrator privileges
"""
guard let appleScript = NSAppleScript(source: appleScriptSource) else {
throw NetworkSetupError.scriptCreationFailed

View File

@@ -35,14 +35,34 @@ struct NetworkStatus: Equatable {
let thunderboltBridgeState: ThunderboltState?
let bridgeInactive: Bool?
let interfaceStatuses: [InterfaceIpStatus]
let rdmaStatus: RDMAStatus
static let empty = NetworkStatus(
thunderboltBridgeState: nil,
bridgeInactive: nil,
interfaceStatuses: []
interfaceStatuses: [],
rdmaStatus: .empty
)
}
struct RDMAStatus: Equatable {
let rdmaCtlEnabled: Bool?
let devices: [String]
let activePorts: [RDMAPort]
var isAvailable: Bool {
rdmaCtlEnabled == true || !devices.isEmpty
}
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
}
struct RDMAPort: Equatable {
let device: String
let port: String
let state: String
}
struct InterfaceIpStatus: Equatable {
let interfaceName: String
let ipAddress: String?
@@ -59,10 +79,79 @@ private struct NetworkStatusFetcher {
NetworkStatus(
thunderboltBridgeState: readThunderboltBridgeState(),
bridgeInactive: readBridgeInactive(),
interfaceStatuses: readInterfaceStatuses()
interfaceStatuses: readInterfaceStatuses(),
rdmaStatus: readRDMAStatus()
)
}
private func readRDMAStatus() -> RDMAStatus {
let rdmaCtlEnabled = readRDMACtlEnabled()
let devices = readRDMADevices()
let activePorts = readRDMAActivePorts()
return RDMAStatus(
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
}
private func readRDMACtlEnabled() -> Bool? {
let result = runCommand(["rdma_ctl", "status"])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
if output.contains("enabled") {
return true
}
if output.contains("disabled") {
return false
}
return nil
}
private func readRDMADevices() -> [String] {
let result = runCommand(["ibv_devices"])
guard result.exitCode == 0 else { return [] }
var devices: [String] = []
for line in result.output.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("---") || trimmed.lowercased().hasPrefix("device")
|| trimmed.isEmpty
{
continue
}
let parts = trimmed.split(separator: " ", maxSplits: 1)
if let deviceName = parts.first {
devices.append(String(deviceName))
}
}
return devices
}
private func readRDMAActivePorts() -> [RDMAPort] {
let result = runCommand(["ibv_devinfo"])
guard result.exitCode == 0 else { return [] }
var ports: [RDMAPort] = []
var currentDevice: String?
var currentPort: String?
for line in result.output.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("hca_id:") {
currentDevice = trimmed.replacingOccurrences(of: "hca_id:", with: "")
.trimmingCharacters(in: .whitespaces)
} else if trimmed.hasPrefix("port:") {
currentPort = trimmed.replacingOccurrences(of: "port:", with: "")
.trimmingCharacters(in: .whitespaces)
} else if trimmed.hasPrefix("state:") {
let state = trimmed.replacingOccurrences(of: "state:", with: "").trimmingCharacters(
in: .whitespaces)
if let device = currentDevice, let port = currentPort {
if state.lowercased().contains("active") {
ports.append(RDMAPort(device: device, port: port, state: state))
}
}
}
}
return ports
}
private func readThunderboltBridgeState() -> ThunderboltState? {
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
guard result.exitCode == 0 else {
@@ -85,10 +174,11 @@ private struct NetworkStatusFetcher {
private func readBridgeInactive() -> Bool? {
let result = runCommand(["ifconfig", "bridge0"])
guard result.exitCode == 0 else { return nil }
guard let statusLine = result.output
.components(separatedBy: .newlines)
.first(where: { $0.contains("status:") })?
.lowercased()
guard
let statusLine = result.output
.components(separatedBy: .newlines)
.first(where: { $0.contains("status:") })?
.lowercased()
else {
return nil
}
@@ -171,4 +261,3 @@ private struct NetworkStatusFetcher {
)
}
}

View File

@@ -57,7 +57,7 @@ struct InstanceViewModel: Identifiable, Equatable {
case waiting
case failed
case idle
case unknown
case preparing
var label: String {
switch self {
@@ -68,7 +68,7 @@ struct InstanceViewModel: Identifiable, Equatable {
case .waiting: return "Waiting"
case .failed: return "Failed"
case .idle: return "Idle"
case .unknown: return "Unknown"
case .preparing: return "Preparing"
}
}
}
@@ -107,10 +107,13 @@ extension ClusterState {
let nodeToRunner = instance.shardAssignments.nodeToRunner
let nodeIds = Array(nodeToRunner.keys)
let runnerIds = Array(nodeToRunner.values)
let nodeNames = nodeIds.compactMap { nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0 }
let nodeNames = nodeIds.compactMap {
nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0
}
let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }
let downloadProgress = aggregateDownloadProgress(for: nodeIds)
let state = InstanceViewModel.State(statuses: statuses, hasActiveDownload: downloadProgress != nil)
let state = InstanceViewModel.State(
statuses: statuses, hasActiveDownload: downloadProgress != nil)
let chatTasks = (chatTasksByInstance[entry.key] ?? [])
.sorted(by: { $0.sortPriority < $1.sortPriority })
.map { InstanceTaskViewModel(task: $0) }
@@ -165,8 +168,8 @@ extension ClusterState {
}
}
private extension InstanceViewModel.State {
init(statuses: [String], hasActiveDownload: Bool = false) {
extension InstanceViewModel.State {
fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {
if statuses.contains(where: { $0.contains("failed") }) {
self = .failed
} else if hasActiveDownload || statuses.contains(where: { $0.contains("downloading") }) {
@@ -182,7 +185,7 @@ private extension InstanceViewModel.State {
} else if statuses.isEmpty {
self = .idle
} else {
self = .unknown
self = .preparing
}
}
}
@@ -243,4 +246,3 @@ extension InstanceTaskViewModel {
self.parameters = task.parameters
}
}

View File

@@ -87,7 +87,9 @@ struct TopologyViewModel {
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
let allNodes = nodeViewModels().filter {
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
}
guard !allNodes.isEmpty else { return nil }
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
@@ -106,18 +108,24 @@ extension ClusterState {
}
// Rotate so the local node (from /node_id API) is first
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
if let localId = localNodeId,
let index = orderedNodes.firstIndex(where: { $0.id == localId })
{
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
}
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
return TopologyEdgeViewModel(sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
} ?? []
let edgesArray: [TopologyEdgeViewModel] =
topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId),
nodeIds.contains(connection.sendBackNodeId)
else { return nil }
return TopologyEdgeViewModel(
sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
} ?? []
let edges = Set(edgesArray)
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
return TopologyViewModel(
nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
}
}

View File

@@ -20,8 +20,8 @@ struct InstanceRowView: View {
if let progress = instance.downloadProgress {
downloadStatusView(progress: progress)
} else {
statusChip(label: instance.state.label.uppercased(), color: statusColor)
}
statusChip(label: instance.state.label.uppercased(), color: statusColor)
}
}
if let progress = instance.downloadProgress {
GeometryReader { geometry in
@@ -83,7 +83,7 @@ struct InstanceRowView: View {
case .ready: return .teal
case .waiting, .idle: return .gray
case .failed: return .red
case .unknown: return .secondary
case .preparing: return .secondary
}
}
@@ -97,7 +97,8 @@ struct InstanceRowView: View {
.font(.caption)
.fontWeight(.semibold)
if let subtitle = task.subtitle,
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame {
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame
{
Text(subtitle)
.font(.caption2)
.foregroundColor(.secondary)
@@ -234,9 +235,12 @@ struct InstanceRowView: View {
Button {
isExpanded.wrappedValue.toggle()
} label: {
Label(isExpanded.wrappedValue ? "Hide" : "Show", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
Label(
isExpanded.wrappedValue ? "Hide" : "Show",
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
)
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
}
.buttonStyle(.plain)
.font(.caption2)
@@ -311,7 +315,9 @@ struct InstanceRowView: View {
}
@ViewBuilder
private func detailRow(icon: String? = nil, title: String, value: String, tint: Color = .secondary) -> some View {
private func detailRow(
icon: String? = nil, title: String, value: String, tint: Color = .secondary
) -> some View {
HStack(alignment: .firstTextBaseline, spacing: 6) {
if let icon {
Image(systemName: icon)
@@ -329,4 +335,3 @@ struct InstanceRowView: View {
}
}
}

View File

@@ -32,4 +32,3 @@ struct NodeDetailView: View {
}
}
}

View File

@@ -28,4 +28,3 @@ struct NodeRowView: View {
.padding(.vertical, 4)
}
}

View File

@@ -76,30 +76,33 @@ struct TopologyMiniView: View {
private func connectionLines(in size: CGSize) -> some View {
let positions = positionedNodes(in: size)
let positionById = Dictionary(uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
let positionById = Dictionary(
uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
return Canvas { context, _ in
guard !topology.edges.isEmpty else { return }
let nodeRadius: CGFloat = 32
let arrowLength: CGFloat = 10
let arrowSpread: CGFloat = .pi / 7
for edge in topology.edges {
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId] else { continue }
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]
else { continue }
let dx = end.x - start.x
let dy = end.y - start.y
let distance = max(CGFloat(hypot(dx, dy)), 1)
let ux = dx / distance
let uy = dy / distance
let adjustedStart = CGPoint(x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
let adjustedStart = CGPoint(
x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)
var linePath = Path()
linePath.move(to: adjustedStart)
linePath.addLine(to: adjustedEnd)
context.stroke(
context.stroke(
linePath,
with: .color(.secondary.opacity(0.3)),
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
)
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
)
let angle = atan2(uy, ux)
let tip = adjustedEnd
@@ -168,5 +171,3 @@ private struct NodeGlyphView: View {
.frame(width: 95)
}
}

View File

@@ -6,6 +6,7 @@
//
import Testing
@testable import EXO
struct EXOTests {

154
app/EXO/uninstall-exo.sh Executable file
View File

@@ -0,0 +1,154 @@
#!/usr/bin/env bash
#
# EXO Uninstaller Script
#
# This script removes all EXO system components that persist after deleting the app.
# Run with: sudo ./uninstall-exo.sh
#
# Components removed:
# - LaunchDaemon: /Library/LaunchDaemons/io.exo.networksetup.plist
# - Network script: /Library/Application Support/EXO/
# - Log files: /var/log/io.exo.networksetup.*
# - Network location: "exo"
# - Launch at login registration
#
set -euo pipefail
LABEL="io.exo.networksetup"
SCRIPT_DEST="/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
PLIST_DEST="/Library/LaunchDaemons/io.exo.networksetup.plist"
LOG_OUT="/var/log/${LABEL}.log"
LOG_ERR="/var/log/${LABEL}.err.log"
APP_BUNDLE_ID="io.exo.EXO"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
echo "========================================"
echo " EXO Uninstaller"
echo "========================================"
echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
# SMAppService-based login items cannot be removed from a shell script.
# They can only be unregistered from within the app itself or manually via System Settings.
echo_warn "Launch at login must be removed manually:"
echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
echo "========================================"
echo_info "EXO uninstall complete!"
echo "========================================"
echo ""
echo "The following have been removed:"
echo " • Network setup LaunchDaemon"
echo " • Network configuration script"
echo " • Log files"
echo " • 'exo' network location"
echo ""
echo "Your network has been restored to use the 'Automatic' location."
echo "Thunderbolt Bridge has been re-enabled (if present)."
echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
@@ -15,9 +16,6 @@ from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
@@ -26,7 +24,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -104,22 +102,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
@@ -241,6 +263,9 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn
@@ -296,6 +321,12 @@ def main() -> int:
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
@@ -317,7 +348,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -396,7 +427,7 @@ def main() -> int:
):
continue
if 0 < n <= args.max_nodes:
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
if not selected:
@@ -438,7 +469,13 @@ def main() -> int:
)
client.request_json("POST", "/instance", body={"instance": instance})
wait_for_instance_ready(client, instance_id)
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)
@@ -450,17 +487,17 @@ def main() -> int:
logger.debug(f" warmup {i + 1}/{args.warmup} done")
for pp in pp_list:
if (
pp * n_nodes > 2048
and "ring" in instance_meta.lower()
and "tensor" in sharding.lower()
):
model_card = MODEL_CARDS[short_id]
if model_card.metadata.storage_size > Memory.from_gb(10):
logger.info(
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
)
continue
# if (
# pp * n_nodes > 2048
# and "ring" in instance_meta.lower()
# and "tensor" in sharding.lower()
# ):
# model_card = MODEL_CARDS[short_id]
# if model_card.metadata.storage_size > Memory.from_gb(10):
# logger.info(
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
# )
# continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):

60
dashboard/dashboard.nix Normal file
View File

@@ -0,0 +1,60 @@
{ lib
, config
, dream2nix
, ...
}:
let
# Read and parse the lock file
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
# For packages with bundleDependencies, filter out deps that are bundled
# (bundled deps are inside the tarball, not separate lockfile entries)
fixedPackages = lib.mapAttrs
(path: entry:
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
then entry // {
dependencies = lib.filterAttrs
(name: _: !(lib.elem name entry.bundleDependencies))
(entry.dependencies or { });
}
else entry
)
(rawLockFile.packages or { });
fixedLockFile = rawLockFile // { packages = fixedPackages; };
in
{
imports = [
dream2nix.modules.dream2nix.nodejs-package-lock-v3
dream2nix.modules.dream2nix.nodejs-granular-v3
];
name = "exo-dashboard";
version = "1.0.0";
mkDerivation = {
src = config.deps.dashboardSrc;
buildPhase = ''
runHook preBuild
npm run build
runHook postBuild
'';
installPhase = ''
runHook preInstall
cp -r build $out/build
runHook postInstall
'';
};
deps = { nixpkgs, ... }: {
inherit (nixpkgs) stdenv;
dashboardSrc = null; # Injected by parts.nix
};
nodejs-package-lock-v3 = {
# Don't use packageLockFile - provide the fixed lock content directly
packageLock = fixedLockFile;
};
}

44
dashboard/parts.nix Normal file
View File

@@ -0,0 +1,44 @@
{ inputs, ... }:
{
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to only include dashboard directory
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
inDashboardDir =
(lib.hasInfix "/dashboard/" path)
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|| (baseName == "dashboard" && type == "directory");
in
inDashboardDir;
};
# Build the dashboard with dream2nix (includes node_modules in output)
dashboardFull = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
# Inject the filtered source
{
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
}
];
};
in
{
# Extract just the static site from the full build
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
cp -r ${dashboardFull}/build $out
'';
};
}

View File

@@ -11,4 +11,3 @@ declare global {
}
export {};

View File

@@ -60,12 +60,39 @@
return models;
});
// Auto-select the first available model if none is selected
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
$effect(() => {
const models = availableModels();
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -53,62 +53,285 @@
marked.use({ renderer });
/**
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
* Unescape HTML entities that marked may have escaped
*/
function unescapeHtmlEntities(text: string): string {
return text
.replace(/&lt;/g, '<')
.replace(/&gt;/g, '>')
.replace(/&amp;/g, '&')
.replace(/&quot;/g, '"')
.replace(/&#39;/g, "'");
}
// Storage for math expressions extracted before markdown processing
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
let mathCounter = 0;
// Storage for HTML snippets that need protection from markdown
const htmlSnippets: Map<string, string> = new Map();
let htmlCounter = 0;
// Use alphanumeric placeholders that won't be interpreted as HTML tags
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
const HTML_PLACEHOLDER_PREFIX = 'HTMLPLACEHOLDER';
/**
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
*/
function preprocessLaTeX(text: string): string {
// Protect code blocks
// Reset storage
mathExpressions.clear();
mathCounter = 0;
htmlSnippets.clear();
htmlCounter = 0;
// Protect code blocks first
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `<<CODE_${codeBlocks.length - 1}>>`;
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
});
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Remove LaTeX document commands
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\begin\{document\}/g, '');
processed = processed.replace(/\\end\{document\}/g, '');
processed = processed.replace(/\\maketitle/g, '');
processed = processed.replace(/\\title\{[^}]*\}/g, '');
processed = processed.replace(/\\author\{[^}]*\}/g, '');
processed = processed.replace(/\\date\{[^}]*\}/g, '');
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
processed = processed.replace(/\\require\{[^}]*\}/g, '');
// Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">📐</span><span class="latex-diagram-text">Diagram</span></div>');
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\begin\{figure\}[\s\S]*?\\end\{figure\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">🖼️</span><span class="latex-diagram-text">Figure</span></div>');
htmlCounter++;
return placeholder;
});
// Strip center environment (layout only, no content change)
processed = processed.replace(/\\begin\{center\}/g, '');
processed = processed.replace(/\\end\{center\}/g, '');
// Strip other layout environments
processed = processed.replace(/\\begin\{flushleft\}/g, '');
processed = processed.replace(/\\end\{flushleft\}/g, '');
processed = processed.replace(/\\begin\{flushright\}/g, '');
processed = processed.replace(/\\end\{flushright\}/g, '');
processed = processed.replace(/\\label\{[^}]*\}/g, '');
processed = processed.replace(/\\caption\{[^}]*\}/g, '');
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
// Convert LaTeX math environments to display math (both bare and wrapped in $...$)
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*', 'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', 'cases'];
for (const env of mathEnvs) {
// Handle $\begin{env}...\end{env}$ (with dollar signs, possibly multiline)
const wrappedRegex = new RegExp(`\\$\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}\\$`, 'g');
processed = processed.replace(wrappedRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
// Handle bare \begin{env}...\end{env} (without dollar signs)
const bareRegex = new RegExp(`\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
processed = processed.replace(bareRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
}
// Convert LaTeX proof environments to styled blocks (use placeholders for HTML)
processed = processed.replace(
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
(_, content) => {
const html = `<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
}
);
// Convert LaTeX theorem-like environments
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
for (const env of theoremEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
const envName = env.charAt(0).toUpperCase() + env.slice(1);
processed = processed.replace(envRegex, (_, content) => {
const html = `<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
});
}
// Convert LaTeX text formatting commands (use placeholders to protect from markdown)
processed = processed.replace(/\\emph\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textit\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textbf\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<strong>${content}</strong>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\texttt\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<code class="inline-code">${content}</code>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\underline\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<u>${content}</u>`);
htmlCounter++;
return placeholder;
});
// Handle LaTeX line breaks and spacing
processed = processed.replace(/\\\\(?:\s*\n)?/g, '\n'); // \\ -> newline
processed = processed.replace(/\\newline/g, '\n');
processed = processed.replace(/\\par\b/g, '\n\n');
processed = processed.replace(/\\quad/g, ' ');
processed = processed.replace(/\\qquad/g, ' ');
processed = processed.replace(/~~/g, ' '); // non-breaking space
// Remove other common LaTeX commands that don't render
processed = processed.replace(/\\centering/g, '');
processed = processed.replace(/\\noindent/g, '');
processed = processed.replace(/\\hfill/g, '');
processed = processed.replace(/\\vspace\{[^}]*\}/g, '');
processed = processed.replace(/\\hspace\{[^}]*\}/g, ' ');
// Convert \(...\) to placeholder (display: false)
processed = processed.replace(/\\\(([\s\S]+?)\\\)/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: false });
mathCounter++;
return placeholder;
});
// Convert \[...\] to placeholder (display: true)
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: true });
mathCounter++;
return placeholder;
});
// Extract display math ($$...$$) BEFORE markdown processing
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
mathCounter++;
return placeholder;
});
// Extract inline math ($...$) BEFORE markdown processing
// Allow single-line only, skip currency patterns like $5 or $50
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
if (/^\d/.test(content.trim())) {
return match; // Keep as-is for currency
}
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
mathCounter++;
return placeholder;
});
// Restore escaped dollar signs
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
// Restore code blocks
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
// Clean up any remaining stray backslashes from unrecognized commands
processed = processed.replace(/\\(?=[a-zA-Z])/g, ''); // Remove \ before letters (unrecognized commands)
return processed;
}
/**
* Render math expressions with KaTeX after HTML is generated
* Render math expressions with KaTeX and restore HTML placeholders
*/
function renderMath(html: string): string {
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
// Replace all math placeholders with rendered KaTeX
for (const [placeholder, { content, displayMode }] of mathExpressions) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
html = html.replace(regex, () => {
try {
const rendered = katex.renderToString(content, {
displayMode,
throwOnError: false,
output: 'html'
});
if (displayMode) {
return `
<div class="math-display-wrapper">
<div class="math-display-header">
<span class="math-label">LaTeX</span>
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<div class="math-display-content">
${rendered}
</div>
</div>
`;
} else {
return `<span class="math-inline">${rendered}</span>`;
}
} catch {
const display = displayMode ? `$$${content}$$` : `$${content}$`;
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
}
});
}
// Restore HTML placeholders (for \textbf, \emph, etc.)
for (const [placeholder, htmlContent] of htmlSnippets) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
html = html.replace(regex, htmlContent);
}
return html;
}
@@ -154,16 +377,50 @@
}
}
async function handleMathCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedSource = target.getAttribute('data-math-source');
if (!encodedSource) return;
const source = decodeURIComponent(encodedSource);
try {
await navigator.clipboard.writeText(source);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy math:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of codeButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
for (const button of mathButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleMathCopyClick);
}
}
}
$effect(() => {
@@ -424,28 +681,290 @@
color: #60a5fa;
}
/* KaTeX math styling */
/* KaTeX math styling - Base */
.markdown-content :global(.katex) {
font-size: 1.1em;
color: oklch(0.9 0 0);
}
.markdown-content :global(.katex-display) {
/* Display math container wrapper */
.markdown-content :global(.math-display-wrapper) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.15);
background: rgba(0, 0, 0, 0.3);
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover) {
border-color: rgba(255, 215, 0, 0.25);
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
}
/* Display math header - hidden by default, slides in on hover */
.markdown-content :global(.math-display-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.375rem 0.75rem;
background: rgba(255, 215, 0, 0.03);
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
opacity: 0;
max-height: 0;
padding-top: 0;
padding-bottom: 0;
overflow: hidden;
transition:
opacity 0.2s ease,
max-height 0.2s ease,
padding 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
opacity: 1;
max-height: 2.5rem;
padding: 0.375rem 0.75rem;
}
.markdown-content :global(.math-label) {
color: rgba(255, 215, 0, 0.7);
font-size: 0.65rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-math-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
opacity: 0;
transition:
color 0.2s,
opacity 0.15s ease;
}
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
opacity: 1;
}
.markdown-content :global(.copy-math-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-math-btn.copied) {
color: #22c55e;
}
/* Display math content area */
.markdown-content :global(.math-display-content) {
padding: 1rem 1.25rem;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
.markdown-content :global(.katex-display > .katex) {
/* Custom scrollbar for math overflow */
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
height: 6px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
background: rgba(255, 255, 255, 0.05);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
background: rgba(255, 215, 0, 0.2);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
background: rgba(255, 215, 0, 0.35);
}
.markdown-content :global(.math-display-content .katex-display) {
margin: 0;
padding: 0;
}
.markdown-content :global(.math-display-content .katex-display > .katex) {
text-align: center;
}
/* Inline math wrapper */
.markdown-content :global(.math-inline) {
display: inline;
padding: 0 0.125rem;
border-radius: 0.25rem;
transition: background-color 0.15s ease;
}
.markdown-content :global(.math-inline:hover) {
background: rgba(255, 215, 0, 0.05);
}
/* Dark theme KaTeX overrides */
.markdown-content :global(.katex .mord),
.markdown-content :global(.katex .minner),
.markdown-content :global(.katex .mop),
.markdown-content :global(.katex .mbin),
.markdown-content :global(.katex .mrel),
.markdown-content :global(.katex .mpunct) {
color: oklch(0.9 0 0);
}
/* Fraction lines and rules */
.markdown-content :global(.katex .frac-line),
.markdown-content :global(.katex .overline-line),
.markdown-content :global(.katex .underline-line),
.markdown-content :global(.katex .hline),
.markdown-content :global(.katex .rule) {
border-color: oklch(0.85 0 0) !important;
background: oklch(0.85 0 0);
}
/* Square roots and SVG elements */
.markdown-content :global(.katex .sqrt-line) {
border-color: oklch(0.85 0 0) !important;
}
.markdown-content :global(.katex svg) {
fill: oklch(0.85 0 0);
stroke: oklch(0.85 0 0);
}
.markdown-content :global(.katex svg path) {
stroke: oklch(0.85 0 0);
}
/* Delimiters (parentheses, brackets, braces) */
.markdown-content :global(.katex .delimsizing),
.markdown-content :global(.katex .delim-size1),
.markdown-content :global(.katex .delim-size2),
.markdown-content :global(.katex .delim-size3),
.markdown-content :global(.katex .delim-size4),
.markdown-content :global(.katex .mopen),
.markdown-content :global(.katex .mclose) {
color: oklch(0.75 0 0);
}
/* Math error styling */
.markdown-content :global(.math-error) {
display: inline-flex;
align-items: center;
gap: 0.375rem;
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.125rem 0.25rem;
padding: 0.25rem 0.5rem;
border-radius: 0.25rem;
border: 1px solid rgba(248, 113, 113, 0.2);
}
.markdown-content :global(.math-error-icon) {
font-size: 0.875em;
opacity: 0.9;
}
/* LaTeX proof environment */
.markdown-content :global(.latex-proof) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 255, 255, 0.02);
border-left: 3px solid rgba(255, 215, 0, 0.4);
border-radius: 0 0.375rem 0.375rem 0;
}
.markdown-content :global(.latex-proof-header) {
font-weight: 600;
font-style: italic;
color: oklch(0.85 0 0);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-proof-header::after) {
content: '.';
}
.markdown-content :global(.latex-proof-content) {
color: oklch(0.9 0 0);
}
.markdown-content :global(.latex-proof-content p:last-child) {
margin-bottom: 0;
}
/* QED symbol at end of proof */
.markdown-content :global(.latex-proof-content::after) {
content: '∎';
display: block;
text-align: right;
color: oklch(0.7 0 0);
margin-top: 0.5rem;
}
/* LaTeX theorem-like environments */
.markdown-content :global(.latex-theorem) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 215, 0, 0.03);
border: 1px solid rgba(255, 215, 0, 0.15);
border-radius: 0.375rem;
}
.markdown-content :global(.latex-theorem-header) {
font-weight: 700;
color: var(--exo-yellow, #ffd700);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-theorem-header::after) {
content: '.';
}
.markdown-content :global(.latex-theorem-content) {
color: oklch(0.9 0 0);
font-style: italic;
}
.markdown-content :global(.latex-theorem-content p:last-child) {
margin-bottom: 0;
}
/* LaTeX diagram/figure placeholder */
.markdown-content :global(.latex-diagram-placeholder) {
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
margin: 1rem 0;
padding: 1.5rem 2rem;
background: rgba(255, 255, 255, 0.02);
border: 1px dashed rgba(255, 215, 0, 0.25);
border-radius: 0.5rem;
color: rgba(255, 215, 0, 0.6);
font-size: 0.875rem;
}
.markdown-content :global(.latex-diagram-icon) {
font-size: 1.25rem;
opacity: 0.8;
}
.markdown-content :global(.latex-diagram-text) {
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.75rem;
text-transform: uppercase;
letter-spacing: 0.05em;
}
</style>

View File

@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
// Uses API preview data when available, falls back to local estimation
const placementPreview = $derived(() => {
const nodeArray = nodeList();
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
const numNodes = nodeArray.length;
const iconSize = numNodes === 1 ? 50 : 36;

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: typeof data.nodes[string]): string | null {
function checkNode(node: NodeInfo | undefined): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
@@ -39,17 +39,19 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
return matchFromInterfaces.name;
}
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
if (node.ip_to_interface) {
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
if (mapped && mapped.trim().length > 0) {
return mapped;
}
}
return null;
}
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
@@ -255,21 +257,24 @@ function wrapLine(text: string, maxLen: number): string[] {
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
const pairMap = new Map<string, PairEntry>();
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
if (edge.source === a) entry.aToB = true;
else entry.bToA = true;
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
const ip = edge.sendBackIp || '?';
const ifaceInfo = getInterfaceLabel(edge.source, ip);
entry.connections.push({
from: edge.source,
@@ -338,9 +343,8 @@ function wrapLine(text: string, maxLen: number): string[] {
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
@@ -381,32 +385,32 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, typeof debugEdgeLabels> = {
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
if (quadrantEdges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
edges.forEach(edge => {
quadrantEdges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;

View File

@@ -1,8 +1,7 @@
export { default as TopologyGraph } from './TopologyGraph.svelte';
export { default as ChatForm } from './ChatForm.svelte';
export { default as ChatMessages } from './ChatMessages.svelte';
export { default as ChatAttachments } from './ChatAttachments.svelte';
export { default as ChatSidebar } from './ChatSidebar.svelte';
export { default as ModelCard } from './ModelCard.svelte';
export { default as MarkdownContent } from './MarkdownContent.svelte';
export { default as TopologyGraph } from "./TopologyGraph.svelte";
export { default as ChatForm } from "./ChatForm.svelte";
export { default as ChatMessages } from "./ChatMessages.svelte";
export { default as ChatAttachments } from "./ChatAttachments.svelte";
export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";

View File

File diff suppressed because it is too large Load Diff

View File

@@ -13,55 +13,124 @@ export interface ChatUploadedFile {
}
export interface ChatAttachment {
type: 'image' | 'text' | 'pdf' | 'audio';
type: "image" | "text" | "pdf" | "audio";
name: string;
content?: string;
base64Url?: string;
mimeType?: string;
}
export type FileCategory = 'image' | 'text' | 'pdf' | 'audio' | 'unknown';
export type FileCategory = "image" | "text" | "pdf" | "audio" | "unknown";
export const IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg'];
export const IMAGE_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/svg+xml'];
export const IMAGE_EXTENSIONS = [
".jpg",
".jpeg",
".png",
".gif",
".webp",
".svg",
];
export const IMAGE_MIME_TYPES = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"image/svg+xml",
];
export const TEXT_EXTENSIONS = [
'.txt', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', '.log',
'.js', '.ts', '.jsx', '.tsx', '.py', '.java', '.cpp', '.c', '.h',
'.css', '.html', '.htm', '.sql', '.sh', '.bat', '.rs', '.go',
'.rb', '.php', '.swift', '.kt', '.scala', '.r', '.dart', '.vue', '.svelte'
".txt",
".md",
".json",
".xml",
".yaml",
".yml",
".csv",
".log",
".js",
".ts",
".jsx",
".tsx",
".py",
".java",
".cpp",
".c",
".h",
".css",
".html",
".htm",
".sql",
".sh",
".bat",
".rs",
".go",
".rb",
".php",
".swift",
".kt",
".scala",
".r",
".dart",
".vue",
".svelte",
];
export const TEXT_MIME_TYPES = [
'text/plain', 'text/markdown', 'text/csv', 'text/html', 'text/css',
'application/json', 'application/xml', 'text/xml', 'application/javascript',
'text/javascript', 'application/typescript'
"text/plain",
"text/markdown",
"text/csv",
"text/html",
"text/css",
"application/json",
"application/xml",
"text/xml",
"application/javascript",
"text/javascript",
"application/typescript",
];
export const PDF_EXTENSIONS = ['.pdf'];
export const PDF_MIME_TYPES = ['application/pdf'];
export const PDF_EXTENSIONS = [".pdf"];
export const PDF_MIME_TYPES = ["application/pdf"];
export const AUDIO_EXTENSIONS = ['.mp3', '.wav', '.ogg', '.m4a'];
export const AUDIO_MIME_TYPES = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/mp4'];
export const AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".m4a"];
export const AUDIO_MIME_TYPES = [
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/mp4",
];
/**
* Get file category based on MIME type and extension
*/
export function getFileCategory(mimeType: string, fileName: string): FileCategory {
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf('.'));
if (IMAGE_MIME_TYPES.includes(mimeType) || IMAGE_EXTENSIONS.includes(extension)) {
return 'image';
export function getFileCategory(
mimeType: string,
fileName: string,
): FileCategory {
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf("."));
if (
IMAGE_MIME_TYPES.includes(mimeType) ||
IMAGE_EXTENSIONS.includes(extension)
) {
return "image";
}
if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {
return 'pdf';
return "pdf";
}
if (AUDIO_MIME_TYPES.includes(mimeType) || AUDIO_EXTENSIONS.includes(extension)) {
return 'audio';
if (
AUDIO_MIME_TYPES.includes(mimeType) ||
AUDIO_EXTENSIONS.includes(extension)
) {
return "audio";
}
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith('text/')) {
return 'text';
if (
TEXT_MIME_TYPES.includes(mimeType) ||
TEXT_EXTENSIONS.includes(extension) ||
mimeType.startsWith("text/")
) {
return "text";
}
return 'unknown';
return "unknown";
}
/**
@@ -69,36 +138,36 @@ export function getFileCategory(mimeType: string, fileName: string): FileCategor
*/
export function getAcceptString(categories: FileCategory[]): string {
const accepts: string[] = [];
for (const category of categories) {
switch (category) {
case 'image':
case "image":
accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);
break;
case 'text':
case "text":
accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);
break;
case 'pdf':
case "pdf":
accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);
break;
case 'audio':
case "audio":
accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);
break;
}
}
return accepts.join(',');
return accepts.join(",");
}
/**
* Format file size for display
*/
export function formatFileSize(bytes: number): string {
if (bytes === 0) return '0 B';
if (bytes === 0) return "0 B";
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const sizes = ["B", "KB", "MB", "GB"];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + " " + sizes[i];
}
/**
@@ -128,42 +197,44 @@ export function readFileAsText(file: File): Promise<string> {
/**
* Process uploaded files into ChatUploadedFile format
*/
export async function processUploadedFiles(files: File[]): Promise<ChatUploadedFile[]> {
export async function processUploadedFiles(
files: File[],
): Promise<ChatUploadedFile[]> {
const results: ChatUploadedFile[] = [];
for (const file of files) {
const id = Date.now().toString() + Math.random().toString(36).substring(2, 9);
const id =
Date.now().toString() + Math.random().toString(36).substring(2, 9);
const category = getFileCategory(file.type, file.name);
const base: ChatUploadedFile = {
id,
name: file.name,
size: file.size,
type: file.type,
file
file,
};
try {
if (category === 'image') {
if (category === "image") {
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else if (category === 'text' || category === 'unknown') {
} else if (category === "text" || category === "unknown") {
const textContent = await readFileAsText(file);
results.push({ ...base, textContent });
} else if (category === 'pdf') {
} else if (category === "pdf") {
results.push(base);
} else if (category === 'audio') {
} else if (category === "audio") {
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else {
results.push(base);
}
} catch (error) {
console.error('Error processing file:', file.name, error);
console.error("Error processing file:", file.name, error);
results.push(base);
}
}
return results;
}

View File

@@ -400,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -436,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -593,7 +591,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
// Unwrap the instance
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { isDownloading: false, progress: null, statusText: 'UNKNOWN', perNode: [] };
return { isDownloading: false, progress: null, statusText: 'PREPARING', perNode: [] };
}
const inst = instance as { shardAssignments?: { nodeToRunner?: Record<string, string>; runnerToShard?: Record<string, unknown>; modelId?: string } };
@@ -706,7 +704,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function deriveInstanceStatus(instanceWrapped: unknown): { statusText: string; statusClass: string } {
const [, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { statusText: 'UNKNOWN', statusClass: 'inactive' };
return { statusText: 'PREPARING', statusClass: 'inactive' };
}
const inst = instance as { shardAssignments?: { runnerToShard?: Record<string, unknown> } };
@@ -735,7 +733,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const has = (s: string) => statuses.includes(s);
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
if (statuses.length === 0) return { statusText: 'PREPARING', statusClass: 'inactive' };
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
@@ -763,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -771,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);
@@ -895,7 +915,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
const [tag, shard] = getTagged(shardWrapped);
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
return { runnerId, tag, deviceRank };
});
@@ -1267,9 +1287,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
</div>
<div
<div
bind:this={instancesContainerRef}
class="max-h-72 space-y-3 overflow-y-auto"
class="max-h-72 xl:max-h-96 space-y-3 overflow-y-auto overflow-x-hidden py-px"
>
{#each Object.entries(instanceData) as [id, instance]}
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
@@ -1773,7 +1793,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
</div>
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
<div class="space-y-3 max-h-72 xl:max-h-96 overflow-y-auto overflow-x-hidden py-px pr-1">
{#each Object.entries(instanceData) as [id, instance]}
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
{@const statusText = downloadInfo.statusText}

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;
@@ -199,7 +199,13 @@
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
?? (downloadPayload as Record<string, unknown>).downloadProgress
?? {};
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
const totalBytes = getBytes(
(downloadPayload as Record<string, unknown>).total_bytes
?? (downloadPayload as Record<string, unknown>).totalBytes
?? (rawProgress as Record<string, unknown>).total_bytes
?? (rawProgress as Record<string, unknown>).totalBytes
);
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
@@ -332,8 +338,13 @@
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
</div>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
<div>
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
</div>
<div class="text-exo-light-gray normal-case tracking-normal">
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
</div>
</div>
</div>
@@ -385,7 +396,7 @@
</div>
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
{#if model.status !== 'completed'}
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
{/if}

View File

@@ -1,16 +1,15 @@
import tailwindcss from '@tailwindcss/vite';
import { sveltekit } from '@sveltejs/kit/vite';
import { defineConfig } from 'vite';
import tailwindcss from "@tailwindcss/vite";
import { sveltekit } from "@sveltejs/kit/vite";
import { defineConfig } from "vite";
export default defineConfig({
plugins: [tailwindcss(), sveltekit()],
server: {
proxy: {
'/v1': 'http://localhost:52415',
'/state': 'http://localhost:52415',
'/models': 'http://localhost:52415',
'/instance': 'http://localhost:52415'
}
}
"/v1": "http://localhost:52415",
"/state": "http://localhost:52415",
"/models": "http://localhost:52415",
"/instance": "http://localhost:52415",
},
},
});

212
docs/api.md Normal file
View File

@@ -0,0 +1,212 @@
# EXO API Technical Reference
This document describes the REST API exposed by the **EXO ** service, as implemented in:
`src/exo/master/api.py`
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
Base URL example:
```
http://localhost:52415
```
## 1. General / Meta Endpoints
### Get Master Node ID
**GET** `/node_id`
Returns the identifier of the current master node.
**Response (example):**
```json
{
"node_id": "node-1234"
}
```
### Get Cluster State
**GET** `/state`
Returns the current state of the cluster, including nodes and active instances.
**Response:**
JSON object describing topology, nodes, and instances.
### Get Events
**GET** `/events`
Returns the list of internal events recorded by the master (mainly for debugging and observability).
**Response:**
Array of event objects.
## 2. Model Instance Management
### Create Instance
**POST** `/instance`
Creates a new model instance in the cluster.
**Request body (example):**
```json
{
"instance": {
"model_id": "llama-3.2-1b",
"placement": { }
}
}
```
**Response:**
JSON description of the created instance.
### Delete Instance
**DELETE** `/instance/{instance_id}`
Deletes an existing instance by ID.
**Path parameters:**
* `instance_id`: string, ID of the instance to delete
**Response:**
Status / confirmation JSON.
### Get Instance
**GET** `/instance/{instance_id}`
Returns details of a specific instance.
**Path parameters:**
* `instance_id`: string
**Response:**
JSON description of the instance.
### Preview Placements
**GET** `/instance/previews?model_id=...`
Returns possible placement previews for a given model.
**Query parameters:**
* `model_id`: string, required
**Response:**
Array of placement preview objects.
### Compute Placement
**GET** `/instance/placement`
Computes a placement for a potential instance without creating it.
**Query parameters (typical):**
* `model_id`: string
* `sharding`: string or config
* `instance_meta`: JSON-encoded metadata
* `min_nodes`: integer
**Response:**
JSON object describing the proposed placement / instance configuration.
### Place Instance (Dry Operation)
**POST** `/place_instance`
Performs a placement operation for an instance (planning step), without necessarily creating it.
**Request body:**
JSON describing the instance to be placed.
**Response:**
Placement result.
## 3. Models
### List Models
**GET** `/models`
**GET** `/v1/models` (alias)
Returns the list of available models and their metadata.
**Response:**
Array of model descriptors.
## 4. Inference / Chat Completions
### OpenAI-Compatible Chat Completions
**POST** `/v1/chat/completions`
Executes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.
**Request body (example):**
```json
{
"model": "llama-3.2-1b",
"messages": [
{ "role": "system", "content": "You are a helpful assistant." },
{ "role": "user", "content": "Hello" }
],
"stream": false
}
```
**Response:**
OpenAI-compatible chat completion response.
### Benchmarked Chat Completions
**POST** `/bench/chat/completions`
Same as `/v1/chat/completions`, but also returns performance and generation statistics.
**Request body:**
Same schema as `/v1/chat/completions`.
**Response:**
Chat completion plus benchmarking metrics.
## 5. Complete Endpoint Summary
```
GET /node_id
GET /state
GET /events
POST /instance
GET /instance/{instance_id}
DELETE /instance/{instance_id}
GET /instance/previews
GET /instance/placement
POST /place_instance
GET /models
GET /v1/models
POST /v1/chat/completions
POST /bench/chat/completions
```
## 6. Notes
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 187 KiB

185
flake.lock generated
View File

@@ -1,5 +1,42 @@
{
"nodes": {
"crane": {
"locked": {
"lastModified": 1767744144,
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
"owner": "ipetkov",
"repo": "crane",
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
"type": "github"
},
"original": {
"owner": "ipetkov",
"repo": "crane",
"type": "github"
}
},
"dream2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
"owner": "nix-community",
"repo": "dream2nix",
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "dream2nix",
"type": "github"
}
},
"fenix": {
"inputs": {
"nixpkgs": [
@@ -8,11 +45,11 @@
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1761893049,
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
"lastModified": 1768287139,
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
"owner": "nix-community",
"repo": "fenix",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
"type": "github"
},
"original": {
@@ -21,25 +58,59 @@
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"flake-compat": {
"flake": false,
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-parts": {
"inputs": {
"nixpkgs-lib": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1768135262,
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "flake-parts",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1768127708,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs-swift": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
@@ -50,27 +121,74 @@
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
}
},
"purescript-overlay": {
"inputs": {
"flake-compat": "flake-compat",
"nixpkgs": [
"dream2nix",
"nixpkgs"
],
"slimlock": "slimlock"
},
"locked": {
"lastModified": 1728546539,
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
"type": "github"
},
"original": {
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"type": "github"
}
},
"root": {
"inputs": {
"crane": "crane",
"dream2nix": "dream2nix",
"fenix": "fenix",
"flake-utils": "flake-utils",
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1761849405,
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
"lastModified": 1768224240,
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
"rev": "725349602e525df37f377701e001fe8aab807878",
"type": "github"
},
"original": {
@@ -80,18 +198,25 @@
"type": "github"
}
},
"systems": {
"slimlock": {
"inputs": {
"nixpkgs": [
"dream2nix",
"purescript-overlay",
"nixpkgs"
]
},
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"lastModified": 1688756706,
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
"owner": "thomashoneyman",
"repo": "slimlock",
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"owner": "thomashoneyman",
"repo": "slimlock",
"type": "github"
}
},
@@ -102,11 +227,11 @@
]
},
"locked": {
"lastModified": 1762938485,
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
"lastModified": 1768158989,
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
"owner": "numtide",
"repo": "treefmt-nix",
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
"type": "github"
},
"original": {

196
flake.nix
View File

@@ -3,118 +3,134 @@
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
# Provides Rust dev-env integration:
flake-parts = {
url = "github:hercules-ci/flake-parts";
inputs.nixpkgs-lib.follows = "nixpkgs";
};
crane.url = "github:ipetkov/crane";
fenix = {
url = "github:nix-community/fenix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Provides formatting infrastructure:
treefmt-nix = {
url = "github:numtide/treefmt-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
};
# TODO: figure out caching story
# nixConfig = {
# # nix community cachix
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
# extra-substituters = "https://nix-community.cachix.org";
# };
nixConfig = {
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
extra-substituters = "https://exo.cachix.org";
};
outputs =
inputs:
let
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
systems = [
"x86_64-linux"
"aarch64-darwin"
"aarch64-linux"
];
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
in
inputs.flake-utils.lib.eachSystem systems (
system:
let
pkgs = import inputs.nixpkgs {
inherit system;
overlays = [ inputs.fenix.overlays.default ];
};
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
projectRootFile = "flake.nix";
programs.ruff-format.enable = true;
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
programs.rustfmt.enable = true;
programs.rustfmt.package = (fenixToolchain system).rustfmt;
programs.nixpkgs-fmt.enable = true;
};
in
{
formatter = treefmtEval.config.build.wrapper;
checks.formatting = treefmtEval.config.build.check inputs.self;
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = pkgs.mkShell {
packages =
with pkgs;
[
# PYTHON
python313
uv
ruff
basedpyright
imports = [
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
];
# RUST
((fenixToolchain system).withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
treefmt = {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
};
rustfmt = {
enable = true;
package = config.rust.toolchain;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format = {
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
};
};
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
]);
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];
packages =
[
# FORMATTING
config.treefmt.build.wrapper
# PYTHON
python313
uv
ruff
basedpyright
# RUST
config.rust.toolchain
maturin
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ lib.optionals stdenv.isLinux [
unixtools.ifconfig
]
++ lib.optionals stdenv.isDarwin [
macmon
];
OPENSSL_NO_VENDOR = "1";
shellHook = ''
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
${lib.optionalString stdenv.isLinux ''
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
''}
'';
};
};
}
);
};
}

View File

@@ -1,3 +1,5 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -8,33 +8,23 @@ dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"typeguard>=4.4.4",
"pydantic>=2.11.7",
"base58>=2.1.1",
"cryptography>=45.0.5",
"fastapi>=0.116.1",
"filelock>=3.18.0",
"aiosqlite>=0.21.0",
"networkx>=3.5",
"protobuf>=6.32.0",
"rich>=14.1.0",
"rustworkx>=0.17.1",
"sqlmodel>=0.0.24",
"sqlalchemy[asyncio]>=2.0.43",
"greenlet>=3.2.4",
"huggingface-hub>=0.33.4",
"psutil>=7.0.0",
"loguru>=0.7.3",
"textual>=5.3.0",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"bidict>=0.23.1",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx-lm>=0.28.3",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
]
[project.scripts]
@@ -45,6 +35,7 @@ exo = "exo.main:main"
# dependencies only required for development
[dependency-groups]
dev = [
"basedpyright>=1.29.0",
"pyinstaller>=6.17.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
@@ -110,6 +101,7 @@ root = "src"
# supported platforms for this project
[tool.uv]
prerelease = "allow"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
@@ -135,3 +127,6 @@ env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

145
rust/parts.nix Normal file
View File

@@ -0,0 +1,145 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
"rust-analyzer"
];
# Crane with fenix toolchain
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
# Source filtering - only include rust/ directory and root Cargo files
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
parentDir = builtins.dirOf path;
inRustDir =
(lib.hasInfix "/rust/" path)
|| (lib.hasSuffix "/rust" parentDir)
|| (baseName == "rust" && type == "directory");
isRootCargoFile =
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
&& (builtins.dirOf path == toString inputs.self);
in
isRootCargoFile
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
};
# Common arguments for all Rust builds
commonArgs = {
inherit src;
pname = "exo-rust";
version = "0.0.1";
strictDeps = true;
nativeBuildInputs = [
pkgs.pkg-config
pkgs.python313 # Required for pyo3-build-config
];
buildInputs = [
pkgs.openssl
pkgs.python313 # Required for pyo3 tests
];
OPENSSL_NO_VENDOR = "1";
# Required for pyo3 tests to find libpython
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
};
# Build dependencies once for caching
cargoArtifacts = craneLib.buildDepsOnly (
commonArgs
// {
cargoExtraArgs = "--workspace";
}
);
in
{
# Export toolchain for use in treefmt and devShell
options.rust = {
toolchain = lib.mkOption {
type = lib.types.package;
default = rustToolchain;
description = "The Rust toolchain to use";
};
};
config = {
packages = {
# Python bindings wheel via maturin
exo_pyo3_bindings = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
pname = "exo_pyo3_bindings";
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
pkgs.maturin
];
buildPhaseCargoCommand = ''
maturin build \
--release \
--manylinux off \
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
--features "pyo3/extension-module,pyo3/experimental-async" \
--interpreter ${pkgs.python313}/bin/python \
--out dist
'';
# Don't use crane's default install behavior
doNotPostBuildInstallCargoBinaries = true;
installPhaseCommand = ''
mkdir -p $out
cp dist/*.whl $out/
'';
}
);
};
checks = {
# Full workspace build (all crates)
cargo-build = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Run tests with nextest
cargo-nextest = craneLib.cargoNextest (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Build documentation
cargo-doc = craneLib.cargoDoc (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
};
};
};
}

View File

@@ -1,47 +0,0 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -1,4 +0,0 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -1,69 +0,0 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -1,6 +1,7 @@
import argparse
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -195,6 +196,8 @@ class Node:
def main():
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
@@ -202,6 +205,14 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -215,6 +226,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -256,6 +268,20 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,31 +1,30 @@
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
import anyio
from anyio import create_task_group
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
get_model_card,
)
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -35,6 +34,8 @@ from exo.shared.types.api import (
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
GenerationStats,
ModelList,
@@ -55,9 +56,13 @@ from exo.shared.types.commands import (
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -67,8 +72,6 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -87,12 +90,12 @@ def chunk_to_response(
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card.metadata
return model_card
else:
return await get_model_meta(model_id)
return await get_model_card(model_id)
class API:
@@ -123,6 +126,7 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -153,6 +157,21 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
self.app.exception_handler(HTTPException)(self.http_exception_handler)
async def http_exception_handler(
self, _: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -181,7 +200,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_meta=await resolve_model_meta(payload.model_id),
model_card=await resolve_model_card(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -191,15 +210,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=command.model_meta,
model_card=command.model_card,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -216,26 +235,28 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
model_card=model_card,
)
async def get_placement(
self,
model_id: str,
model_id: ModelId,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_meta = await resolve_model_meta(model_id)
model_card = await resolve_model_card(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -262,7 +283,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -280,32 +301,33 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for card in cards:
model_meta = card.metadata
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
)
except ValueError as exc:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -316,17 +338,17 @@ class API:
]
if len(new_instances) != 1:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -335,7 +357,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_meta.storage_size.in_bytes
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -343,14 +365,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
card.model_id,
model_card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -358,7 +380,7 @@ class API:
error=None,
)
)
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -381,35 +403,8 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -417,16 +412,10 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -442,11 +431,23 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -458,7 +459,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -466,7 +467,13 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -495,7 +502,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -503,7 +510,13 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -542,10 +555,8 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -562,18 +573,17 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
return await self._collect_chat_completion(command.command_id)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -589,19 +599,15 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
for memory in self.state.node_memory.values():
total_available += memory.ram_available
return total_available
@@ -610,13 +616,13 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.short_id,
id=card.model_id,
hugging_face_id=card.model_id,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
)
for card in MODEL_CARDS.values()
]
@@ -654,14 +660,14 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -27,6 +27,7 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
@@ -158,6 +159,8 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -200,9 +203,7 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
connected_node_ids = set(self.state.topology.list_nodes())
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
@@ -237,6 +238,8 @@ class Master:
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)

View File

@@ -6,23 +6,25 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -52,37 +54,33 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
candidate_cycles, node_memory, command.model_card.storage_size
)
if not cycles_with_sufficient_memory:
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
if not command.model_card.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
if command.model_card.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -92,44 +90,38 @@ def place_instance(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycles_with_leaf_nodes: list[Cycle] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node.node_id) for node in cycle)
if any(topology.node_is_leaf(node_id) for node_id in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
(node_memory[node_id].ram_available for node_id in cycle),
start=Memory(),
),
)
shard_assignments = get_shard_assignments(
command.model_meta, selected_cycle, command.sharding
command.model_card, selected_cycle, command.sharding, node_memory
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected ibv for a single node instance; falling back to MlxRing"
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -137,19 +129,20 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
[node_id for node_id in selected_cycle],
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
selected_cycle,
coordinator=selected_cycle.node_ids[0],
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_network=node_network,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
@@ -158,6 +151,7 @@ def place_instance(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
node_network=node_network,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,

View File

@@ -1,15 +1,13 @@
from collections.abc import Generator
from typing import TypeGuard, cast
from collections.abc import Generator, Mapping
from loguru import logger
from pydantic import BaseModel
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -19,67 +17,113 @@ from exo.shared.types.worker.shards import (
)
class NodeWithProfile(BaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
cycles: list[Cycle],
node_memory: Mapping[NodeId, MemoryUsage],
required_memory: Memory,
) -> list[Cycle]:
filtered_cycles: list[Cycle] = []
for cycle in cycles:
if not narrow_all_nodes(cycle):
if not all(node in node_memory for node in cycle):
continue
total_mem = sum(
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[NodeInfo], cycle))
filtered_cycles.append(cycle)
return filtered_cycles
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
def get_smallest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
):
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -88,11 +132,11 @@ def get_shard_assignments_for_pipeline_parallel(
)
runner_to_shard[runner_id] = shard
node_to_runner[node.node_id] = runner_id
node_to_runner[node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -101,17 +145,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
model_card: ModelCard,
cycle: Cycle,
):
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node in enumerate(selected_cycle):
for i, node_id in enumerate(cycle):
shard = TensorShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -122,10 +166,10 @@ def get_shard_assignments_for_tensor_parallel(
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node.node_id] = runner_id
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -134,22 +178,22 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeInfo],
model_card: ModelCard,
cycle: Cycle,
sharding: Sharding,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_meta=model_meta,
selected_cycle=selected_cycle,
model_card=model_card,
cycle=cycle,
node_memory=node_memory,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_meta=model_meta,
selected_cycle=selected_cycle,
model_card=model_card,
cycle=cycle,
)
@@ -164,38 +208,40 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
)
return []
cycle = cycles[0]
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
if cycle_digraph.is_thunderbolt_cycle(cycle):
get_thunderbolt = True
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
cycle = cycles[0]
hosts: list[Host] = []
for i in range(len(cycle)):
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
current_node = cycle.node_ids[i]
next_node = cycle.node_ids[(i + 1) % len(cycle)]
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break
for connection in cycle_digraph.get_all_connections_between(
source=current_node, sink=next_node
):
if not isinstance(connection, SocketConnection):
continue
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
return hosts
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
def get_mlx_jaccl_devices_matrix(
selected_cycle: list[NodeId],
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -214,72 +260,37 @@ def get_mlx_ibv_devices_matrix(
if i == j:
continue
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
break
else:
logger.warning(
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
f"Failed to find interface name between {node_i} and {node_j}"
)
raise ValueError(
"Current ibv backend requires all-to-all rdma connections"
"Current jaccl backend requires all-to-all RDMA connections"
)
return matrix
def _find_connection_ip(
node_i: NodeInfo,
node_j: NodeInfo,
node_i: NodeId,
node_j: NodeId,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
def _find_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
ip_address: str, node_network: NodeNetworkInfo
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
for interface in node_network.interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -287,7 +298,10 @@ def _find_interface_name_for_ip(
def _find_ip_prioritised(
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
node_id: NodeId,
other_node_id: NodeId,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -298,9 +312,14 @@ def _find_ip_prioritised(
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
iface_map = {
_find_interface_name_for_ip(
ip, node_network.get(other_node_id, NodeNetworkInfo())
): ip
for ip, _ in ips
}
en0_ip = iface_map.get("en0")
if en0_ip:
@@ -324,9 +343,10 @@ def _find_ip_prioritised(
def get_mlx_ring_hosts_by_node(
selected_cycle: list[NodeInfo],
selected_cycle: Cycle,
cycle_digraph: Topology,
ephemeral_port: int,
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
@@ -341,14 +361,13 @@ def get_mlx_ring_hosts_by_node(
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
for rank, node_id in enumerate(selected_cycle):
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
hosts_for_node: list[Host] = []
for idx, other_node in enumerate(selected_cycle):
for idx, other_node_id in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
@@ -358,10 +377,12 @@ def get_mlx_ring_hosts_by_node(
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_network
)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
@@ -375,31 +396,34 @@ def get_mlx_ring_hosts_by_node(
def get_mlx_jaccl_coordinators(
selected_cycle: list[NodeInfo],
coordinator: NodeId,
coordinator_port: int,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
logger.info(f"Selecting coordinator: {coordinator}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
return "0.0.0.0"
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
if ip is not None:
return ip
logger.warning(
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
n: f"{get_ip_for_node(n)}:{coordinator_port}"
for n in cycle_digraph.list_nodes()
}

View File

@@ -1,67 +1,37 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
MemoryUsage,
NetworkInterfaceInfo,
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
def create_node_memory(memory: int) -> MemoryUsage:
return MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
)
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def create_node_network() -> NodeNetworkInfo:
return NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
for i in range(10)
]
)
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
return _create_connection
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)

View File

@@ -0,0 +1,107 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -7,6 +7,7 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -19,15 +20,12 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodePerformanceMeasured,
NodeGatheredInfo,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
MemoryUsage,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -75,29 +73,22 @@ async def test_master():
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodePerformanceProfile event
logger.info("inject a NodePerformanceProfile event")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
session=session_id,
event=(
NodePerformanceMeasured(
NodeGatheredInfo(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
)
),
@@ -108,7 +99,7 @@ async def test_master():
logger.info("wait for initial topology event")
while len(list(master.state.topology.list_nodes())) == 0:
await anyio.sleep(0.001)
while len(master.state.node_profiles) == 0:
while len(master.state.node_memory) == 0:
await anyio.sleep(0.001)
logger.info("inject a CreateInstance Command")
@@ -118,9 +109,8 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -163,7 +153,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
@@ -176,9 +166,8 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -1,20 +1,24 @@
from typing import Callable
import pytest
from loguru import logger
from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_node_memory,
create_node_network,
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -26,11 +30,6 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -44,21 +43,20 @@ def instance() -> Instance:
@pytest.fixture
def model_meta() -> ModelMetadata:
return ModelMetadata(
def model_card() -> ModelCard:
return ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -70,47 +68,75 @@ def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
[
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
((312, 468, 1092), 12, (2, 3, 7)),
],
)
def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
model_card: ModelCard,
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
cic = place_instance_command(model_card)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_id_c, sink=node_id_b, edge=create_socket_connection(4)
)
conn_a_c = Connection(
source=node_id_a, sink=node_id_c, edge=create_socket_connection(5)
)
conn_b_a = Connection(
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
)
node_memory = {
node_id_a: create_node_memory(available_memory[0]),
node_id_b: create_node_memory(available_memory[1]),
node_id_c: create_node_memory(available_memory[2]),
}
node_network = {
node_id_a: create_node_network(),
node_id_b: create_node_network(),
node_id_c: create_node_network(),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
topology.add_connection(conn_b_a)
# act
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, node_memory, node_network)
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_meta.model_id
assert instance.shard_assignments.model_id == model_card.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -130,23 +156,22 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_exact_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, node_memory, node_network)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -157,23 +182,22 @@ def test_get_instance_placements_one_node_exact_fit(
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1001 * 1024, node_id))
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1001 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, node_memory, node_network)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -184,17 +208,16 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_not_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -202,7 +225,7 @@ def test_get_instance_placements_one_node_not_fit(
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {})
place_instance(cic, topology, {}, node_memory, node_network)
def test_get_transition_events_no_change(instance: Instance):
@@ -247,217 +270,175 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
def test_placement_selects_leaf_nodes(
model_card: ModelCard,
):
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
# arrange
topology = Topology()
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_card.storage_size = Memory.from_bytes(1000)
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
node_memory = {
node_id_a: create_node_memory(500),
node_id_b: create_node_memory(600),
node_id_c: create_node_memory(600),
node_id_d: create_node_memory(500),
}
node_network = {
node_id_a: create_node_network(),
node_id_b: create_node_network(),
node_id_c: create_node_network(),
node_id_d: create_node_network(),
}
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
# Daisy chain topology (directed)
topology.add_connection(
Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_d, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
)
# Act
placements = place_instance(cic, topology, {})
cic = place_instance_command(model_card=model_card)
# Assert: D-E-F cycle should be selected as it has more total memory
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
instance = list(placements.values())[0]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(
node_id_c,
node_id_d,
)
)
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
model_card: ModelCard,
):
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
node_memory = {
node_a: create_node_memory(500),
node_b: create_node_memory(500),
node_c: create_node_memory(500),
}
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="192.168.1.100",
ip_address="10.0.0.1",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
)
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
)
node_network = {
node_a: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_b: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_c: NodeNetworkInfo(interfaces=[ethernet_interface]),
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
# RDMA connections (directed)
topology.add_connection(
Connection(source=node_a, sink=node_b, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_a, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_c, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_c, sink=node_b, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_a, sink=node_c, edge=create_rdma_connection(5))
)
topology.add_connection(
Connection(source=node_c, sink=node_a, edge=create_rdma_connection(5))
)
# Ethernet connections (directed)
topology.add_connection(Connection(source=node_a, sink=node_b, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_a, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_b, edge=ethernet_conn))
cic = PlaceInstance(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
min_nodes=1,
)
placements = place_instance(cic, topology, {})
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.jaccl_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.ibv_devices
matrix = instance.jaccl_devices
assert len(matrix) == 3
for i in range(3):
assert matrix[i][i] is None
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3
@@ -469,7 +450,5 @@ def test_tensor_rdma_backend_connectivity_matrix(
if node_id == assigned_nodes[0]:
assert coordinator.startswith("0.0.0.0:")
else:
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
ip_part = coordinator.split(":")[0]
# Just verify it's a valid IP format
assert len(ip_part.split(".")) == 4

View File

@@ -1,162 +1,182 @@
from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_cycles_by_memory():
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
node1_mem = create_node_memory(1000 * 1024)
node2_mem = create_node_memory(1000 * 1024)
node_memory = {node1_id: node1_mem, node2_id: node2_mem}
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = topology.get_cycles()
cycles = [c for c in topology.get_cycles() if len(c) != 1]
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 2
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
assert set(n for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_cycles_by_insufficient_memory():
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
node1_mem = create_node_memory(1000 * 1024)
node2_mem = create_node_memory(1000 * 1024)
node_memory = {node1_id: node1_mem, node2_id: node2_mem}
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), Memory.from_kb(2001)
topology.get_cycles(), node_memory, Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_multiple_cycles_by_memory():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
node_a_mem = create_node_memory(500 * 1024)
node_b_mem = create_node_memory(500 * 1024)
node_c_mem = create_node_memory(1000 * 1024)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 3
assert set(n.node_id for n in filtered_cycles[0]) == {
assert set(n for n in filtered_cycles[0]) == {
node_a_id,
node_b_id,
node_c_id,
}
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_get_smallest_cycles():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
# act
smallest_cycles = get_smallest_cycles(topology.get_cycles())
smallest_cycles = get_smallest_cycles(cycles)
# assert
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
@pytest.mark.parametrize(
@@ -165,12 +185,12 @@ def test_get_smallest_cycles(
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -180,44 +200,61 @@ def test_get_shard_assignments(
node_b_id = NodeId()
node_c_id = NodeId()
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
# create connections (A -> B -> C -> A forms a 3-cycle, plus B -> A also exists)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(4)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
node_a_mem = create_node_memory(available_memory[0] * 1024)
node_b_mem = create_node_memory(available_memory[1] * 1024)
node_c_mem = create_node_memory(available_memory[2] * 1024)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
model_card = ModelCard(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
# pick the 3-node cycle deterministically (cycle ordering can vary)
selected_cycle = next(cycle for cycle in cycles if len(cycle) == 3)
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline
model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory
)
# assert
runner_id_a = shard_assignments.node_to_runner[node_a_id]
runner_id_b = shard_assignments.node_to_runner[node_b_id]
runner_id_c = shard_assignments.node_to_runner[node_c_id]
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
assert (
shard_assignments.runner_to_shard[runner_id_a].end_layer
- shard_assignments.runner_to_shard[runner_id_a].start_layer
@@ -228,30 +265,37 @@ def test_get_shard_assignments(
- shard_assignments.runner_to_shard[runner_id_b].start_layer
== expected_layers[1]
)
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
def test_get_hosts_from_subgraph():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
# act
hosts = get_hosts_from_subgraph(topology)
@@ -259,95 +303,68 @@ def test_get_hosts_from_subgraph(
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
Host(ip="169.254.0.1", port=1234),
Host(ip="169.254.0.2", port=1234),
Host(ip="169.254.0.3", port=1234),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
def test_get_mlx_jaccl_coordinators():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(5)
)
conn_a_c = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
network_a = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
)
network_b = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
)
network_c = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
)
node_network = {
node_a_id: network_a,
node_b_id: network_b,
node_c_id: network_c,
}
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
@@ -356,11 +373,12 @@ def test_get_mlx_jaccl_coordinators(
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_jaccl_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
node_a_id,
coordinator_port=5000,
cycle_digraph=topology,
node_network=node_network,
)
# assert
@@ -381,19 +399,129 @@ def test_get_mlx_jaccl_coordinators(
f"Coordinator for {node_id} should use port 5000"
)
# Rank 0 (node_a) treats this as the listen socket so should listen on all
# IPs
# Rank 0 (node_a) treats this as the listen socket so should listen on all IPs
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
"Rank 0 node should use localhost as coordinator"
"Rank 0 node should use 0.0.0.0 as coordinator listen address"
)
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
assert isinstance(conn_b_a.edge, SocketConnection)
assert (
coordinators[node_b_id] == f"{conn_b_a.edge.sink_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert isinstance(conn_c_a.edge, SocketConnection)
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
f"{conn_c_a.edge.sink_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises():
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a_mem = create_node_memory(900 * 1024)
node_b_mem = create_node_memory(50 * 1024)
node_c_mem = create_node_memory(10 * 1024) # Insufficient memory
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
model_card = ModelCard(
model_id=ModelId("test-model"),
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
)

View File

@@ -1,13 +1,9 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
from exo.shared.types.topology import Connection, SocketConnection
@pytest.fixture
@@ -16,189 +12,97 @@ def topology() -> Topology:
@pytest.fixture
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
def socket_connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=memory_profile,
network_interfaces=[],
system=system_profile,
)
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
def test_add_node(topology: Topology):
# arrange
node_id = NodeId()
# act
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
topology.add_node(node_id)
# assert
data = topology.get_node_profile(node_id)
assert data == node_profile
assert topology.node_is_leaf(node_id)
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_add_connection(topology: Topology, socket_connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
node_a = NodeId()
node_b = NodeId()
connection = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(connection)
# act
data = topology.get_connection_profile(connection)
data = list(topology.list_connections())
# assert
assert data == connection.connection_profile
assert data == [connection]
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_remove_connection_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, socket_connection: SocketConnection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
# act
topology.remove_connection(connection)
topology.remove_connection(conn)
# assert
assert topology.get_connection_profile(connection) is None
assert list(topology.get_all_connections_between(node_a, node_b)) == []
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, socket_connection: SocketConnection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
# act
topology.remove_node(connection.local_node_id)
topology.remove_node(node_b)
# assert
assert topology.get_node_profile(connection.local_node_id) is None
assert list(topology.out_edges(node_a)) == []
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_list_nodes(topology: Topology, socket_connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}
assert all(isinstance(node, NodeId) for node in nodes)
assert set(node for node in nodes) == set([node_a, node_b])

View File

@@ -11,10 +11,8 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeGatheredInfo,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -27,13 +25,27 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import NodeInfo
from exo.shared.types.topology import Connection, RDMAConnection
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacThunderboltConnections,
MacThunderboltIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
@@ -47,16 +59,12 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -188,120 +196,133 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
}
topology = copy.deepcopy(state.topology)
topology.remove_node(event.node_id)
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
node_system = {
key: value for key, value in state.node_system.items() if key != event.node_id
}
node_network = {
key: value for key, value in state.node_network.items() if key != event.node_id
}
node_thunderbolt = {
key: value
for key, value in state.node_thunderbolt.items()
if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
}
)
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_profiles": new_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
# Build update dict with only the mappings that change
update: dict[str, object] = {
"last_seen": {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
},
"topology": topology,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
match info:
case MacmonMetrics():
update["node_system"] = {
**state.node_system,
event.node_id: info.system_profile,
}
update["node_memory"] = {**state.node_memory, event.node_id: info.memory}
case MemoryUsage():
update["node_memory"] = {**state.node_memory, event.node_id: info}
case NodeConfig():
pass
case MiscData():
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"friendly_name": info.friendly_name}
)
update["node_identities"] = {
**state.node_identities,
event.node_id: new_identity,
}
case StaticNodeInformation():
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"model_id": info.model, "chip_id": info.chip}
)
update["node_identities"] = {
**state.node_identities,
event.node_id: new_identity,
}
case NodeNetworkInterfaces():
update["node_network"] = {
**state.node_network,
event.node_id: NodeNetworkInfo(interfaces=info.ifaces),
}
case MacThunderboltIdentifiers():
update["node_thunderbolt"] = {
**state.node_thunderbolt,
event.node_id: NodeThunderboltInfo(interfaces=info.idents),
}
case MacThunderboltConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_thunderbolt
for tb_ident in state.node_thunderbolt[nid].interfaces
}
as_rdma_conns = [
Connection(
source=event.node_id,
sink=conn_map[tb_conn.sink_uuid][0],
edge=RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
return state.model_copy(update=update)
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
topology = copy.deepcopy(state.topology)
topology.add_connection(event.conn)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.conn)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -38,6 +38,7 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"

View File

@@ -11,9 +11,6 @@ class InterceptLogger(HypercornLogger):
def __init__(self, config: Config):
super().__init__(config)
assert self.error_logger
# TODO: Decide if we want to provide access logs
# assert self.access_logger
# self.access_logger.handlers = [_InterceptHandler()]
self.error_logger.handlers = [_InterceptHandler()]
@@ -29,6 +26,11 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -1,572 +1,463 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
_card_cache: dict[str, "ModelCard"] = {}
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
name: str
description: str
tags: list[str]
metadata: ModelMetadata
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=False,
)
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
# "deepseek-v3-0324:4bit": ModelCard(
# short_id="deepseek-v3-0324:4bit",
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
# name="DeepSeek V3 0324 (4-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
# pretty_name="DeepSeek V3 0324 (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# ),
# ),
# "deepseek-v3-0324": ModelCard(
# short_id="deepseek-v3-0324",
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
# name="DeepSeek V3 0324 (8-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
# pretty_name="DeepSeek V3 0324 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# ),
# ),
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# "deepseek-v3.2": ModelCard(
# short_id="deepseek-v3.2",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# name="DeepSeek V3.2 (8-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# pretty_name="DeepSeek V3.2 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-v3.2-4bit": ModelCard(
# short_id="deepseek-v3.2-4bit",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# name="DeepSeek V3.2 (4-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# pretty_name="DeepSeek V3.2 (4-bit)",
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# deepseek r1
# "deepseek-r1-0528-4bit": ModelCard(
# short_id="deepseek-r1-0528-4bit",
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
# name="DeepSeek-R1-0528 (4-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
# pretty_name="DeepSeek R1 671B (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-r1-0528": ModelCard(
# short_id="deepseek-r1-0528",
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
# name="DeepSeek-R1-0528 (8-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
# pretty_name="DeepSeek R1 671B (8-bit)",
# storage_size=Memory.from_bytes(754998771712),
# n_layers=61,
# . hidden_size=7168,
# ),
# ),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
# Needs to be quantized g32 or g16.
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
short_id="glm-4.5-air-8bit",
# Needs to be quantized g32 or g16 to work with tensor parallel
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-8bit-gs32": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
"minimax-m2.1-3bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# storage_size=Memory.from_kb(133_000_000),
# n_layers=88,
# hidden_size=12288,
# supports_tensor=True,
# ),
# ),
}
from exo.worker.download.download_utils import ( # noqa: E402
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_card(model_id: ModelId) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_card(model_id: ModelId) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == model_id),
None,
)
return ModelCard(
model_id=model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -1,126 +0,0 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_meta_cache: dict[str, ModelMetadata] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,9 +31,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})
@@ -39,7 +43,4 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection
from exo.shared.types.topology import Connection, SocketConnection
def test_state_serialization_roundtrip() -> None:
@@ -12,9 +12,11 @@ def test_state_serialization_roundtrip() -> None:
node_b = NodeId("node-b")
connection = Connection(
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
source=node_a,
sink=node_b,
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
),
)
state = State()
@@ -23,5 +25,11 @@ def test_state_serialization_roundtrip() -> None:
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert (
state.topology.to_snapshot().nodes
== restored_state.topology.to_snapshot().nodes
)
assert set(state.topology.to_snapshot().connections) == set(
restored_state.topology.to_snapshot().connections
)
assert restored_state.model_dump_json() == json_repr

View File

@@ -1,203 +1,227 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.topology import (
Connection,
Cycle,
RDMAConnection,
SocketConnection,
)
class TopologySnapshot(BaseModel):
nodes: list[NodeInfo]
connections: list[Connection]
nodes: Sequence[NodeId]
connections: Mapping[
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
]
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
model_config = ConfigDict(frozen=True, extra="forbid")
@dataclass
class Topology:
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
nodes=list(self.list_nodes()), connections=self.map_connections()
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node in snapshot.nodes:
for node_id in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node)
topology.add_node(node_id)
for connection in snapshot.connections:
topology.add_connection(connection)
for source in snapshot.connections:
for sink in snapshot.connections[source]:
for edge in snapshot.connections[source][sink]:
topology.add_connection(
Connection(source=source, sink=sink, edge=edge)
)
return topology
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
return
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
]
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
def out_edges(self, node_id: NodeId) -> Iterable[Connection]:
if node_id not in self._vertex_indices:
return []
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
return (
Connection(source=self._graph[source], sink=self._graph[sink], edge=edge)
for source, sink, edge in self._graph.out_edges(
self._vertex_indices[node_id]
)
]
)
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._node_id_to_rx_id_map
return node_id in self._vertex_indices
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
connection: Connection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
def add_connection(self, conn: Connection) -> None:
source, sink, edge = conn.source, conn.sink, conn.edge
del conn
if edge in self.get_all_connections_between(source, sink):
return
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
_ = self._graph.add_edge(src_id, sink_id, edge)
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
try:
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def list_connections(
self,
) -> Iterable[Connection]:
return (
(
Connection(
source=self._graph[src_id],
sink=self._graph[sink_id],
edge=connection,
)
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._node_id_to_rx_id_map:
if node_id not in self._vertex_indices:
return
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
rx_idx = self._vertex_indices[node_id]
self._graph.remove_node(rx_idx)
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
del self._vertex_indices[node_id]
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
def replace_all_out_rdma_connections(
self, source: NodeId, new_connections: Sequence[Connection]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for conn in new_connections:
self.add_connection(conn)
def remove_connection(self, conn: Connection) -> None:
if (
conn.source not in self._vertex_indices
or conn.sink not in self._vertex_indices
):
return
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[conn.source], self._vertex_indices[conn.sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == conn.edge:
self._graph.remove_edge_from_index(conn_idx)
def get_cycles(self) -> list[Cycle]:
"""Get simple cycles in the graph, including singleton cycles"""
def get_cycles(self) -> list[list[NodeInfo]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[NodeInfo]] = []
cycles: list[Cycle] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycle = Cycle(node_ids=[self._graph[idx] for idx in cycle_idx])
cycles.append(cycle)
for node_id in self.list_nodes():
cycles.append(Cycle(node_ids=[node_id]))
return cycles
def get_cycles_tb(self) -> list[list[NodeInfo]]:
def get_cycles_tb(self) -> list[Cycle]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
tb_graph.add_edge(u, v, conn)
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[list[NodeInfo]] = []
cycles: list[Cycle] = []
for cycle_idx in cycle_idxs:
cycle = [tb_graph[idx] for idx in cycle_idx]
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
topology = Topology()
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for node_id in node_ids:
topology.add_node(node_id)
for connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
if connection.source in node_ids and connection.sink in node_ids:
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -4,17 +4,28 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call"
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"
@@ -157,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class PlaceInstanceParams(BaseModel):
model_id: str
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
@@ -195,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
model_card: ModelCard
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum):
@@ -22,6 +22,7 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
model_card: ModelCard
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int

View File

@@ -16,13 +16,23 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.str_schema()
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,25 +76,15 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
class NodePerformanceMeasured(BaseEvent):
# TODO: bikeshed this name
class NodeGatheredInfo(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
info: GatheredInfo
class NodeDownloadProgress(BaseEvent):
@@ -107,11 +97,11 @@ class ChunkGenerated(BaseEvent):
class TopologyEdgeCreated(BaseEvent):
edge: Connection
conn: Connection
class TopologyEdgeDeleted(BaseEvent):
edge: Connection
conn: Connection
Event = (
@@ -125,10 +115,8 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeGatheredInfo
| NodeDownloadProgress
| ChunkGenerated
| TopologyEdgeCreated

View File

@@ -1,18 +0,0 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -1,10 +1,11 @@
import re
from typing import ClassVar
from pydantic import BaseModel, computed_field, field_validator
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,12 +1,14 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import ThunderboltIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryPerformanceProfile(CamelCaseModel):
class MemoryUsage(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -44,7 +46,6 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -52,16 +53,21 @@ class NetworkInterfaceInfo(CamelCaseModel):
ip_address: str
class NodePerformanceProfile(CamelCaseModel):
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
class NodeIdentity(CamelCaseModel):
"""Static and slow-changing node identification data."""
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
class ConnectionProfile(CamelCaseModel):
throughput: float
latency: float
jitter: float
class NodeNetworkInfo(CamelCaseModel):
"""Network interface information for a node."""
interfaces: Sequence[NetworkInterfaceInfo] = []
class NodeThunderboltInfo(CamelCaseModel):
"""Thunderbolt interface identifiers for a node."""
interfaces: Sequence[ThunderboltIdentifier] = []

View File

@@ -7,7 +7,13 @@ from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import (
MemoryUsage,
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -35,11 +41,17 @@ class State(CamelCaseModel):
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
last_seen: Mapping[NodeId, datetime] = {}
topology: Topology = Field(default_factory=Topology)
last_event_applied_idx: int = Field(default=-1, ge=-1)
# Granular node state mappings (update independently at different frequencies)
node_identities: Mapping[NodeId, NodeIdentity] = {}
node_memory: Mapping[NodeId, MemoryUsage] = {}
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:
return value.to_snapshot()

View File

@@ -0,0 +1,81 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class ThunderboltConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class ThunderboltIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
## Intentionally minimal, only collecting data we care about - there's a lot more
class _ReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str | None = None
class _ConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
class ThunderboltConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
items: list[_ConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: _ReceptacleTag | None = None
def ident(self, ifaces: dict[str, str]) -> ThunderboltIdentifier | None:
if (
self.domain_uuid_key is None
or self.receptacle_1_tag is None
or self.receptacle_1_tag.receptacle_id_key is None
):
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
assert tag in ifaces # doesn't need to be an assertion but im confident
# if tag not in ifaces: return None
iface = f"rdma_{ifaces[tag]}"
return ThunderboltIdentifier(
rdma_interface=iface, domain_uuid=self.domain_uuid_key
)
def conn(self) -> ThunderboltConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
),
None,
)
if sink_key is None:
return None
return ThunderboltConnection(
source_uuid=self.domain_uuid_key, sink_uuid=sink_key
)
class ThunderboltConnectivity(BaseModel, extra="ignore"):
SPThunderboltDataType: list[ThunderboltConnectivityData] = []
@classmethod
async def gather(cls) -> list[ThunderboltConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return ThunderboltConnectivity.model_validate_json(
proc.stdout
).SPThunderboltDataType

View File

@@ -1,37 +1,41 @@
from collections.abc import Iterator
from dataclasses import dataclass
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import FrozenModel
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
@dataclass(frozen=True)
class Cycle:
node_ids: list[NodeId]
def __len__(self) -> int:
return self.node_ids.__len__()
def __iter__(self) -> Iterator[NodeId]:
return self.node_ids.__iter__()
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
def is_thunderbolt(self) -> bool:
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
return True
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
class Connection(FrozenModel):
source: NodeId
sink: NodeId
edge: RDMAConnection | SocketConnection

View File

@@ -1,3 +1,8 @@
from datetime import timedelta
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
@@ -28,7 +33,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
pass
total_bytes: Memory
class DownloadFailed(BaseDownloadProgress):
@@ -42,3 +47,50 @@ class DownloadOngoing(BaseDownloadProgress):
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)

View File

@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]

View File

@@ -1,43 +0,0 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Callable
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
SystemPerformanceProfile,
)
class ResourceCollector(ABC):
@abstractmethod
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
class SystemResourceCollector(ResourceCollector):
async def collect(self) -> SystemPerformanceProfile: ...
class MemoryResourceCollector(ResourceCollector):
async def collect(self) -> MemoryPerformanceProfile: ...
class ResourceMonitor:
data_collectors: list[ResourceCollector]
effect_handlers: set[
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
]
async def _collect(
self,
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
tasks: list[
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
] = [collector.collect() for collector in self.data_collectors]
return await asyncio.gather(*tasks)
async def collect(self) -> None:
profiles = await self._collect()
for profile in profiles:
for effect_handler in self.effect_handlers:
effect_handler(profile)

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.models import ModelMetadata
from exo.shared.models.model_cards import ModelCard
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_meta: ModelMetadata
model_card: ModelCard
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_meta.model_id,
self.model_card.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -0,0 +1,235 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import (
ThunderboltConnection,
ThunderboltConnectivity,
ThunderboltIdentifier,
)
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacThunderboltIdentifiers(TaggedModel):
idents: Sequence[ThunderboltIdentifier]
class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if IS_DARWIN:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler_thunderbolt_data(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await ThunderboltConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -0,0 +1,70 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -0,0 +1,114 @@
from collections.abc import Mapping
import anyio
import httpx
from anyio import create_task_group
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodeNetworkInfo
REACHABILITY_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
url = f"http://[{target_ip}]:52415/node_id"
else:
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
remote_node_id = NodeId(body)
break
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
if remote_node_id is None:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(
topology: Topology,
self_node_id: NodeId,
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
for node_id in topology.list_nodes():
if node_id not in node_network:
continue
if node_id == self_node_id:
continue
for iface in node_network[node_id].interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node_id,
reachable,
client,
)
return reachable

Some files were not shown because too many files have changed in this diff Show More