Compare commits

...

19 Commits

Author SHA1 Message Date
Jake Hillion
bc32c20cab bench: add multi-node M3 Ultra benchmark specs for 2, 3, and 4 nodes
The existing benchmark spec only covered single-node M3 Ultra
configurations. Multi-node benchmarks are needed to test sharded
inference across 2, 3, and 4 M3 Ultra 80-core clusters connected
via Thunderbolt 5 with RDMA.

Added 2x, 3x, and 4x-m3-ultra.toml benchmark specs with all_to_all
topology (min Thunderbolt version 5), All(Rdma) constraint,
min_nodes matching the host count, and skip_tensor_ring. Models are
tiered by per-node memory (>=96GiB and >=256GiB), with >=512GiB
commented out for now. Renamed single-m3-ultra.toml to
1x-m3-ultra.toml for consistent naming.

Test plan:
- CI
2026-02-19 12:03:41 +00:00
rltakashige
51021f6fc6 Add cancellation button and the ability to cancel during prefill (#1540)
## Motivation
There's no way to easily use the cancellation features we added! Also,
prefill can take ages so let's allow cancelling out of that.

## Changes

Wiring up our existing functionality to easily cancel during generation
(and adding stuff to do so during prefill)

## Test Plan

### Manual Testing
Tested it works during both prefill and decode.

### Automated testing
Needs testing to see if this causes a GPU timeout error on large prefill
on large models in pipeline parallel. However, from manually testing GLM
5 pipeline ring on 2 nodes, and from reading the code, it does not seem
like this will be the case.
2026-02-19 11:40:59 +00:00
Alex Cheema
025ed9fd82 feat: add prefill progress bar for long prompts (#1181)
## Motivation

Users processing long prompts have no visibility into when token
generation will start. This feature adds a progress bar showing prefill
progress, giving users real-time feedback during prompt processing.

## Changes

### Backend
- Added `PrefillProgress` event type with `command_id`,
`processed_tokens`, `total_tokens`
- Added `PrefillProgressResponse` type (though now using direct callback
approach)
- Wired `prompt_progress_callback` through MLX's `stream_generate()`
- Progress events sent directly from callback for real-time updates (not
batched)
- API generates SSE named events: `event: prefill_progress\ndata: {...}`
- Added `PrefillProgressData` dataclass and `StreamEvent` union type in
API

### Dashboard
- Added `PrefillProgress` interface to store
- Updated SSE parsing to handle `event:` lines (named events)
- Created `PrefillProgressBar.svelte` with animated progress bar
- Shows "Processing prompt: X/Y tokens" with percentage
- Progress bar disappears when first token arrives

## Why It Works

MLX's `stream_generate()` accepts a `prompt_progress_callback(processed,
total)` that's called after each prefill chunk. By sending events
directly from this callback (rather than yielding from the generator),
progress updates are sent in real-time during prefill.

Using SSE named events (`event: prefill_progress`) maintains full
OpenAI/Claude API compatibility - standard clients ignore named events
they don't recognize, while the exo dashboard explicitly listens for
them.

## Test Plan

### Manual Testing
- Hardware: MacBook Pro M3 Max
- Set `prefill_step_size=256` for more frequent updates
- Tested with long prompts (pasted large documents)
- Verified progress bar updates incrementally during prefill
- Confirmed progress bar disappears when generation starts
- Tested with curl - standard `data:` events still work normally

Here is it working:


https://github.com/user-attachments/assets/5cc6f075-c5b2-4a44-bb4d-9efb246bc5fe


### Automated Testing
- Type checker passes (0 errors)
- All 192 tests pass
- Dashboard builds successfully

### API Compatibility
- Named SSE events are ignored by OpenAI SDK clients
- Regular token data uses standard `data: {...}` format
- `[DONE]` sentinel works as expected

---

**Note:** `prefill_step_size` is temporarily set to 256 for testing.
Should be changed back to 2048 before merging for production
performance.

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-19 03:18:25 +00:00
rltakashige
19bc09550d Add status=downloaded filter for model endpoint (#1539)
## Motivation

https://github.com/exo-explore/exo/issues/1346#issuecomment-3831427905


## Test Plan

### Manual Testing
**Without filter**
<img width="1708" height="1010" alt="Screenshot 2026-02-18 at 22 26 22"
src="https://github.com/user-attachments/assets/f4bf7142-717d-4042-ac28-d8a55a8e45e7"
/>

**With filter**
<img width="1723" height="1021" alt="Screenshot 2026-02-18 at 22 26 45"
src="https://github.com/user-attachments/assets/40a522d5-c6e6-4148-b21a-02caa1221ebe"
/>
2026-02-18 22:34:11 +00:00
Alex Cheema
7cadca4f27 Try multiple endpoints for internet connectivity check (#1516)
## Summary
- `_test_internet_connection()` previously only tried `1.1.1.1:443`,
which some ISPs/networks block, causing exo to incorrectly report no
internet and fail downloads on startup
- Now tries `1.1.1.1`, `8.8.8.8`, and `1.0.0.1` in sequence, succeeding
if any endpoint responds
- Returns early on first success for minimal latency in the common case

Fixes #1425

## Test plan
- [ ] Verify downloads work on networks that block `1.1.1.1`
- [ ] Verify existing behavior unchanged on networks where `1.1.1.1`
works
- [ ] Verify `internet_connection` is set to `False` only when all three
endpoints fail

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 22:10:07 +00:00
rltakashige
24e99ce197 Cleanup mistakes (#1537)
Oops
2026-02-18 22:05:26 +00:00
Alex Cheema
315992549b fix: unblock MpReceiver.close() to prevent shutdown hang (#1511)
## Summary

- `MpReceiver.close()` did not unblock threads stuck on `queue.get()` in
`receive_async()`, causing abandoned threads (via
`abandon_on_cancel=True`) to keep the Python process alive indefinitely
after tests pass
- This caused the `aarch64-darwin` CI jobs in PR #1462 to hang for ~6
hours until the GitHub Actions timeout killed them
- Sends an `_MpEndOfStream` sentinel before closing the buffer,
mirroring what `MpSender.close()` already does

## Test plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — clean
- [x] `nix fmt` — 0 changed
- [x] `uv run pytest` — 188 passed, 1 skipped in 12s (no hang)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-18 21:59:02 +00:00
Alex Cheema
ce5a65d3b9 Add MiniMax M2.5 model cards (#1514)
## Summary
- Adds model cards for MiniMax M2.5 in three quantizations: 4bit (~129
GB), 6bit (~186 GB), 8bit (~243 GB)
- No code changes needed — `MiniMaxM2ForCausalLM` is already in the
tensor parallel whitelist and `MiniMaxShardingStrategy` is already
implemented in `auto_parallel.py`
- Credit to @vskiwi for confirming MiniMax M2.5 works out of the box
with existing code

Closes #1480

## Test plan
- [x] `basedpyright` passes with 0 errors
- [x] `ruff check` passes
- [x] `pytest` passes (260 passed, 1 skipped)
- [ ] Verify MiniMax M2.5 models appear in model selector on dashboard

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 21:11:13 +00:00
rltakashige
c2f2111b88 Fix tool calling (#1529)
## Motivation

GPT OSS tool calling issues.

## Changes

Fixes those and adds a bunch of evals for tool calling.
Fixes GLM5 prefix caching, where CacheList wasn't getting handled
properly.
Extracts a bunch of the setup functionality of exo bench to a harness
that can be reused elsewhere, such as in the tool calling eval.

## Test Plan
### Automated Testing
Let's run the evals for all models
2026-02-18 20:29:18 +00:00
Alex Cheema
6c322ebb72 feat: only show thinking toggle for models that support it (#1497)
## Summary
- Adds `thinking_toggle` capability to 26 model cards that support
toggling thinking mode on/off
- GPT-OSS models (20b, 120b) excluded — they always think and don't
support toggling
- Dashboard UI updated to check for `thinking_toggle` capability before
showing the toggle button

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 0 failed
- [x] Security review passed (no secrets, eval/exec, innerHTML, or dep
changes)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 17:05:00 +00:00
vskiwi
2ebe6216b4 feat: add explicit --offline mode for air-gapped clusters (#1525)
## Motivation

Closes #1510

There is currently no reliable way to run exo on an air-gapped or offline cluster where models are pre-staged on local disks. The two existing mechanisms — `--no-downloads` and `HF_HUB_OFFLINE=1` — each cover only a subset of the problem:

1. **`--no-downloads` blocks model loading**: When passed, `DownloadCoordinator` is not created. No `NodeDownloadProgress` events are ever emitted, so `_model_needs_download()` in `plan.py` perpetually returns `DownloadModel`, short-circuiting `_load_model()` and preventing the model from ever being loaded.

2. **`HF_HUB_OFFLINE=1` doesn't cover exo's aiohttp code**: exo's download pipeline primarily uses raw `aiohttp` for HTTP operations (file list fetching, file downloads, HEAD verification), not the `huggingface_hub` library. These calls will attempt connections and time out on air-gapped networks.

3. **`skip_internet` is not propagated to `download_file_with_retry()`**: Even when `internet_connection = False`, the `_download_file()` function still makes HTTP HEAD calls via `file_meta()` to verify local files and unconditionally attempts downloads for missing files.

## Changes

### `src/exo/main.py`
- Add `--offline` flag to `Args` with env var detection (`EXO_OFFLINE=1`, `HF_HUB_OFFLINE=1`)
- Pass `offline` to `DownloadCoordinator` at creation and re-creation (election loop)

### `src/exo/download/coordinator.py`
- Add `offline: bool = False` field
- In offline mode: set `internet_connection = False` immediately in `__post_init__`, skip `_test_internet_connection()` ping (avoids 3s timeout), skip `_check_internet_connection` periodic loop
- In `_start_download()`: if model is not fully available locally, emit `DownloadFailed` with clear message instead of starting a download task

### `src/exo/download/download_utils.py`
- Add `skip_internet: bool` parameter to `download_file_with_retry()` and `_download_file()`
- When `skip_internet=True` in `_download_file()`: return local file immediately without HTTP HEAD verification; raise `FileNotFoundError` for missing files
- Propagate `skip_internet` from `download_shard()` to `download_file_with_retry()`

### `src/exo/download/tests/test_offline_mode.py` (new)
- 8 tests covering `_download_file`, `download_file_with_retry`, and `fetch_file_list_with_cache` in offline mode

## Why It Works

Unlike `--no-downloads` which disables `DownloadCoordinator` entirely, `--offline` keeps the coordinator running in a restricted mode. The existing `_emit_existing_download_progress()` disk scanner still runs every 60 seconds, emitting `DownloadCompleted` events for pre-staged models. These events flow through the event-sourcing pipeline and populate `state.downloads`, which unblocks `_model_needs_download()` in `plan.py` — no changes to the planning logic required.

```
--offline flag
  → DownloadCoordinator (offline mode)
    → Skip 1.1.1.1 ping, internet_connection = False
    → _emit_existing_download_progress scans disk
      → Emits DownloadCompleted for pre-staged models
        → _model_needs_download sees DownloadCompleted
          → _load_model proceeds normally
```

## Test Plan

### Automated Testing
- `ruff check` — passes
- 8 new tests in `test_offline_mode.py` — all pass
- 11 existing download tests in `test_download_verification.py` — all pass (no regressions)

### Manual Testing
1. Pre-stage a model on disk (e.g., `~/.exo/models/mlx-community--Qwen3-0.6B-4bit/`)
2. Start exo with `--offline` (or `EXO_OFFLINE=1`)
3. Place an instance via API or dashboard
4. Verify: model loads into memory and inference works without any network calls

### Environment
- macOS (Apple Silicon), multi-node cluster with Thunderbolt interconnect
- Models pre-staged via rsync / NFS mount
2026-02-18 16:18:09 +00:00
ciaranbor
f54c80b121 Ciaran/image edit api (#1500)
## Motivation

- Image editing previously ignored input image dimensions, always
defaulting to 1024x1024
- Size dropdown was hidden in edit mode, giving users no control over
output dimensions
- Portrait/landscape presets used non-standard aspect ratios (1024x1365
/ 1365x1024)

## Changes

- Added "auto" size option that uses input image dimensions for edits,
defaults to 1024x1024 for generation
- Introduced ImageSize Literal type and normalize_image_size() validator
(replaces raw str size fields)
  - Updated portrait/landscape presets to standard 1024x1536 / 1536x1024
  - Made size selector visible in edit mode (previously hidden)
  - Default size changed from "1024x1024" to "auto"

## Why It Works

- "auto" reads actual input image dimensions via PIL at generation time,
so edits preserve the original aspect ratio
- Pydantic field_validator on both ImageGenerationTaskParams and
ImageEditsTaskParams normalizes None → "auto", keeping the API
backward-compatible

## Test Plan

### Manual Testing

- Verify image edits output at the input image's native resolution when
size is "auto"
- Verify size dropdown appears and works in both generate and edit modes
2026-02-18 16:05:39 +00:00
rltakashige
48b8f86395 Add support for GLM 5 (#1526)
## Motivation

Add GLM 5 support in favor of #1513 

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-18 14:04:06 +00:00
Evan
5cbd6377a2 prioritize official model cards over custom model cards
our old model card search path would override official model cards with
custom model cards - our packaged model cards should always be the
default here
2026-02-18 13:20:05 +00:00
Evan Quiney
8f01523ddb remove dead code (#1496) 2026-02-18 11:43:27 +00:00
Alex Cheema
3addeadea8 Update mlx-lm to 0.30.7 (#1520)
## Summary
- Bumps `mlx-lm` from 0.30.6 to 0.30.7 in `pyproject.toml` and `uv.lock`

## Test plan
- [x] `uv lock` resolves successfully
- [x] `basedpyright` — no new errors (63 pre-existing in unrelated
`test_tool_call_tracker.py`)
- [x] `ruff check` — all checks passed
- [x] `nix fmt` — no formatting changes
- [x] `pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 11:14:23 +00:00
rltakashige
f2be929211 Leo/address rdma gpu locks 2 (#1515)
Same as #1489 . Had to revert and redo thanks to Claude.

---------

Co-authored-by: Jake Hillion <jake@hillion.co.uk>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:00:52 -08:00
rltakashige
83af8c63fa Revert "Use custom fork that resolves GPU locks" (#1502)
Reverts exo-explore/exo#1489

Goddammit Claude...
2026-02-17 18:18:54 +00:00
Evan Quiney
eccc6298d1 Revert "Add MetaInstance declarative layer (#1447)"
This reverts commit a962a28afc.
2026-02-17 18:11:47 +00:00
121 changed files with 4558 additions and 4865 deletions

View File

@@ -200,7 +200,7 @@ class Module(dict):
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
"""Return the submodules that do not contain other modules."""
def update(self, parameters: dict, strict: bool = ...) -> Module:
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.

View File

@@ -7,7 +7,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mlx.core import MX_ARRAY_TREE
def tree_map(
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[..., bool] | None = ...,
) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -44,11 +47,11 @@ def tree_map(
"""
def tree_map_with_path(
fn: Callable,
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Optional[Callable] = ...,
path: Optional[Any] = ...,
is_leaf: Callable[..., bool] | None = ...,
path: str | None = ...,
) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -80,9 +83,9 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
prefix: str = ...,
is_leaf: Optional[Callable] = ...,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
is_leaf: Callable[..., bool] | None = ...,
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
) -> list[tuple[str, Any]] | dict[str, Any]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -118,7 +121,7 @@ def tree_flatten(
the Python tree.
"""
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python

View File

@@ -0,0 +1,46 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

123
Cargo.lock generated
View File

@@ -141,12 +141,6 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.7.1"
@@ -304,19 +298,6 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
[[package]]
name = "bigdecimal"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
dependencies = [
"autocfg",
"libm",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.3"
@@ -516,15 +497,6 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
[[package]]
name = "convert_case"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -746,29 +718,6 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_more"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.111",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -939,22 +888,17 @@ name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"env_logger",
"extend",
"futures",
"impl-trait-for-tuples",
"libp2p",
"log",
"networking",
"once_cell",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"thiserror 2.0.17",
"thread_local",
"tokio",
"util",
]
@@ -1640,17 +1584,6 @@ dependencies = [
"xmltree",
]
[[package]]
name = "impl-trait-for-tuples"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "indexmap"
version = "2.12.1"
@@ -1829,12 +1762,6 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libp2p"
version = "0.56.0"
@@ -2824,16 +2751,13 @@ name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"libp2p",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
@@ -2918,17 +2842,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -3279,28 +3192,14 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
dependencies = [
"bigdecimal",
"either",
"hashbrown 0.16.1",
"indexmap",
"indoc",
"inventory",
"libc",
"lock_api",
"memoffset",
"num-bigint",
"num-complex",
"num-rational",
"num-traits",
"once_cell",
"ordered-float",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"rust_decimal",
"smallvec",
"unindent",
]
@@ -3741,16 +3640,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -4615,24 +4504,12 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"

View File

@@ -26,49 +26,21 @@ opt-level = 3
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-util = "0.3"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"

View File

@@ -72,12 +72,19 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly:
```bash
nix run .#exo
```
**Note:** To accept the Cachix binary cache (and avoid the Xcode Metal ToolChain), add to `/etc/nix/nix.conf`:
```
trusted-users = root (or your username)
experimental-features = nix-command flakes
```
Then restart the Nix daemon: `sudo launchctl kickstart -k system/org.nixos.nix-daemon`
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)

219
bench/2x-m3-ultra.toml Normal file
View File

@@ -0,0 +1,219 @@
# 2-node M3 Ultra benchmarks (2 × 96 GiB = 192 GiB total, or 2 × 256 GiB = 512 GiB total)
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=2)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
"All(Rdma)",
]
[topology]
type = "all_to_all"
min_version = 5
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
min_nodes = 2
skip_tensor_ring = true
# ── 96 GiB per-node models (total storage < 192 GiB) ─────────────────────────
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
# ── 256 GiB per-node models (192 GiB ≤ total storage < 512 GiB) ──────────────
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
# ── 512 GiB per-node models (total storage ≥ 512 GiB) ────────────────────────
# [[benchmark]]
# model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
# extra_constraints = ["All(Memory(>=512GiB))"]
# [[benchmark]]
# model = "mlx-community/Kimi-K2-Instruct-4bit"
# extra_constraints = ["All(Memory(>=512GiB))"]
# [[benchmark]]
# model = "mlx-community/Kimi-K2.5"
# extra_constraints = ["All(Memory(>=512GiB))"]
# [[benchmark]]
# model = "mlx-community/Kimi-K2-Thinking"
# extra_constraints = ["All(Memory(>=512GiB))"]
# [[benchmark]]
# model = "mlx-community/DeepSeek-V3.1-8bit"
# extra_constraints = ["All(Memory(>=512GiB))"]

217
bench/3x-m3-ultra.toml Normal file
View File

@@ -0,0 +1,217 @@
# 3-node M3 Ultra benchmarks (3 × 96 GiB = 288 GiB total, or 3 × 256 GiB = 768 GiB total)
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=3)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
"All(Rdma)",
]
[topology]
type = "all_to_all"
min_version = 5
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
min_nodes = 3
skip_tensor_ring = true
# ── 96 GiB per-node models (total storage < 288 GiB) ─────────────────────────
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
# ── 256 GiB per-node models (288 GiB ≤ total storage < 768 GiB) ──────────────
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Kimi-K2-Instruct-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Kimi-K2.5"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Kimi-K2-Thinking"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]

217
bench/4x-m3-ultra.toml Normal file
View File

@@ -0,0 +1,217 @@
# 4-node M3 Ultra benchmarks (4 × 96 GiB = 384 GiB total, or 4 × 256 GiB = 1024 GiB total)
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=4)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
"All(Rdma)",
]
[topology]
type = "all_to_all"
min_version = 5
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
min_nodes = 4
skip_tensor_ring = true
# ── 96 GiB per-node models (total storage < 384 GiB) ─────────────────────────
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
# ── 256 GiB per-node models (384 GiB ≤ total storage < 1024 GiB) ─────────────
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Kimi-K2-Instruct-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Kimi-K2.5"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Kimi-K2-Thinking"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]

View File

@@ -3,5 +3,8 @@
# Lists the suite files to include. Each file defines benchmarks
# with shared constraints, topology, and default args.
include = [
"single-m3-ultra.toml",
"1x-m3-ultra.toml",
"2x-m3-ultra.toml",
"3x-m3-ultra.toml",
"4x-m3-ultra.toml",
]

1088
bench/eval_tool_calls.py Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,29 +1,47 @@
# type: ignore
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""Tool-calling eval for exo's OpenAI-compatible API.
Tests whether models correctly:
- Trigger tool calls when appropriate
- Return valid JSON arguments matching function schemas
- Handle multi-turn tool use (call -> result -> final answer)
- Avoid calling tools when unnecessary
Start exo with a model first, then run:
uv run python tool_call_eval.py --model <model-id>
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
uv run python tool_call_eval.py --model <model-id> --repeat 3
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
"""
from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from harness import (
ExoClient,
ExoHttpError,
add_common_instance_args,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -103,154 +121,6 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
@@ -269,184 +139,6 @@ def parse_int_list(values: list[str]) -> list[int]:
return items
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
@@ -538,76 +230,12 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
add_common_instance_args(ap)
ap.add_argument(
"--pp",
nargs="+",
@@ -620,34 +248,6 @@ def main() -> int:
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
@@ -657,9 +257,6 @@ def main() -> int:
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
@@ -674,17 +271,6 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -719,24 +305,10 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_deadline:
backoff = _SETTLE_INITIAL_BACKOFF_S
while not selected and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
@@ -760,16 +332,6 @@ def main() -> int:
if args.dry_run:
return 0
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

327
bench/harness.py Normal file
View File

@@ -0,0 +1,327 @@
# type: ignore
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from typing import Any
from urllib.parse import urlencode
from loguru import logger
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if (m.get("name") or "").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def settle_and_fetch_placements(
client: ExoClient,
full_model_id: str,
args: argparse.Namespace,
settle_timeout: float = 0,
) -> list[dict[str, Any]]:
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
return selected
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
description = "Benchmarking tool for exo distributed inference"
requires-python = ">=3.13"
dependencies = [
"httpx>=0.27.0",
"loguru>=0.7.3",
"transformers>=5.0.0",
"huggingface-hub>=0.33.4",

306
bench/scenarios.toml Normal file
View File

@@ -0,0 +1,306 @@
# Tool definitions — each becomes an OpenAI function tool.
# All scenarios get all tools unless they specify a `tools` list.
[tools.get_current_weather]
description = "Get the current weather in a given location"
required = ["location"]
[tools.get_current_weather.properties.location]
type = "string"
description = "City and state, e.g. San Francisco, CA"
[tools.get_current_weather.properties.unit]
type = "string"
enum = ["celsius", "fahrenheit"]
description = "Temperature unit"
[tools.calculate]
description = "Evaluate a mathematical expression and return the numeric result"
required = ["expression"]
[tools.calculate.properties.expression]
type = "string"
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
[tools.search_products]
description = "Search for products in a catalog by query, category, and price"
required = ["query"]
[tools.search_products.properties.query]
type = "string"
description = "Search query string"
[tools.search_products.properties.category]
type = "string"
enum = ["electronics", "clothing", "food", "books"]
description = "Product category to filter by"
[tools.search_products.properties.max_price]
type = "number"
description = "Maximum price in USD"
[tools.create_todos]
description = "Create a structured todo list"
required = ["todos"]
[tools.create_todos.properties.todos]
type = "array"
description = "List of todo items"
[tools.create_todos.properties.todos.items]
type = "object"
required = ["content", "status", "priority"]
[tools.create_todos.properties.todos.items.properties.content]
type = "string"
description = "The todo item text"
[tools.create_todos.properties.todos.items.properties.status]
type = "string"
description = "Status: pending, in_progress, or completed"
[tools.create_todos.properties.todos.items.properties.priority]
type = "string"
description = "Priority: low, normal, or high"
# -- Should call a tool --
[[scenarios]]
name = "weather_simple"
description = "Basic weather query -> get_current_weather"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "What's the weather like in Tokyo right now?"
[[scenarios]]
name = "calculator_simple"
description = "Math question -> calculate"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 3847 * 926 + 17293"
[[scenarios]]
name = "search_with_filters"
description = "Product search with category and price filter"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.messages]]
role = "user"
content = "Find me electronics under $50"
# -- Multi-turn: tool call then follow-up --
[[scenarios]]
name = "weather_multi_turn"
description = "Weather query -> tool result -> natural language summary"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[scenarios.tool_result]
temperature = "18C"
condition = "partly cloudy"
humidity = "65%"
wind = "12 km/h NW"
[[scenarios.messages]]
role = "user"
content = "What's the weather in Paris?"
[[scenarios]]
name = "calculator_multi_turn"
description = "Math query -> tool result -> model reports the answer"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[scenarios.tool_result]
result = 491682
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 1847 * 263 + 5921"
[[scenarios]]
name = "search_multi_turn"
description = "Search query -> tool result -> model summarizes products"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.tool_result.results]]
name = "Hands-On Machine Learning"
price = 45.99
rating = 4.8
[[scenarios.tool_result.results]]
name = "Deep Learning with Python"
price = 39.99
rating = 4.6
[[scenarios.messages]]
role = "user"
content = "Search for books about machine learning"
# -- Sequential tool calls --
[[scenarios]]
name = "chained_tool_calls_same"
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare the weather in Tokyo and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check both cities. Let me start with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_1"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_1"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios]]
name = "chained_tool_calls_different"
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
[[scenarios.messages]]
role = "assistant"
content = "I'll handle both. Let me check Berlin's weather first."
[[scenarios.messages.tool_calls]]
id = "call_2"
name = "get_current_weather"
arguments = { location = "Berlin" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_2"
content = '{"temperature": "12C", "condition": "rainy"}'
[[scenarios]]
name = "chained_tool_calls_three"
description = "Two prior thinking+tool calls -> results -> model must make a third"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare weather in Tokyo, Paris, and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check all three cities. Starting with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_3"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_3"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios.messages]]
role = "assistant"
content = "Got Tokyo. Now checking Paris."
[[scenarios.messages.tool_calls]]
id = "call_4"
name = "get_current_weather"
arguments = { location = "Paris" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_4"
content = '{"temperature": "18C", "condition": "cloudy"}'
# -- Nested object schema (regression for lossy chat template rendering) --
[[scenarios]]
name = "nested_schema_tool_call"
description = "Tool call with nested object array schema -> create_todos"
expect_tool_call = true
expected_function = "create_todos"
required_arg_keys = ["todos"]
nested_array_key = "todos"
required_item_keys = ["content", "status", "priority"]
tools = ["create_todos"]
[[scenarios.messages]]
role = "user"
content = "Create a todo list with 3 items to learn Python"
# -- Tool name integrity (regression for harmony token leaking into name) --
[tools.glob]
description = "Search for files matching a glob pattern in the codebase"
required = ["pattern"]
[tools.glob.properties.pattern]
type = "string"
description = "The glob pattern to match files against, e.g. '**/*.py'"
[tools.glob.properties.path]
type = "string"
description = "The directory to search in"
[[scenarios]]
name = "tool_name_integrity"
description = "Tool name must not contain harmony tokens like <|channel|>"
expect_tool_call = true
expected_function = "glob"
required_arg_keys = ["pattern"]
tools = ["glob"]
[[scenarios.messages]]
role = "user"
content = "Find all Python files in the src directory"
# -- Should NOT call a tool --
[[scenarios]]
name = "no_tool_joke"
description = "Joke request should NOT trigger any tool"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "Tell me a funny joke about cats."
[[scenarios]]
name = "no_tool_factual"
description = "Factual question answerable from training data"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "What is the capital of Japan?"

View File

@@ -14,6 +14,7 @@
totalTokens,
thinkingEnabled as thinkingEnabledStore,
setConversationThinking,
stopGeneration,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -103,7 +104,7 @@
const modelSupportsThinking = $derived(() => {
if (!currentModel) return false;
const caps = modelCapabilities[currentModel] || [];
return caps.includes("thinking") && caps.includes("text");
return caps.includes("thinking_toggle") && caps.includes("text");
});
const isEditOnlyWithoutImage = $derived(
@@ -653,86 +654,92 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
{#if loading}
<button
type="button"
onclick={() => stopGeneration()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-exo-medium-gray hover:text-white"
aria-label="Stop generation"
>
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
class="w-3 h-3 sm:w-3.5 sm:h-3.5"
fill="currentColor"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
<rect x="6" y="6" width="12" height="12" rx="1" />
</svg>
<span>EDIT</span>
<span class="hidden sm:inline">Cancel</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
</div>
<!-- Bottom accent line -->

View File

@@ -3,16 +3,17 @@
messages,
currentResponse,
isLoading,
prefillProgress,
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import ImageLightbox from "./ImageLightbox.svelte";
interface Props {
@@ -25,6 +26,7 @@
const messageList = $derived(messages());
const response = $derived(currentResponse());
const loading = $derived(isLoading());
const prefill = $derived(prefillProgress());
// Scroll management - user controls scroll, show button when not at bottom
const SCROLL_THRESHOLD = 100;
@@ -428,6 +430,9 @@
{:else}
<!-- Assistant message styling -->
<div class="p-3 sm:p-4">
{#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}
<PrefillProgressBar progress={prefill} class="mb-3" />
{/if}
{#if message.thinking && message.thinking.trim().length > 0}
<div
class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40"

View File

@@ -185,7 +185,11 @@
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {

View File

@@ -26,7 +26,8 @@
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
function formatNumber(num: number | undefined): string {
if (num == null) return "0";
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {

View File

@@ -59,13 +59,14 @@
}
const sizeOptions: ImageGenerationParams["size"][] = [
"auto",
"512x512",
"768x768",
"1024x1024",
"1024x768",
"768x1024",
"1024x1365",
"1365x1024",
"1024x1536",
"1536x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
@@ -176,92 +177,90 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
{params.size.toUpperCase()}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size.toUpperCase()}</span>
</button>
{/each}
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
{/if}
</div>
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -311,7 +310,7 @@
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {qualityDropdownPosition()
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
>

View File

@@ -21,7 +21,7 @@
} | null;
nodes?: Record<string, NodeInfo>;
sharding?: "Pipeline" | "Tensor";
runtime?: "MlxRing" | "MlxJaccl";
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
onLaunch?: () => void;
tags?: string[];
apiPreview?: PlacementPreview | null;
@@ -348,7 +348,7 @@
// Debug mode state
const isDebugMode = $derived(debugMode());
const topology = $derived(topologyData());
const isRdma = $derived(runtime === "MlxJaccl");
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
// Get interface name for an IP from node data
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
@@ -575,7 +575,7 @@
>
{runtime === "MlxRing"
? "MLX Ring"
: runtime === "MlxJaccl"
: runtime === "MlxIbv" || runtime === "MlxJaccl"
? "MLX RDMA"
: runtime}
</span>

View File

@@ -0,0 +1,52 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
function formatTokenCount(count: number | undefined): string {
if (count == null) return "0";
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
export interface PlacementPreview {
model_id: string;
sharding: "Pipeline" | "Tensor";
instance_meta: "MlxRing" | "MlxJaccl";
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
instance: unknown | null;
memory_delta_by_node: Record<string, number> | null;
error: string | null;
@@ -219,6 +219,7 @@ interface RawStateResponse {
string,
{
MlxRingInstance?: Instance;
MlxIbvInstance?: Instance;
MlxJacclInstance?: Instance;
}
>;
@@ -249,20 +250,6 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// MetaInstances (declarative instance constraints)
metaInstances?: Record<string, MetaInstanceData>;
}
export interface MetaInstanceData {
metaInstanceId: string;
modelId: string;
sharding: string;
instanceMeta: string;
minNodes: number;
nodeIds: string[] | null;
placementError: string | null;
consecutiveFailures: number;
lastFailureError: string | null;
}
export interface MessageAttachment {
@@ -286,6 +273,11 @@ export interface TokenData {
topLogprobs: TopLogprob[];
}
export interface PrefillProgress {
processed: number;
total: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -319,13 +311,14 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
export interface ImageGenerationParams {
// Basic params
size:
| "auto"
| "512x512"
| "768x768"
| "1024x1024"
| "1024x768"
| "768x1024"
| "1024x1365"
| "1365x1024";
| "1024x1536"
| "1536x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
@@ -349,7 +342,7 @@ export interface EditingImage {
}
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "1024x1024",
size: "auto",
quality: "medium",
outputFormat: "png",
numImages: 1,
@@ -532,6 +525,10 @@ class AppStore {
ttftMs = $state<number | null>(null); // Time to first token in ms
tps = $state<number | null>(null); // Tokens per second
totalTokens = $state<number>(0); // Total tokens in current response
prefillProgress = $state<PrefillProgress | null>(null);
// Abort controller for stopping generation
private currentAbortController: AbortController | null = null;
// Topology state
topologyData = $state<TopologyData | null>(null);
@@ -550,7 +547,6 @@ class AppStore {
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
metaInstances = $state<Record<string, MetaInstanceData>>({});
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderbolt = $state<
Record<
@@ -909,7 +905,11 @@ class AppStore {
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {
@@ -1283,8 +1283,6 @@ class AppStore {
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
// RDMA ctl status per node
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
// MetaInstances
this.metaInstances = data.metaInstances ?? {};
// Thunderbolt bridge cycles
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
// Thunderbolt bridge status per node
@@ -2016,6 +2014,7 @@ class AppStore {
reader: ReadableStreamDefaultReader<Uint8Array>,
targetConversationId: string,
onChunk: (parsed: T) => void,
onEvent?: Record<string, (data: unknown) => void>,
): Promise<void> {
const decoder = new TextDecoder();
let buffer = "";
@@ -2036,6 +2035,24 @@ class AppStore {
const trimmed = line.trim();
if (!trimmed) continue;
// Handle SSE comments (": key json") for prefill progress etc.
if (trimmed.startsWith(": ") && onEvent) {
const comment = trimmed.slice(2);
const spaceIdx = comment.indexOf(" ");
if (spaceIdx > 0) {
const key = comment.slice(0, spaceIdx);
if (onEvent[key]) {
try {
const parsed = JSON.parse(comment.slice(spaceIdx + 1));
onEvent[key](parsed);
} catch {
// Skip malformed JSON in comment
}
}
}
continue;
}
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
@@ -2267,6 +2284,9 @@ class AppStore {
let firstTokenTime: number | null = null;
let tokenCount = 0;
const abortController = new AbortController();
this.currentAbortController = abortController;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
@@ -2283,6 +2303,7 @@ class AppStore {
enable_thinking: enableThinking,
}),
}),
signal: abortController.signal,
});
if (!response.ok) {
@@ -2320,6 +2341,11 @@ class AppStore {
reader,
targetConversationId,
(parsed) => {
// Clear prefill progress when first token data arrives
if (this.prefillProgress) {
this.prefillProgress = null;
}
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
@@ -2382,8 +2408,26 @@ class AppStore {
this.persistConversation(targetConversationId);
}
},
{
prefill_progress: (data) => {
// TaggedModel wraps as {"PrefillProgressChunk": {...}}
// model_dump_json() uses snake_case (by_alias defaults to False)
const raw = data as Record<string, unknown>;
const inner = (raw["PrefillProgressChunk"] ?? raw) as {
processed_tokens: number;
total_tokens: number;
};
this.prefillProgress = {
processed: inner.processed_tokens,
total: inner.total_tokens,
};
},
},
);
// Clear prefill progress after stream ends
this.prefillProgress = null;
// Calculate final TPS
if (firstTokenTime !== null && tokenCount > 1) {
const totalGenerationTime = performance.now() - firstTokenTime;
@@ -2414,20 +2458,31 @@ class AppStore {
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
if (error instanceof DOMException && error.name === "AbortError") {
// User stopped generation — not an error
} else {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
}
} finally {
this.currentAbortController = null;
this.prefillProgress = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
}
}
stopGeneration(): void {
this.currentAbortController?.abort();
this.currentAbortController = null;
}
/**
* Generate an image using the image generation API
*/
@@ -3054,9 +3109,9 @@ export const isLoading = () => appStore.isLoading;
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const prefillProgress = () => appStore.prefillProgress;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const metaInstances = () => appStore.metaInstances;
export const runners = () => appStore.runners;
export const downloads = () => appStore.downloads;
export const nodeDisk = () => appStore.nodeDisk;
@@ -3072,6 +3127,7 @@ export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
// Actions
export const stopGeneration = () => appStore.stopGeneration();
export const startChat = () => appStore.startChat();
export const sendMessage = (
content: string,

View File

File diff suppressed because it is too large Load Diff

View File

@@ -41,7 +41,7 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260217+50487b41"; in
version = let v = "0.30.7.dev20260218+14841977"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
@@ -49,8 +49,8 @@ let
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
};
patches = [

View File

@@ -19,7 +19,7 @@ dependencies = [
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"mlx-lm==0.30.7",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",

View File

@@ -158,6 +158,7 @@
exo-test-env = testVenv;
} // {
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 405874409472

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 765577920512

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 122406567936

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 229780750336

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 198556925568

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 286737579648

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 396963397248

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 19327352832

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 22548578304

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 26843545600

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 34359738368

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-8bit-MXFP8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 706522120192

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 662498705408

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 100086644736

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 242986745856

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-4bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "4bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 128666664960

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-6bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "6bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 185826705408

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-8bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 342884352

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 698351616

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 141733920768

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 268435456000

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 17612931072

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 33279705088

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 47080074240

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 88814387200

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 114572190076

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 159039627774

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 209082699847

View File

@@ -25,17 +25,17 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
# "ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.17.2" }
@@ -45,8 +45,6 @@ pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
@@ -54,24 +52,11 @@ tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }

View File

@@ -6,7 +6,7 @@ use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
pin::{Pin, pin},
pin::Pin,
task::{Context, Poll},
};
@@ -33,8 +33,6 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
}
}

View File

@@ -1,240 +0,0 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -17,7 +17,6 @@
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
@@ -25,7 +24,6 @@ use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -36,14 +34,10 @@ pub(crate) mod r#const {
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
@@ -51,7 +45,6 @@ pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
@@ -62,7 +55,7 @@ pub(crate) mod ext {
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new(py, self).unbind())
Python::attach(|py| PyBytes::new(py, self).unbind())
}
}
@@ -98,7 +91,7 @@ pub(crate) mod ext {
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
Python::attach(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
@@ -175,24 +168,6 @@ pub(crate) mod ext {
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.

View File

@@ -11,9 +11,9 @@ use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt a
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
@@ -25,7 +25,7 @@ use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
use pyo3::types::PyTuple;
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
@@ -155,7 +155,6 @@ async fn networking_task(
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
@@ -485,7 +484,7 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,

View File

@@ -19,8 +19,6 @@ either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
@@ -29,11 +27,6 @@ futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
@@ -41,4 +34,4 @@ keccak-const = { workspace = true }
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -24,8 +24,8 @@ use libp2p::{
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::error::Error;
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;

View File

@@ -1,5 +1,4 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;

View File

@@ -1,44 +0,0 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -3,19 +3,7 @@
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
@@ -54,11 +42,3 @@ pub(crate) mod ext {
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -47,6 +47,7 @@ class DownloadCoordinator:
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -62,6 +63,8 @@ class DownloadCoordinator:
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
@@ -107,23 +110,30 @@ class DownloadCoordinator:
self._last_progress_time[model_id] = current_time()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
logger.info(
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
)
if not self.offline:
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
if not self.offline:
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
# Try multiple endpoints since some ISPs/networks block specific IPs
for host in ("1.1.1.1", "8.8.8.8", "1.0.0.1"):
try:
socket.create_connection((host, 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
logger.debug(f"Internet connectivity: True (via {host})")
return
except OSError:
continue
self.shard_downloader.set_internet_connection(False)
logger.debug("Internet connectivity: False")
async def _check_internet_connection(self) -> None:
first_connection = True
@@ -202,6 +212,20 @@ class DownloadCoordinator:
)
return
if self.offline:
logger.warning(
f"Offline mode: model {model_id} is not fully available locally, cannot download"
)
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=f"Model files not found locally in offline mode: {model_id}",
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
return
# Start actual download
self._start_download_task(shard, initial_progress)
@@ -314,17 +338,7 @@ class DownloadCoordinator:
),
)
elif progress.status in ["in_progress", "not_started"]:
if (
progress.downloaded_bytes.in_bytes
>= progress.total_bytes.in_bytes
> 0
):
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,

View File

@@ -448,12 +448,13 @@ async def download_file_with_retry(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
skip_internet: bool = False,
) -> Path:
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress, skip_internet
)
except HuggingFaceAuthenticationError:
raise
@@ -487,10 +488,14 @@ async def _download_file(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
skip_internet: bool = False,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
if skip_internet:
return target_path
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
@@ -510,6 +515,11 @@ async def _download_file(
)
return target_path
if skip_internet:
raise FileNotFoundError(
f"File {path} not found locally and cannot download in offline mode"
)
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -814,6 +824,7 @@ async def download_shard(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
skip_internet=skip_internet,
)
if not skip_download:

View File

@@ -0,0 +1,230 @@
"""Tests for offline/air-gapped mode."""
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
download_file_with_retry,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.worker.downloads import FileListEntry
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestDownloadFileOffline:
"""Tests for _download_file with skip_internet=True."""
async def test_returns_local_file_without_http_verification(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists locally, return it immediately
without making any HTTP calls (no file_meta verification)."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"model weights data")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_raises_file_not_found_for_missing_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file does NOT exist locally,
raise FileNotFoundError instead of attempting download."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError, match="offline mode"):
await _download_file(
model_id,
"main",
"missing_model.safetensors",
target_dir,
skip_internet=True,
)
async def test_returns_local_file_in_subdirectory(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists in a subdirectory,
return it without HTTP calls."""
target_dir = tmp_path / "downloads"
subdir = target_dir / "transformer"
await aios.makedirs(subdir, exist_ok=True)
local_file = subdir / "diffusion_pytorch_model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"weights")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"transformer/diffusion_pytorch_model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
class TestDownloadFileWithRetryOffline:
"""Tests for download_file_with_retry with skip_internet=True."""
async def test_propagates_skip_internet_to_download_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Verify skip_internet is passed through to _download_file."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "config.json"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b'{"model_type": "qwen2"}')
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_file_not_found_does_not_retry(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""FileNotFoundError from offline mode should not trigger retries."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
await download_file_with_retry(
model_id,
"main",
"nonexistent.safetensors",
target_dir,
skip_internet=True,
)
class TestFetchFileListOffline:
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
async def test_uses_cached_file_list(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and cache file exists, use it without network."""
from pydantic import TypeAdapter
cache_dir = temp_models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
cached_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=200),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
assert result == cached_list
mock_fetch.assert_not_called()
async def test_falls_back_to_local_directory_scan(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and no cache but local files exist,
build file list from local directory."""
import json
model_dir = temp_models_dir / model_id.normalize()
await aios.makedirs(model_dir, exist_ok=True)
async with aiofiles.open(model_dir / "config.json", "w") as f:
await f.write('{"model_type": "qwen2"}')
index_data = {
"metadata": {},
"weight_map": {"model.layers.0.weight": "model.safetensors"},
}
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
await f.write(json.dumps(index_data))
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
await f.write(b"x" * 500)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
mock_fetch.assert_not_called()
paths = {entry.path for entry in result}
assert "config.json" in paths
assert "model.safetensors" in paths
async def test_raises_when_no_cache_and_no_local_files(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and neither cache nor local files exist,
raise FileNotFoundError."""
with pytest.raises(FileNotFoundError, match="No internet"):
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)

View File

@@ -39,6 +39,7 @@ class Node:
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -68,6 +69,7 @@ class Node:
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
download_coordinator = None
@@ -132,6 +134,7 @@ class Node:
api,
node_id,
event_index_counter,
args.offline,
)
async def run(self):
@@ -222,6 +225,7 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -254,12 +258,15 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
@@ -282,6 +289,7 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -329,6 +337,11 @@ class Args(CamelCaseModel):
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
parser.add_argument(
"--offline",
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -19,7 +19,12 @@ from exo.shared.types.api import (
ToolCall,
Usage,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
@@ -123,67 +128,81 @@ def chunk_to_response(
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
PrefillProgressChunk | ErrorChunk | ToolCallChunk | TokenChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
match chunk:
case PrefillProgressChunk():
# Use SSE comment so third-party clients ignore it
yield f": prefill_progress {chunk.model_dump_json()}\n\n"
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
case ErrorChunk():
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
yield f"data: {chunk_response.model_dump_json()}\n\n"
case ToolCallChunk():
last_usage = chunk.usage or last_usage
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
case TokenChunk():
last_usage = chunk.usage or last_usage
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(
update={"usage": last_usage}
)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -197,38 +216,43 @@ async def collect_chat_response(
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
match chunk:
case PrefillProgressChunk():
continue
if model is None:
model = chunk.model
case ErrorChunk():
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
case TokenChunk():
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=tool.id,
index=i,
function=tool,
case ToolCallChunk():
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
tool_calls.extend(
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)

View File

@@ -5,7 +5,12 @@ from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.api import FinishReason, Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.claude_api import (
ClaudeContentBlock,
ClaudeContentBlockDeltaEvent,
@@ -160,7 +165,9 @@ def claude_request_to_text_generation(
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -172,6 +179,9 @@ async def collect_claude_response(
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
@@ -230,7 +240,9 @@ async def collect_claude_response(
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
@@ -256,6 +268,9 @@ async def generate_claude_stream(
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
# Close text block and bail
break

View File

@@ -5,7 +5,12 @@ from itertools import count
from typing import Any
from exo.shared.types.api import Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
FunctionCallInputItem,
@@ -121,7 +126,9 @@ def responses_request_to_text_generation(
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -134,6 +141,9 @@ async def collect_responses_response(
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
@@ -189,7 +199,9 @@ async def collect_responses_response(
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
@@ -243,6 +255,9 @@ async def generate_responses_stream(
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
break

View File

@@ -71,11 +71,8 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
CreateMetaInstanceParams,
CreateMetaInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
DeleteMetaInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
@@ -88,6 +85,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
ImageListItem,
ImageListResponse,
ImageSize,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -103,11 +101,13 @@ from exo.shared.types.api import (
TraceRankStats,
TraceResponse,
TraceStatsResponse,
normalize_image_size,
)
from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
@@ -118,10 +118,8 @@ from exo.shared.types.claude_api import (
from exo.shared.types.commands import (
Command,
CreateInstance,
CreateMetaInstance,
DeleteDownload,
DeleteInstance,
DeleteMetaInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
@@ -134,21 +132,22 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -224,7 +223,8 @@ class API:
)
self._text_generation_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -282,9 +282,6 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/meta_instances")(self.list_meta_instances)
self.app.post("/meta_instance")(self.create_meta_instance)
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
@@ -314,27 +311,12 @@ class API:
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
async def place_instance(self, payload: PlaceInstanceParams):
model_card = await ModelCard.load(payload.model_id)
command = PlaceInstance(
model_card=model_card,
model_card=await ModelCard.load(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
)
# Validate placement before sending — fail fast with a clear error
# instead of silently dropping the command in the master.
try:
get_instance_placements(
command,
topology=self.state.topology,
current_instances=self.state.instances,
node_memory=self.state.node_memory,
node_network=self.state.node_network,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
await self._send(command)
return CreateInstanceResponse(
@@ -546,67 +528,33 @@ class API:
instance_id=instance_id,
)
def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]:
return dict(self.state.meta_instances)
async def create_meta_instance(
self, payload: CreateMetaInstanceParams
) -> CreateMetaInstanceResponse:
meta_instance = MetaInstance(
model_id=payload.model_id,
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
node_ids=payload.node_ids,
)
command = CreateMetaInstance(meta_instance=meta_instance)
await self._send(command)
return CreateMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance.meta_instance_id,
)
async def delete_meta_instance(
self, meta_instance_id: MetaInstanceId
) -> DeleteMetaInstanceResponse:
meta = self.state.meta_instances.get(meta_instance_id)
if not meta:
raise HTTPException(status_code=404, detail="MetaInstance not found")
# Command processor handles cascade-deleting backing instances
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
await self._send(command)
return DeleteMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance_id,
)
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
) -> AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
]:
"""Yield chunks for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._text_generation_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if isinstance(chunk, PrefillProgressChunk):
continue
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -625,6 +573,9 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._token_chunk_stream(command_id):
if isinstance(chunk, PrefillProgressChunk):
continue
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -813,9 +764,11 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"advanced_params": _ensure_seed(payload.advanced_params),
}
)
command = ImageGeneration(
@@ -946,10 +899,10 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -1032,10 +985,10 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -1071,12 +1024,13 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"stream": False,
"partial_images": 0,
"advanced_params": _ensure_seed(payload.advanced_params),
}
)
command = ImageGeneration(
@@ -1097,7 +1051,7 @@ class API:
prompt: str,
model: ModelId,
n: int,
size: str,
size: ImageSize,
response_format: Literal["url", "b64_json"],
input_fidelity: Literal["low", "high"],
stream: bool,
@@ -1167,7 +1121,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
size: str | None = Form(None),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
@@ -1193,7 +1147,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=size,
size=normalize_image_size(size),
response_format=response_format,
input_fidelity=input_fidelity,
stream=stream_bool,
@@ -1229,7 +1183,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
size: str | None = Form(None),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
@@ -1249,7 +1203,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=size,
size=normalize_image_size(size),
response_format=response_format,
input_fidelity=input_fidelity,
stream=False,
@@ -1349,8 +1303,18 @@ class API:
return total_available
async def get_models(self) -> ModelList:
"""Returns list of available models."""
async def get_models(self, status: str | None = Query(default=None)) -> ModelList:
"""Returns list of available models, optionally filtered by being downloaded."""
cards = await get_model_cards()
if status == "downloaded":
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [c for c in cards if c.model_id in downloaded_model_ids]
return ModelList(
data=[
ModelListModel(
@@ -1368,7 +1332,7 @@ class API:
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in await get_model_cards()
for card in cards
]
)
@@ -1492,6 +1456,21 @@ class API:
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
elif isinstance(event, PrefillProgress):
if queue := self._text_generation_queues.get(
event.command_id, None
):
try:
await queue.send(
PrefillProgressChunk(
model=event.model,
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)

View File

@@ -1,5 +1,4 @@
from collections.abc import Sequence
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
import anyio
from anyio.abc import TaskGroup
@@ -13,22 +12,11 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.process_managers import ProcessManager
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.commands import (
CreateInstance,
CreateMetaInstance,
DeleteInstance,
DeleteMetaInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
@@ -48,12 +36,8 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
JacclSideChannelData,
JacclSideChannelGathered,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
@@ -76,8 +60,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
@@ -101,16 +84,16 @@ class Master:
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
self._process_managers: Sequence[ProcessManager] = [
InstanceHealthReconciler(),
NodeTimeoutReconciler(),
MetaInstanceReconciler(),
]
async def run(self):
logger.info("Starting Master")
@@ -119,12 +102,15 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._reconcile)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -306,86 +292,6 @@ class Master:
)
)
generated_events.extend(transition_events)
case CreateMetaInstance():
logger.info(
f"Creating MetaInstance for {command.meta_instance.model_id}"
f" (min_nodes={command.meta_instance.min_nodes},"
f" sharding={command.meta_instance.sharding})"
)
# Apply immediately so self.state is fresh across
# the await below and the reconciler won't race.
await self._apply_and_broadcast(
MetaInstanceCreated(meta_instance=command.meta_instance)
)
# Immediate placement attempt for responsiveness
model_card = await ModelCard.load(
command.meta_instance.model_id
)
# Re-check: reconciler may have satisfied it during the await
meta_id = command.meta_instance.meta_instance_id
still_unsatisfied = any(
m.meta_instance_id == meta_id
for m in find_unsatisfied_meta_instances(
self.state.meta_instances,
self.state.instances,
self.state.topology,
)
)
if still_unsatisfied:
result = try_place_for_meta_instance(
command.meta_instance,
model_card,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
self.state.tasks,
)
generated_events.extend(result.events)
if result.error is not None:
generated_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_id,
reason=result.error,
)
)
case DeleteMetaInstance():
backing_count = sum(
1
for inst in self.state.instances.values()
if inst.meta_instance_id == command.meta_instance_id
)
logger.info(
f"Deleting MetaInstance {command.meta_instance_id}"
f" (cascade-deleting {backing_count} backing instance(s))"
)
generated_events.append(
MetaInstanceDeleted(
meta_instance_id=command.meta_instance_id
)
)
# Cascade-delete backing instances atomically,
# cancelling any active tasks first.
for iid, inst in self.state.instances.items():
if inst.meta_instance_id == command.meta_instance_id:
for task in self.state.tasks.values():
if (
task.instance_id == iid
and task.task_status
in (
TaskStatus.Pending,
TaskStatus.Running,
)
):
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
generated_events.append(
InstanceDeleted(instance_id=iid)
)
case PlaceInstance():
placement = place_instance(
command,
@@ -417,19 +323,16 @@ class Master:
)
case TaskCancelled():
if (
command.cancelled_command_id
in self.command_task_mapping
):
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.cancelled_command_id
]
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
del self.command_task_mapping[
command.cancelled_command_id
]
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -438,10 +341,9 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time
@@ -452,32 +354,31 @@ class Master:
):
await self._send_event(IndexedEvent(idx=i, event=event))
for event in generated_events:
await self._apply_and_broadcast(event)
await self.event_sender.send(event)
except ValueError as e:
logger.opt(exception=e).warning("Error in command processor")
async def _apply_and_broadcast(self, event: Event) -> None:
"""Apply event to state, persist to disk, and broadcast to workers.
State is updated synchronously (before any await), so callers can
rely on ``self.state`` reflecting this event immediately after the
call. Python's cooperative scheduling guarantees no interleaving
between the state read and write.
"""
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
self._event_log.append(event)
await self._send_event(indexed)
async def _reconcile(self) -> None:
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
async def _plan(self) -> None:
while True:
for pm in self._process_managers:
events = await pm.reconcile(self.state)
for event in events:
await self._apply_and_broadcast(event)
await anyio.sleep(1)
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
await self.event_sender.send(
InstanceDeleted(instance_id=instance_id)
)
break
# time out dead nodes
for node_id, time in self.state.last_seen.items():
now = datetime.now(tz=timezone.utc)
if now - time > timedelta(seconds=30):
logger.info(f"Manually removing node {node_id} due to inactivity")
await self.event_sender.send(NodeTimedOut(node_id=node_id))
await anyio.sleep(10)
async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
@@ -495,15 +396,32 @@ class Master:
await self._handle_traces_collected(event)
continue
if isinstance(event, JacclSideChannelData):
await self._apply_and_broadcast(event)
await self._handle_jaccl_side_channel(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
await self._apply_and_broadcast(event)
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
@@ -535,49 +453,10 @@ class Master:
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self._apply_and_broadcast(
await self.event_sender.send(
TracesMerged(task_id=task_id, traces=all_trace_data)
)
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
"""Accumulate SideChannel contributions; when all runners for an instance
have submitted for the same sequence, emit JacclSideChannelGathered."""
iid = event.instance_id
seq = event.sequence
if iid not in self._jaccl_pending:
self._jaccl_pending[iid] = {}
if seq not in self._jaccl_pending[iid]:
self._jaccl_pending[iid][seq] = {}
self._jaccl_pending[iid][seq][event.runner_id] = event.data
instance = self.state.instances.get(iid)
if instance is None:
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
return
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
submitted = set(self._jaccl_pending[iid][seq].keys())
logger.info(
f"JACCL side channel: instance={iid} seq={seq} "
f"submitted={len(submitted)}/{len(expected_runners)}"
)
if submitted >= expected_runners:
gathered = dict(self._jaccl_pending[iid][seq])
del self._jaccl_pending[iid][seq]
if not self._jaccl_pending[iid]:
del self._jaccl_pending[iid]
await self._apply_and_broadcast(
JacclSideChannelGathered(
instance_id=iid,
sequence=seq,
gathered_data=gathered,
)
)

View File

@@ -6,11 +6,11 @@ from typing import Sequence
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_largest_cycles,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
@@ -106,27 +106,23 @@ def place_instance(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
largest_cycles = get_largest_cycles(cycles_with_sufficient_memory)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
largest_rdma_cycles = [
cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle)
smallest_rdma_cycles = [
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl:
if not largest_rdma_cycles:
raise ValueError(
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
)
largest_cycles = largest_rdma_cycles
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycle
for cycle in largest_cycles
for cycle in smallest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles,
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_memory[node_id].ram_available for node_id in cycle),
start=Memory(),

View File

@@ -37,11 +37,11 @@ def filter_cycles_by_memory(
return filtered_cycles
def get_largest_cycles(
def get_smallest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
max_nodes = max(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == max_nodes]
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(

View File

@@ -1,12 +0,0 @@
from collections.abc import Sequence
from typing import Protocol, runtime_checkable
from exo.shared.types.events import Event
from exo.shared.types.state import State
@runtime_checkable
class ProcessManager(Protocol):
"""A reconciliation step that examines state and returns corrective events."""
async def reconcile(self, state: State) -> Sequence[Event]: ...

View File

@@ -1,62 +0,0 @@
from collections.abc import Sequence
from typing import final
from loguru import logger
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
from exo.shared.types.state import State
MAX_INSTANCE_RETRIES = 3
@final
class InstanceHealthReconciler:
"""Delete instances whose network connections are broken or whose runners have all failed."""
async def reconcile(self, state: State) -> Sequence[Event]:
events: list[Event] = []
for instance_id, instance in state.instances.items():
if not instance_connections_healthy(instance, state.topology):
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error="Network connection lost",
)
)
continue
is_failed, error_message = instance_runners_failed(
instance, state.runners, state.node_identities
)
if is_failed:
# Retry within the same instance if backed by a MetaInstance
mid = instance.meta_instance_id
mi = state.meta_instances.get(mid) if mid else None
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
logger.info(
f"Instance {instance_id} failed (attempt"
f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES}),"
f" retrying: {error_message}"
)
events.append(
InstanceRetrying(
instance_id=instance_id,
meta_instance_id=mid,
failure_error=error_message or "Runner failed",
)
)
else:
if mid and mi:
logger.warning(
f"Instance {instance_id} exceeded retry limit"
f" ({MAX_INSTANCE_RETRIES}), deleting:"
f" {error_message}"
)
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error=error_message,
)
)
return events

View File

@@ -1,92 +0,0 @@
from collections.abc import Sequence
from typing import final
import anyio
from loguru import logger
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
from exo.shared.types.state import State
from exo.shared.types.worker.instances import Instance, InstanceId
MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10
@final
class MetaInstanceReconciler:
"""Place instances for unsatisfied MetaInstances."""
async def reconcile(self, state: State) -> Sequence[Event]:
all_events: list[Event] = []
# Local copy for intermediate tracking — so placement of B
# sees A's instance and doesn't double-place on same resources.
current_instances: dict[InstanceId, Instance] = dict(state.instances)
unsatisfied = find_unsatisfied_meta_instances(
state.meta_instances,
current_instances,
state.topology,
)
for meta_instance in unsatisfied:
try:
with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS):
model_card = await ModelCard.load(meta_instance.model_id)
except TimeoutError:
logger.warning(
f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle"
)
continue
except Exception as exc:
logger.warning(
f"ModelCard.load failed for {meta_instance.model_id}: {exc}"
)
error = f"Failed to load model card: {exc}"
if meta_instance.placement_error != error:
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=error,
)
)
continue
result = try_place_for_meta_instance(
meta_instance,
model_card,
state.topology,
current_instances,
state.node_memory,
state.node_network,
state.tasks,
)
# Update local instance map so next placement sees this one
for event in result.events:
if isinstance(event, InstanceCreated):
logger.info(
f"MetaInstance reconciler placed instance"
f" {event.instance.instance_id} for"
f" {meta_instance.model_id}"
)
current_instances[event.instance.instance_id] = event.instance
all_events.extend(result.events)
# Emit placement failure if error differs from what's already in state
if (
result.error is not None
and meta_instance.placement_error != result.error
):
logger.warning(
f"MetaInstance placement failed for"
f" {meta_instance.model_id}: {result.error}"
)
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=result.error,
)
)
return all_events

View File

@@ -1,27 +0,0 @@
from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from typing import final
from loguru import logger
from exo.shared.types.events import Event, NodeTimedOut
from exo.shared.types.state import State
_DEFAULT_TIMEOUT = timedelta(seconds=30)
@final
class NodeTimeoutReconciler:
"""Time out nodes that haven't been seen recently."""
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
self.timeout = timeout
async def reconcile(self, state: State) -> Sequence[Event]:
now = datetime.now(tz=timezone.utc)
events: list[Event] = []
for node_id, last_seen in state.last_seen.items():
if now - last_seen > self.timeout:
logger.info(f"Removing node {node_id} due to inactivity")
events.append(NodeTimedOut(node_id=node_id))
return events

View File

@@ -1,244 +0,0 @@
from collections.abc import Mapping, Sequence
from typing import NamedTuple
from loguru import logger
from exo.master.placement import get_transition_events, place_instance
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import Event
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.instances import (
BaseInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
RunnerStatus,
)
class PlacementResult(NamedTuple):
"""Result of a placement attempt: events to apply and optional error reason."""
events: Sequence[Event]
error: str | None
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
"""Reconstruct ring order from shard device_rank."""
node_ranks: list[tuple[NodeId, int]] = []
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
shard = instance.shard_assignments.runner_to_shard[runner_id]
node_ranks.append((node_id, shard.device_rank))
node_ranks.sort(key=lambda x: x[1])
return [node_id for node_id, _ in node_ranks]
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
"""Check that the specific IPs used by a ring instance still exist in the topology."""
ring = _get_ring_order(instance)
n = len(ring)
for node in ring:
hosts = instance.hosts_by_node[node]
for idx in range(n):
host = hosts[idx]
if host.ip in ("0.0.0.0", "198.51.100.1"):
continue # self or placeholder
# Real connection: node → ring[idx]. Check specific IP.
connections = topology.get_all_connections_between(node, ring[idx])
if not any(
isinstance(c, SocketConnection)
and c.sink_multiaddr.ip_address == host.ip
for c in connections
):
return False
return True
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
ring = _get_ring_order(instance)
n = len(ring)
for i in range(n):
for j in range(n):
iface = instance.jaccl_devices[i][j]
if iface is None:
continue
connections = topology.get_all_connections_between(ring[i], ring[j])
if not any(
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
for c in connections
):
return False
return True
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
"""Check that an instance's nodes and specific connections are still in the topology."""
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if not all(topology.contains_node(n) for n in instance_nodes):
return False
if len(instance_nodes) <= 1:
return True
match instance:
case MlxRingInstance():
return _ring_connections_healthy(instance, topology)
case MlxJacclInstance():
return _jaccl_connections_healthy(instance, topology)
def instance_runners_failed(
instance: Instance,
runners: Mapping[RunnerId, RunnerStatus],
node_identities: Mapping[NodeId, NodeIdentity],
) -> tuple[bool, str | None]:
"""Check if an instance's runners have all reached terminal failure states.
Returns ``(True, error_message)`` when ALL runners are terminal
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
Returns ``(False, None)`` when runners are still active, haven't reported
yet, or all gracefully shut down (no ``RunnerFailed``).
"""
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
if not instance_runner_ids:
return False, None
# Build reverse mapping: runner_id -> node_id
runner_to_node: dict[RunnerId, NodeId] = {
runner_id: node_id
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
}
has_any_failed = False
error_messages: list[str] = []
for runner_id in instance_runner_ids:
status = runners.get(runner_id)
if status is None:
# Runner hasn't reported yet — instance is still starting
return False, None
if isinstance(status, RunnerFailed):
has_any_failed = True
if status.error_message:
node_id = runner_to_node.get(runner_id)
name = (
node_identities[node_id].friendly_name
if node_id and node_id in node_identities
else node_id or "unknown"
)
error_messages.append(f"{name}: {status.error_message}")
elif isinstance(status, RunnerShutdown):
pass # Terminal but not a failure indicator on its own
else:
# Runner is still active (connecting, loading, running, etc.)
return False, None
if has_any_failed:
return True, "; ".join(error_messages) if error_messages else "Runner failed"
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
return False, None
def instance_satisfies_meta_instance(
meta_instance: MetaInstance,
instance: Instance,
) -> bool:
"""Check if a single instance satisfies a meta-instance's constraints.
This is a pure constraint check (model, min_nodes, node_ids).
Use ``instance_connections_healthy`` separately for topology health.
"""
if instance.shard_assignments.model_id != meta_instance.model_id:
return False
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if len(instance_nodes) < meta_instance.min_nodes:
return False
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
instance_nodes
)
def find_unsatisfied_meta_instances(
meta_instances: Mapping[MetaInstanceId, MetaInstance],
instances: Mapping[InstanceId, Instance],
topology: Topology,
) -> Sequence[MetaInstance]:
"""Return meta-instances that have no healthy backing instance."""
unsatisfied: list[MetaInstance] = []
for meta_id, meta_instance in meta_instances.items():
has_healthy_backing = any(
instance.meta_instance_id == meta_id
and instance_connections_healthy(instance, topology)
for instance in instances.values()
)
if not has_healthy_backing:
unsatisfied.append(meta_instance)
return unsatisfied
def try_place_for_meta_instance(
meta_instance: MetaInstance,
model_card: ModelCard,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
tasks: Mapping[TaskId, Task],
) -> PlacementResult:
"""Try to place an instance satisfying the meta-instance constraints.
Returns a :class:`PlacementResult` with events on success, or an error
reason on failure.
"""
command = PlaceInstance(
model_card=model_card,
sharding=meta_instance.sharding,
instance_meta=meta_instance.instance_meta,
min_nodes=meta_instance.min_nodes,
)
try:
target_instances = place_instance(
command,
topology,
current_instances,
node_memory,
node_network,
required_nodes=(
set(meta_instance.node_ids) if meta_instance.node_ids else None
),
)
# Tag the new instance with meta_instance_id
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
if new_instance_ids:
new_id = next(iter(new_instance_ids))
target_instances[new_id] = target_instances[new_id].model_copy(
update={"meta_instance_id": meta_instance.meta_instance_id}
)
return PlacementResult(
events=list(
get_transition_events(current_instances, target_instances, tasks)
),
error=None,
)
except ValueError as e:
logger.debug(
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
)
return PlacementResult(events=[], error=str(e))

View File

@@ -1,778 +0,0 @@
"""Edge-case and regression tests for MetaInstance lifecycle, concurrent operations, and error handling."""
import pytest
from exo.master.process_managers.instance_health import (
MAX_INSTANCE_RETRIES,
InstanceHealthReconciler,
)
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
instance_connections_healthy,
instance_runners_failed,
instance_satisfies_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, MetaInstanceId, NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NodeIdentity
from exo.shared.types.state import State
from exo.shared.types.tasks import LoadModel, TaskId, TaskStatus
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerReady,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
# --- Helpers (copied from test_reconcile.py for independence) ---
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_kb(1000),
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
def _topology(*node_ids: str, connect: bool = True) -> Topology:
t = Topology()
nodes = [NodeId(n) for n in node_ids]
for n in nodes:
t.add_node(n)
if connect and len(nodes) > 1:
for i in range(len(nodes)):
j = (i + 1) % len(nodes)
t.add_connection(
Connection(
source=nodes[i],
sink=nodes[j],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
)
),
)
)
t.add_connection(
Connection(
source=nodes[j],
sink=nodes[i],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
)
),
)
)
return t
def _meta_instance(
model_id: str = "test-org/test-model",
*,
min_nodes: int = 1,
node_ids: list[NodeId] | None = None,
meta_instance_id: MetaInstanceId | None = None,
consecutive_failures: int = 0,
last_failure_error: str | None = None,
placement_error: str | None = None,
) -> MetaInstance:
return MetaInstance(
meta_instance_id=meta_instance_id or MetaInstanceId(),
model_id=ModelId(model_id),
min_nodes=min_nodes,
node_ids=node_ids,
consecutive_failures=consecutive_failures,
last_failure_error=last_failure_error,
placement_error=placement_error,
)
def _instance(
model_id: str = "test-org/test-model",
node_ids: list[str] | None = None,
instance_id: InstanceId | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> tuple[InstanceId, MlxRingInstance]:
iid = instance_id or InstanceId()
nodes = node_ids or ["node-a"]
n = len(nodes)
mc = _model_card(model_id)
ephemeral_port = 50000
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
runner_to_shard = {
runner_id: PipelineShardMetadata(
model_card=mc,
device_rank=i,
world_size=n,
start_layer=0,
end_layer=mc.n_layers,
n_layers=mc.n_layers,
)
for i, runner_id in enumerate(node_to_runner.values())
}
hosts_by_node: dict[NodeId, list[Host]] = {}
for r, node_str in enumerate(nodes):
hosts: list[Host] = []
for idx in range(n):
if idx == r:
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
else:
hosts.append(Host(ip="198.51.100.1", port=0))
hosts_by_node[NodeId(node_str)] = hosts
return iid, MlxRingInstance(
instance_id=iid,
shard_assignments=ShardAssignments(
model_id=ModelId(model_id),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
),
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
meta_instance_id=meta_instance_id,
)
# =============================================================================
# 1. MetaInstance lifecycle edge cases
# =============================================================================
def test_meta_instance_model_is_frozen():
"""MetaInstance should be immutable (frozen model)."""
meta = _meta_instance()
try:
meta.model_id = ModelId("something-else")
raise AssertionError("Should have raised")
except Exception:
pass # Expected — frozen model
def test_meta_instance_created_then_deleted_roundtrip():
"""Create and delete a MetaInstance through apply — state should be clean."""
state = State()
meta = _meta_instance()
state = apply(
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta))
)
assert meta.meta_instance_id in state.meta_instances
state = apply(
state,
IndexedEvent(
idx=1, event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
),
)
assert meta.meta_instance_id not in state.meta_instances
assert len(state.meta_instances) == 0
def test_delete_nonexistent_meta_instance_is_safe():
"""Deleting a MetaInstance that doesn't exist should not crash."""
state = State()
event = MetaInstanceDeleted(meta_instance_id=MetaInstanceId("nonexistent"))
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
def test_placement_failed_for_nonexistent_meta_instance_is_safe():
"""MetaInstancePlacementFailed for unknown ID should not crash."""
state = State()
event = MetaInstancePlacementFailed(
meta_instance_id=MetaInstanceId("nonexistent"),
reason="test",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
def test_multiple_meta_instances_for_same_model():
"""Multiple MetaInstances for the same model are tracked independently."""
state = State()
meta_a = _meta_instance("test-org/model-x")
meta_b = _meta_instance("test-org/model-x")
state = apply(
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta_a))
)
state = apply(
state, IndexedEvent(idx=1, event=MetaInstanceCreated(meta_instance=meta_b))
)
assert len(state.meta_instances) == 2
assert meta_a.meta_instance_id in state.meta_instances
assert meta_b.meta_instance_id in state.meta_instances
# =============================================================================
# 2. Retry logic edge cases
# =============================================================================
def test_retry_counter_resets_on_successful_instance_creation():
"""When a new instance is created for a meta-instance, failures should reset."""
meta = _meta_instance(consecutive_failures=2, last_failure_error="old")
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
mi = state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
# last_failure_error is preserved (for UI display)
assert mi.last_failure_error == "old"
async def test_retry_count_increments_through_full_cycle():
"""Walk through MAX_INSTANCE_RETRIES worth of retries, then verify delete."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
topology = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
topology=topology,
)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
for idx, i in enumerate(range(MAX_INSTANCE_RETRIES)):
# Simulate runners failing
state_with_runners = state.model_copy(
update={"runners": {runner_ids[0]: RunnerFailed(error_message=f"fail-{i}")}}
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state_with_runners)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying), f"iteration {i}"
state = apply(state, IndexedEvent(idx=idx, event=events[0]))
# After MAX_INSTANCE_RETRIES retries, failure counter should be at max
mi = state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == MAX_INSTANCE_RETRIES
# Next failure should result in deletion
state_with_runners = state.model_copy(
update={"runners": {runner_ids[0]: RunnerFailed(error_message="final")}}
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state_with_runners)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_respects_exact_limit():
"""At exactly MAX_INSTANCE_RETRIES, reconciler should delete, not retry."""
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_at_limit_minus_one_retries():
"""At MAX_INSTANCE_RETRIES - 1, reconciler should still retry."""
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES - 1)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
# =============================================================================
# 3. Error handling edge cases
# =============================================================================
def test_runners_failed_with_empty_error_message():
"""RunnerFailed with empty error_message should still report as failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerFailed(error_message="")
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
# Empty error message means we get the fallback
assert error == "Runner failed"
def test_runners_failed_with_none_error_message():
"""RunnerFailed with None error_message should still report as failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerFailed(error_message=None)
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error == "Runner failed"
def test_runners_failed_collects_all_error_messages():
"""With multiple failed runners, all error messages should be collected."""
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM on GPU 0"),
runner_ids[1]: RunnerFailed(error_message="OOM on GPU 1"),
runner_ids[2]: RunnerFailed(error_message="OOM on GPU 2"),
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "OOM on GPU 0" in error
assert "OOM on GPU 1" in error
assert "OOM on GPU 2" in error
def test_runners_failed_includes_friendly_name():
"""Error messages should include node friendly names when available."""
_, inst = _instance(node_ids=["node-a"])
node_id = NodeId("node-a")
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {runner_ids[0]: RunnerFailed(error_message="OOM")}
identities = {node_id: NodeIdentity(friendly_name="My Mac Studio")}
is_failed, error = instance_runners_failed(inst, runners, identities)
assert is_failed is True
assert error is not None
assert "My Mac Studio" in error
def test_instance_retrying_for_missing_instance_is_safe():
"""InstanceRetrying for an instance not in state should not crash.
NOTE: When the instance is missing, the handler returns early WITHOUT
incrementing the MetaInstance failure counter. This means stale retry
events for already-deleted instances are silently dropped. This is
acceptable since the InstanceDeleted handler already increments failures.
"""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = InstanceRetrying(
instance_id=InstanceId("nonexistent"),
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
# Does not crash, but failure count is NOT incremented (early return)
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
# =============================================================================
# 4. Backward compatibility
# =============================================================================
def test_instance_without_meta_instance_id_works():
"""Instances created without meta_instance_id should still function normally."""
_, inst = _instance(node_ids=["node-a"])
assert inst.meta_instance_id is None
topology = _topology("node-a")
assert instance_connections_healthy(inst, topology) is True
def test_instance_deleted_without_meta_does_not_affect_meta_instances():
"""Deleting an instance without meta_instance_id should not affect meta_instances."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"]) # no meta_instance_id
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="crash")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0 # unchanged
def test_satisfies_ignores_meta_instance_id_binding():
"""instance_satisfies_meta_instance checks constraints only, not binding."""
meta = _meta_instance()
_, inst = _instance(node_ids=["node-a"]) # no meta_instance_id set
# Should match on constraints (model, min_nodes) regardless of binding
assert instance_satisfies_meta_instance(meta, inst) is True
def test_find_unsatisfied_uses_binding_not_constraints():
"""find_unsatisfied checks meta_instance_id binding, not just constraint matching."""
meta = _meta_instance()
# Instance matches constraints but is NOT bound to this meta_instance
iid, inst = _instance(node_ids=["node-a"])
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta}, {iid: inst}, topology
)
# Should be unsatisfied because instance.meta_instance_id != meta.meta_instance_id
assert list(result) == [meta]
# =============================================================================
# 5. Concurrent / multi-instance scenarios
# =============================================================================
async def test_health_reconciler_handles_multiple_failing_instances():
"""Multiple instances failing simultaneously should each get their own event."""
meta_a = _meta_instance()
meta_b = _meta_instance()
iid_a, inst_a = _instance(
node_ids=["node-a"], meta_instance_id=meta_a.meta_instance_id
)
iid_b, inst_b = _instance(
node_ids=["node-b"], meta_instance_id=meta_b.meta_instance_id
)
runner_ids_a = list(inst_a.shard_assignments.node_to_runner.values())
runner_ids_b = list(inst_b.shard_assignments.node_to_runner.values())
state = State(
meta_instances={
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
instances={iid_a: inst_a, iid_b: inst_b},
runners={
runner_ids_a[0]: RunnerFailed(error_message="OOM"),
runner_ids_b[0]: RunnerFailed(error_message="OOM"),
},
topology=_topology("node-a", "node-b"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 2
# Both should be InstanceRetrying since failures < MAX
assert all(isinstance(e, InstanceRetrying) for e in events)
instance_ids = {e.instance_id for e in events} # type: ignore[union-attr]
assert instance_ids == {iid_a, iid_b}
async def test_health_reconciler_mixed_healthy_and_failing():
"""Only failing instances should produce events; healthy ones should not."""
meta_healthy = _meta_instance()
meta_failing = _meta_instance()
iid_h, inst_h = _instance(
node_ids=["node-a"], meta_instance_id=meta_healthy.meta_instance_id
)
iid_f, inst_f = _instance(
node_ids=["node-b"], meta_instance_id=meta_failing.meta_instance_id
)
runner_ids_h = list(inst_h.shard_assignments.node_to_runner.values())
runner_ids_f = list(inst_f.shard_assignments.node_to_runner.values())
state = State(
meta_instances={
meta_healthy.meta_instance_id: meta_healthy,
meta_failing.meta_instance_id: meta_failing,
},
instances={iid_h: inst_h, iid_f: inst_f},
runners={
runner_ids_h[0]: RunnerReady(),
runner_ids_f[0]: RunnerFailed(error_message="crash"),
},
topology=_topology("node-a", "node-b"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
assert events[0].instance_id == iid_f
async def test_meta_instance_reconciler_empty_state():
"""MetaInstanceReconciler with no meta_instances should produce no events."""
state = State()
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 0
# =============================================================================
# 6. Placement error tracking
# =============================================================================
def test_placement_failed_sets_error():
"""MetaInstancePlacementFailed should set placement_error on the MetaInstance."""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstancePlacementFailed(
meta_instance_id=meta.meta_instance_id,
reason="Not enough memory",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.placement_error == "Not enough memory"
def test_instance_created_clears_placement_error():
"""InstanceCreated should clear placement_error on the MetaInstance."""
meta = _meta_instance(placement_error="Not enough memory")
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
mi = state.meta_instances[meta.meta_instance_id]
assert mi.placement_error is None
def test_placement_error_does_not_increment_failures():
"""Placement failures should only set placement_error, not increment consecutive_failures."""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstancePlacementFailed(
meta_instance_id=meta.meta_instance_id,
reason="No resources",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
assert mi.placement_error == "No resources"
# =============================================================================
# 7. State serialization roundtrip
# =============================================================================
def test_state_with_meta_instances_serializes():
"""State with meta_instances should serialize and deserialize correctly."""
meta = _meta_instance(consecutive_failures=2, last_failure_error="test")
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
json_str = state.model_dump_json()
restored = State.model_validate_json(json_str)
assert meta.meta_instance_id in restored.meta_instances
mi = restored.meta_instances[meta.meta_instance_id]
assert mi.model_id == meta.model_id
assert mi.consecutive_failures == 2
assert mi.last_failure_error == "test"
assert iid in restored.instances
assert restored.instances[iid].meta_instance_id == meta.meta_instance_id
# =============================================================================
# 8. MetaInstanceReconciler error handling
# =============================================================================
async def test_meta_instance_reconciler_model_load_error_emits_placement_failed(
monkeypatch: "pytest.MonkeyPatch",
):
"""When ModelCard.load raises, reconciler emits MetaInstancePlacementFailed."""
import exo.master.process_managers.meta_instance as mi_mod
meta = _meta_instance()
topo = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
topology=topo,
)
async def _failing_load(_model_id: ModelId) -> ModelCard:
raise RuntimeError("Network error")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
assert len(placement_failed) == 1
assert "Failed to load model card" in placement_failed[0].reason
assert meta.meta_instance_id == placement_failed[0].meta_instance_id
async def test_meta_instance_reconciler_model_load_error_skips_dedup(
monkeypatch: "pytest.MonkeyPatch",
):
"""When ModelCard.load error matches existing placement_error, no duplicate event."""
import exo.master.process_managers.meta_instance as mi_mod
meta = _meta_instance(placement_error="Failed to load model card: Network error")
topo = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
topology=topo,
)
async def _failing_load(_model_id: ModelId) -> ModelCard:
raise RuntimeError("Network error")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
# Error matches existing placement_error, so no duplicate event emitted
assert len(events) == 0
async def test_meta_instance_reconciler_continues_after_error(
monkeypatch: "pytest.MonkeyPatch",
):
"""Reconciler should continue to next meta-instance after one fails to load."""
import exo.master.process_managers.meta_instance as mi_mod
meta_a = _meta_instance(model_id="org/model-a")
meta_b = _meta_instance(model_id="org/model-b")
topo = _topology("node-a")
state = State(
meta_instances={
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
topology=topo,
)
call_count = 0
async def _load_second_fails(model_id: ModelId) -> ModelCard:
nonlocal call_count
call_count += 1
raise RuntimeError(f"Cannot load {model_id}")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_load_second_fails)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
# Both meta-instances should have been attempted (not short-circuited)
assert call_count == 2
# Both should have placement failed events
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
assert len(placement_failed) == 2
# =============================================================================
# 8. Cascade delete with task cancellation
# =============================================================================
def test_cascade_delete_cancels_active_tasks():
"""Deleting a MetaInstance should cancel tasks on backing instances.
Regression test: previously, cascade-deleting backing instances via
DeleteMetaInstance did not emit TaskStatusUpdated(Cancelled) for active
tasks, leaving orphaned task references in state.
"""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
task_id = TaskId()
task = LoadModel(task_id=task_id, instance_id=iid, task_status=TaskStatus.Running)
# Build state with meta-instance, backing instance, and active task
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
tasks={task_id: task},
topology=_topology("node-a"),
)
# Simulate the cascade-delete event sequence produced by main.py:
# 1. MetaInstanceDeleted
# 2. TaskStatusUpdated(Cancelled) for active tasks
# 3. InstanceDeleted
idx = 0
state = apply(
state,
IndexedEvent(
idx=idx,
event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id),
),
)
idx += 1
state = apply(
state,
IndexedEvent(
idx=idx,
event=TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Cancelled),
),
)
idx += 1
state = apply(
state,
IndexedEvent(idx=idx, event=InstanceDeleted(instance_id=iid)),
)
# Verify everything is cleaned up
assert len(state.meta_instances) == 0
assert len(state.instances) == 0
assert state.tasks[task_id].task_status == TaskStatus.Cancelled
def test_cascade_delete_skips_completed_tasks():
"""Cascade delete should only cancel Pending/Running tasks, not completed ones."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
running_task_id = TaskId()
completed_task_id = TaskId()
running_task = LoadModel(
task_id=running_task_id, instance_id=iid, task_status=TaskStatus.Running
)
completed_task = LoadModel(
task_id=completed_task_id, instance_id=iid, task_status=TaskStatus.Complete
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
tasks={running_task_id: running_task, completed_task_id: completed_task},
topology=_topology("node-a"),
)
# Only the running task should be cancelled — we verify the logic pattern
# by checking which tasks are Pending or Running
active_tasks = [
t
for t in state.tasks.values()
if t.instance_id == iid
and t.task_status in (TaskStatus.Pending, TaskStatus.Running)
]
assert len(active_tasks) == 1
assert active_tasks[0].task_id == running_task_id

View File

@@ -3,10 +3,10 @@ import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_largest_cycles,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_shard_assignments_for_pipeline_parallel,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
create_node_memory,
@@ -143,7 +143,7 @@ def test_filter_multiple_cycles_by_memory():
}
def test_get_largest_cycles():
def test_get_smallest_cycles():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
@@ -175,12 +175,12 @@ def test_get_largest_cycles():
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
# act
largest_cycles = get_largest_cycles(cycles)
smallest_cycles = get_smallest_cycles(cycles)
# assert
assert len(largest_cycles) == 1
assert len(largest_cycles[0]) == 3
assert set(n for n in largest_cycles[0]) == {node_a_id, node_b_id, node_c_id}
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
@pytest.mark.parametrize(

View File

@@ -1,742 +0,0 @@
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
instance_connections_healthy,
instance_runners_failed,
instance_satisfies_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, MetaInstanceId, NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerLoading,
RunnerReady,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_kb(1000),
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
def _topology(*node_ids: str, connect: bool = True) -> Topology:
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
between consecutive nodes (including wrap-around).
"""
t = Topology()
nodes = [NodeId(n) for n in node_ids]
for n in nodes:
t.add_node(n)
if connect and len(nodes) > 1:
for i in range(len(nodes)):
j = (i + 1) % len(nodes)
t.add_connection(
Connection(
source=nodes[i],
sink=nodes[j],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
)
),
)
)
t.add_connection(
Connection(
source=nodes[j],
sink=nodes[i],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
)
),
)
)
return t
def _meta_instance(
model_id: str = "test-org/test-model",
*,
min_nodes: int = 1,
node_ids: list[NodeId] | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> MetaInstance:
return MetaInstance(
meta_instance_id=meta_instance_id or MetaInstanceId(),
model_id=ModelId(model_id),
min_nodes=min_nodes,
node_ids=node_ids,
)
def _instance(
model_id: str = "test-org/test-model",
node_ids: list[str] | None = None,
instance_id: InstanceId | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> tuple[InstanceId, MlxRingInstance]:
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
iid = instance_id or InstanceId()
nodes = node_ids or ["node-a"]
n = len(nodes)
mc = _model_card(model_id)
ephemeral_port = 50000
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
runner_to_shard = {
runner_id: PipelineShardMetadata(
model_card=mc,
device_rank=i,
world_size=n,
start_layer=0,
end_layer=mc.n_layers,
n_layers=mc.n_layers,
)
for i, runner_id in enumerate(node_to_runner.values())
}
# Build hosts_by_node with IPs matching _topology() convention:
# node at index idx has IP 10.0.0.{idx+1}
hosts_by_node: dict[NodeId, list[Host]] = {}
for r, node_str in enumerate(nodes):
hosts: list[Host] = []
for idx in range(n):
if idx == r:
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
else:
hosts.append(Host(ip="198.51.100.1", port=0))
hosts_by_node[NodeId(node_str)] = hosts
return iid, MlxRingInstance(
instance_id=iid,
shard_assignments=ShardAssignments(
model_id=ModelId(model_id),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
),
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
meta_instance_id=meta_instance_id,
)
# --- instance_satisfies_meta_instance (pure constraint matching) ---
def test_satisfies_matching_model():
meta = _meta_instance()
_, inst = _instance(node_ids=["node-a"])
assert instance_satisfies_meta_instance(meta, inst) is True
def test_not_satisfies_wrong_model():
meta = _meta_instance("test-org/model-a")
_, inst = _instance("test-org/model-b")
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_missing_required_node():
meta = _meta_instance(node_ids=[NodeId("node-c")])
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_fewer_than_min_nodes():
meta = _meta_instance(min_nodes=3)
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_satisfies_with_node_ids_specified():
meta = _meta_instance(node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2)
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
assert instance_satisfies_meta_instance(meta, inst) is True
# --- instance_connections_healthy ---
def test_healthy_single_node_present():
_, inst = _instance(node_ids=["node-a"])
topology = _topology("node-a")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_single_node_missing():
_, inst = _instance(node_ids=["node-a"])
topology = Topology() # empty
assert instance_connections_healthy(inst, topology) is False
def test_healthy_two_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_two_node_edge_removed():
"""Nodes present but edge removed — ring broken."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", connect=False)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_two_node_ip_changed():
"""Edge exists but with a different IP than instance was configured with."""
_, inst = _instance(node_ids=["node-a", "node-b"])
# Build topology with different IPs than _instance() expects
topology = Topology()
topology.add_node(NodeId("node-a"))
topology.add_node(NodeId("node-b"))
topology.add_connection(
Connection(
source=NodeId("node-a"),
sink=NodeId("node-b"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=NodeId("node-b"),
sink=NodeId("node-a"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
),
)
)
assert instance_connections_healthy(inst, topology) is False
def test_healthy_three_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_three_node_one_edge_removed():
"""Remove one edge from a three-node ring — instance unhealthy."""
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
# Build topology with one direction of one edge missing
topology = Topology()
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
for n in nodes:
topology.add_node(n)
# Add all edges except node-a → node-b
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[1],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[0],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
# Missing: node-a → node-b (ip 10.0.0.2)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_node_missing_from_topology():
"""Instance has a node that's not in the topology at all."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a") # node-b not present
assert instance_connections_healthy(inst, topology) is False
def test_healthy_extra_nodes_in_topology():
"""Extra nodes in topology don't affect instance health."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
# --- find_unsatisfied_meta_instances ---
def test_unsatisfied_no_meta_instances():
result = find_unsatisfied_meta_instances({}, {}, Topology())
assert list(result) == []
def test_unsatisfied_one_satisfied():
meta = _meta_instance()
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == []
def test_unsatisfied_one_not_satisfied():
meta = _meta_instance("test-org/model-x")
id_a, inst_a = _instance("test-org/model-y")
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
)
assert list(result) == [meta]
def test_unsatisfied_mix():
meta_satisfied = _meta_instance("test-org/model-a")
meta_unsatisfied = _meta_instance("test-org/model-b")
id_a, inst_a = _instance(
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_satisfied.meta_instance_id: meta_satisfied,
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
},
{id_a: inst_a},
topology,
)
assert list(result) == [meta_unsatisfied]
def test_unsatisfied_node_disconnect():
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a") # node-b disconnected
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_edge_break():
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_idempotent():
meta = _meta_instance("test-org/model-x")
topology = _topology("node-a")
meta_instances = {meta.meta_instance_id: meta}
instances: dict[InstanceId, MlxRingInstance] = {}
result_1 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
result_2 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
assert result_1 == result_2
def test_unsatisfied_exclusive_binding():
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
meta_a = _meta_instance("test-org/model-x")
meta_b = _meta_instance("test-org/model-x")
id_inst, inst = _instance(
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
{id_inst: inst},
topology,
)
assert list(result) == [meta_b]
# --- apply handlers ---
def test_apply_meta_instance_created():
state = State()
meta = _meta_instance()
event = MetaInstanceCreated(meta_instance=meta)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id in new_state.meta_instances
assert new_state.meta_instances[meta.meta_instance_id] == meta
def test_apply_meta_instance_deleted():
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
def test_apply_meta_instance_deleted_clears_failure_info():
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
)
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
# --- instance_runners_failed ---
def test_runners_failed_all_failed():
"""All runners in RunnerFailed -> instance is failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runners = {
rid: RunnerFailed(error_message="OOM")
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "OOM" in error
def test_runners_failed_mixed_failed_shutdown():
"""One Failed + one Shutdown = failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="crash"),
runner_ids[1]: RunnerShutdown(),
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "crash" in error
def test_runners_not_failed_all_shutdown():
"""All Shutdown (graceful) = not a failure."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerShutdown() for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_still_active():
"""Some runners still active = not failed yet."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerLoading(),
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_no_status():
"""Runner not yet reported = not failed."""
_, inst = _instance(node_ids=["node-a"])
is_failed, _ = instance_runners_failed(inst, {}, {})
assert is_failed is False
def test_runners_not_failed_healthy():
"""Runners in Ready state = not failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerReady() for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
# --- failure tracking in apply_instance_deleted ---
def test_apply_instance_deleted_tracks_failure():
"""InstanceDeleted with failure_error increments meta instance failure count."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "Runner OOM"
def test_apply_instance_deleted_increments_failure():
"""Subsequent failures increment the counter."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="new error")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 3
assert mi.last_failure_error == "new error"
def test_apply_instance_deleted_no_failure_no_tracking():
"""InstanceDeleted without failure_error does not track."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
def test_apply_instance_deleted_orphan_no_tracking():
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
iid, inst = _instance(node_ids=["node-a"])
state = State(instances={iid: inst})
event = InstanceDeleted(instance_id=iid, failure_error="crash")
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
# --- InstanceRetrying ---
def test_apply_instance_retrying_removes_runners():
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerShutdown(),
}
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners=runners,
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="OOM",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
# Instance still exists
assert iid in new_state.instances
# Runners removed
assert runner_ids[0] not in new_state.runners
assert runner_ids[1] not in new_state.runners
def test_apply_instance_retrying_increments_failure():
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "crash"
def test_apply_instance_retrying_skips_missing_runners():
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
# No runners in state at all
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
# Should not raise
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert iid in new_state.instances
def test_apply_instance_created_resets_failure_counter():
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 3, "last_failure_error": "old error"}
)
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
event = InstanceCreated(instance=inst)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
assert mi.last_failure_error == "old error"
assert mi.placement_error is None
# --- InstanceHealthReconciler retry-vs-delete ---
async def test_health_reconciler_retries_when_under_limit():
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
assert events[0].instance_id == iid
assert events[0].meta_instance_id == meta.meta_instance_id
async def test_health_reconciler_deletes_when_limit_reached():
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_deletes_without_meta_instance():
"""Instances without a MetaInstance are deleted immediately on runner failure."""
iid, inst = _instance(node_ids=["node-a"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_network_failure_always_deletes():
"""Network failure always triggers InstanceDeleted regardless of retry count."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
topology=_topology("node-a"), # node-b missing
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
assert events[0].failure_error == "Network connection lost"

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from loguru import logger
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -12,15 +12,10 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
JacclSideChannelData,
JacclSideChannelGathered,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -34,7 +29,6 @@ from exo.shared.types.events import (
TracesCollected,
TracesMerged,
)
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
@@ -71,24 +65,15 @@ def event_apply(event: Event, state: State) -> State:
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| PrefillProgress()
| TracesCollected()
| TracesMerged()
| JacclSideChannelData()
| JacclSideChannelGathered()
): # Pass-through events that don't modify state
return state
case InstanceCreated():
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceRetrying():
return apply_instance_retrying(event, state)
case MetaInstanceCreated():
return apply_meta_instance_created(event, state)
case MetaInstanceDeleted():
return apply_meta_instance_deleted(event, state)
case MetaInstancePlacementFailed():
return apply_meta_instance_placement_failed(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -191,123 +176,20 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
return state.model_copy(update={"tasks": new_tasks})
def _update_meta_instance(
state: State, mid: MetaInstanceId, **fields: object
) -> Mapping[MetaInstanceId, MetaInstance]:
mi = state.meta_instances[mid]
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
def apply_instance_created(event: InstanceCreated, state: State) -> State:
instance = event.instance
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
instance.instance_id: instance,
}
update: dict[str, object] = {"instances": new_instances}
# Reset failure tracking when a new instance is created for a meta-instance
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
mi = state.meta_instances[instance.meta_instance_id]
if mi.placement_error is not None or mi.consecutive_failures > 0:
update["meta_instances"] = _update_meta_instance(
state,
instance.meta_instance_id,
placement_error=None,
consecutive_failures=0,
)
return state.model_copy(update=update)
return state.model_copy(update={"instances": new_instances})
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
deleted_instance = state.instances.get(event.instance_id)
new_instances: Mapping[InstanceId, Instance] = {
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
}
update: dict[str, object] = {"instances": new_instances}
# Track failure on the MetaInstance itself
if (
event.failure_error
and deleted_instance
and deleted_instance.meta_instance_id
and deleted_instance.meta_instance_id in state.meta_instances
):
mid = deleted_instance.meta_instance_id
mi = state.meta_instances[mid]
update["meta_instances"] = {
**state.meta_instances,
mid: mi.model_copy(
update={
"consecutive_failures": mi.consecutive_failures + 1,
"last_failure_error": event.failure_error,
}
),
}
return state.model_copy(update=update)
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
"""Runners failed but retry limit not reached — remove runners, keep instance."""
instance = state.instances.get(event.instance_id)
if instance is None:
# Instance was already deleted (e.g. cascade from DeleteMetaInstance).
# The InstanceDeleted handler already incremented consecutive_failures
# on the MetaInstance, so skipping here avoids double-counting.
return state
# Remove all runners belonging to this instance from state
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
new_runners: Mapping[RunnerId, RunnerStatus] = {
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
}
update: dict[str, object] = {"runners": new_runners}
# Increment failure count on the MetaInstance
if event.meta_instance_id in state.meta_instances:
update["meta_instances"] = _update_meta_instance(
state,
event.meta_instance_id,
consecutive_failures=state.meta_instances[
event.meta_instance_id
].consecutive_failures
+ 1,
last_failure_error=event.failure_error,
)
return state.model_copy(update=update)
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
**state.meta_instances,
event.meta_instance.meta_instance_id: event.meta_instance,
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
mid: mi
for mid, mi in state.meta_instances.items()
if mid != event.meta_instance_id
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_placement_failed(
event: MetaInstancePlacementFailed, state: State
) -> State:
if event.meta_instance_id not in state.meta_instances:
return state
return state.model_copy(
update={
"meta_instances": _update_meta_instance(
state, event.meta_instance_id, placement_error=event.reason
)
}
)
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:

View File

@@ -44,7 +44,8 @@ async def _refresh_card_cache():
async for toml_file in path.rglob("*.toml"):
try:
card = await ModelCard.load_from_path(toml_file)
_card_cache[card.model_id] = card
if card.model_id not in _card_cache:
_card_cache[card.model_id] = card
except (ValidationError, TOMLKitError):
pass
@@ -182,6 +183,7 @@ class ConfigData(BaseModel):
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["GlmMoeDsaForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],

View File

@@ -1,12 +1,12 @@
import time
from collections.abc import Generator
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, get_args
from uuid import uuid4
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -262,24 +262,25 @@ class DeleteInstanceResponse(BaseModel):
instance_id: InstanceId
class CreateMetaInstanceParams(BaseModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
ImageSize = Literal[
"auto",
"512x512",
"768x768",
"1024x768",
"768x1024",
"1024x1024",
"1024x1536",
"1536x1024",
]
class CreateMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class DeleteMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
def normalize_image_size(v: object) -> ImageSize:
"""Shared validator for ImageSize fields: maps None → "auto" and rejects invalid values."""
if v is None:
return "auto"
if v not in get_args(ImageSize):
raise ValueError(f"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}")
return v # pyright: ignore[reportReturnType]
class AdvancedImageParams(BaseModel):
@@ -301,7 +302,7 @@ class ImageGenerationTaskParams(BaseModel):
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
size: ImageSize = "auto"
stream: bool | None = False
style: str | None = "vivid"
user: str | None = None
@@ -309,6 +310,11 @@ class ImageGenerationTaskParams(BaseModel):
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
@field_validator("size", mode="before")
@classmethod
def normalize_size(cls, v: object) -> ImageSize:
return normalize_image_size(v)
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
bench: bool = True
@@ -325,13 +331,18 @@ class ImageEditsTaskParams(BaseModel):
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
size: ImageSize = "auto"
image_strength: float | None = 0.7
stream: bool = False
partial_images: int | None = 0
advanced_params: AdvancedImageParams | None = None
bench: bool = False
@field_validator("size", mode="before")
@classmethod
def normalize_size(cls, v: object) -> ImageSize:
return normalize_image_size(v)
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
if name == "image_data":

View File

@@ -76,4 +76,13 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
class PrefillProgressChunk(BaseChunk):
"""Data class for prefill progress events during streaming."""
processed_tokens: int
total_tokens: int
GenerationChunk = (
TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk | PrefillProgressChunk
)

View File

@@ -6,8 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -53,14 +52,6 @@ class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class CreateMetaInstance(BaseCommand):
meta_instance: MetaInstance
class DeleteMetaInstance(BaseCommand):
meta_instance_id: MetaInstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -103,8 +94,6 @@ Command = (
| CreateInstance
| DeleteInstance
| TaskCancelled
| CreateMetaInstance
| DeleteMetaInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -42,10 +42,6 @@ class CommandId(Id):
pass
class MetaInstanceId(Id):
"""Identifier for a MetaInstance."""
class Host(CamelCaseModel):
ip: str
port: int

View File

@@ -1,14 +1,11 @@
import base64
from collections.abc import Mapping
from datetime import datetime
from typing import Annotated, final
from typing import final
from pydantic import BeforeValidator, Field, PlainSerializer
from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -17,28 +14,6 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
def _decode_base64_bytes(v: bytes | str) -> bytes:
if isinstance(v, bytes):
return v
return base64.b64decode(v)
def _encode_base64_bytes(v: bytes) -> str:
return base64.b64encode(v).decode("ascii")
Base64Bytes = Annotated[
bytes,
BeforeValidator(_decode_base64_bytes),
PlainSerializer(_encode_base64_bytes, return_type=str),
]
"""bytes that serialize to/from base64 strings in JSON.
Needed because TaggedModel's wrap validator converts JSON→Python validation
context, which breaks strict-mode bytes deserialization from JSON strings.
"""
class EventId(Id):
"""
Newtype around `ID`
@@ -91,30 +66,6 @@ class InstanceCreated(BaseEvent):
class InstanceDeleted(BaseEvent):
instance_id: InstanceId
failure_error: str | None = None
class MetaInstanceCreated(BaseEvent):
meta_instance: MetaInstance
class MetaInstanceDeleted(BaseEvent):
meta_instance_id: MetaInstanceId
@final
class MetaInstancePlacementFailed(BaseEvent):
meta_instance_id: MetaInstanceId
reason: str
@final
class InstanceRetrying(BaseEvent):
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
instance_id: InstanceId
meta_instance_id: MetaInstanceId
failure_error: str
class RunnerStatusUpdated(BaseEvent):
@@ -151,6 +102,13 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
model: ModelId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -181,25 +139,6 @@ class TracesMerged(BaseEvent):
traces: list[TraceEventData]
@final
class JacclSideChannelData(BaseEvent):
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
instance_id: InstanceId
runner_id: RunnerId
sequence: int
data: Base64Bytes
@final
class JacclSideChannelGathered(BaseEvent):
"""Gathered result of a JACCL SideChannel all_gather round."""
instance_id: InstanceId
sequence: int
gathered_data: Mapping[RunnerId, Base64Bytes]
Event = (
TestEvent
| TaskCreated
@@ -209,10 +148,6 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceRetrying
| MetaInstanceCreated
| MetaInstanceDeleted
| MetaInstancePlacementFailed
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut
@@ -220,12 +155,11 @@ Event = (
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected
| TracesMerged
| JacclSideChannelData
| JacclSideChannelGathered
)

View File

@@ -1,25 +0,0 @@
from typing import final
from pydantic import Field
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.worker.instances import InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import FrozenModel
@final
class MetaInstance(FrozenModel):
"""Declarative constraint: ensure an instance matching these parameters always exists."""
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
# Failure tracking
placement_error: str | None = None
consecutive_failures: int = 0
last_failure_error: str | None = None

View File

@@ -4,10 +4,13 @@ from collections.abc import Sequence
from mlx_lm.models.cache import (
ArraysCache,
CacheList,
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
KVCacheType = Sequence[
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
]

View File

@@ -6,8 +6,7 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import (
DiskUsage,
MemoryUsage,
@@ -42,7 +41,6 @@ class State(CamelCaseModel):
arbitrary_types_allowed=True,
)
instances: Mapping[InstanceId, Instance] = {}
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}

View File

@@ -61,7 +61,7 @@ class TextGeneration(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask): # emitted by Worker when master cancels a task
class CancelTask(BaseTask):
cancelled_task_id: TaskId
runner_id: RunnerId

Some files were not shown because too many files have changed in this diff Show More