Compare commits

..

34 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
31cd8b9d68 Delete model earlier in shutdown case 2026-02-19 19:56:07 +00:00
Evan Quiney
4c4c6ce99f simplify rust ident module
this is partly dead code, partly narrowing the rust-python boundary in
prep for future rewrites. no testing as this is all type safe
refactoring.
2026-02-19 17:19:31 +00:00
Jake Hillion
42e1e7322b bench: restore --danger-delete-downloads planning phase (#1542)
c2f2111b extracted shared utilities from exo_bench.py into harness.py
but accidentally dropped the run_planning_phase function and
--danger-delete-downloads CLI argument in the process.

Restored run_planning_phase in harness.py (where its dependencies now
live) and re-added the --danger-delete-downloads argument to
add_common_instance_args. Re-wired the planning phase call in
exo_bench.py's main() before the benchmark loop.
2026-02-19 15:42:02 +00:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
Mustafa Alp Yılmaz
2e29605194 fix: finalize cancel tasks (#1498)
# Cancel task finalization (main.py)

After forwarding the cancel to the runner supervisor, emit TaskStatusUpdated(Complete) for the cancel task itself. This ensures the cancel task is properly removed from state.tasks.
2026-02-19 13:27:34 +00:00
Evan Quiney
cacb456cb2 remove nightly (#1538)
we have no good need for rust nightly (nor futures, for that matter)
2026-02-19 12:55:31 +00:00
rltakashige
51021f6fc6 Add cancellation button and the ability to cancel during prefill (#1540)
## Motivation
There's no way to easily use the cancellation features we added! Also,
prefill can take ages so let's allow cancelling out of that.

## Changes

Wiring up our existing functionality to easily cancel during generation
(and adding stuff to do so during prefill)

## Test Plan

### Manual Testing
Tested it works during both prefill and decode.

### Automated testing
Needs testing to see if this causes a GPU timeout error on large prefill
on large models in pipeline parallel. However, from manually testing GLM
5 pipeline ring on 2 nodes, and from reading the code, it does not seem
like this will be the case.
2026-02-19 11:40:59 +00:00
Alex Cheema
025ed9fd82 feat: add prefill progress bar for long prompts (#1181)
## Motivation

Users processing long prompts have no visibility into when token
generation will start. This feature adds a progress bar showing prefill
progress, giving users real-time feedback during prompt processing.

## Changes

### Backend
- Added `PrefillProgress` event type with `command_id`,
`processed_tokens`, `total_tokens`
- Added `PrefillProgressResponse` type (though now using direct callback
approach)
- Wired `prompt_progress_callback` through MLX's `stream_generate()`
- Progress events sent directly from callback for real-time updates (not
batched)
- API generates SSE named events: `event: prefill_progress\ndata: {...}`
- Added `PrefillProgressData` dataclass and `StreamEvent` union type in
API

### Dashboard
- Added `PrefillProgress` interface to store
- Updated SSE parsing to handle `event:` lines (named events)
- Created `PrefillProgressBar.svelte` with animated progress bar
- Shows "Processing prompt: X/Y tokens" with percentage
- Progress bar disappears when first token arrives

## Why It Works

MLX's `stream_generate()` accepts a `prompt_progress_callback(processed,
total)` that's called after each prefill chunk. By sending events
directly from this callback (rather than yielding from the generator),
progress updates are sent in real-time during prefill.

Using SSE named events (`event: prefill_progress`) maintains full
OpenAI/Claude API compatibility - standard clients ignore named events
they don't recognize, while the exo dashboard explicitly listens for
them.

## Test Plan

### Manual Testing
- Hardware: MacBook Pro M3 Max
- Set `prefill_step_size=256` for more frequent updates
- Tested with long prompts (pasted large documents)
- Verified progress bar updates incrementally during prefill
- Confirmed progress bar disappears when generation starts
- Tested with curl - standard `data:` events still work normally

Here is it working:


https://github.com/user-attachments/assets/5cc6f075-c5b2-4a44-bb4d-9efb246bc5fe


### Automated Testing
- Type checker passes (0 errors)
- All 192 tests pass
- Dashboard builds successfully

### API Compatibility
- Named SSE events are ignored by OpenAI SDK clients
- Regular token data uses standard `data: {...}` format
- `[DONE]` sentinel works as expected

---

**Note:** `prefill_step_size` is temporarily set to 256 for testing.
Should be changed back to 2048 before merging for production
performance.

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-19 03:18:25 +00:00
rltakashige
19bc09550d Add status=downloaded filter for model endpoint (#1539)
## Motivation

https://github.com/exo-explore/exo/issues/1346#issuecomment-3831427905


## Test Plan

### Manual Testing
**Without filter**
<img width="1708" height="1010" alt="Screenshot 2026-02-18 at 22 26 22"
src="https://github.com/user-attachments/assets/f4bf7142-717d-4042-ac28-d8a55a8e45e7"
/>

**With filter**
<img width="1723" height="1021" alt="Screenshot 2026-02-18 at 22 26 45"
src="https://github.com/user-attachments/assets/40a522d5-c6e6-4148-b21a-02caa1221ebe"
/>
2026-02-18 22:34:11 +00:00
Alex Cheema
7cadca4f27 Try multiple endpoints for internet connectivity check (#1516)
## Summary
- `_test_internet_connection()` previously only tried `1.1.1.1:443`,
which some ISPs/networks block, causing exo to incorrectly report no
internet and fail downloads on startup
- Now tries `1.1.1.1`, `8.8.8.8`, and `1.0.0.1` in sequence, succeeding
if any endpoint responds
- Returns early on first success for minimal latency in the common case

Fixes #1425

## Test plan
- [ ] Verify downloads work on networks that block `1.1.1.1`
- [ ] Verify existing behavior unchanged on networks where `1.1.1.1`
works
- [ ] Verify `internet_connection` is set to `False` only when all three
endpoints fail

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 22:10:07 +00:00
rltakashige
24e99ce197 Cleanup mistakes (#1537)
Oops
2026-02-18 22:05:26 +00:00
Alex Cheema
315992549b fix: unblock MpReceiver.close() to prevent shutdown hang (#1511)
## Summary

- `MpReceiver.close()` did not unblock threads stuck on `queue.get()` in
`receive_async()`, causing abandoned threads (via
`abandon_on_cancel=True`) to keep the Python process alive indefinitely
after tests pass
- This caused the `aarch64-darwin` CI jobs in PR #1462 to hang for ~6
hours until the GitHub Actions timeout killed them
- Sends an `_MpEndOfStream` sentinel before closing the buffer,
mirroring what `MpSender.close()` already does

## Test plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — clean
- [x] `nix fmt` — 0 changed
- [x] `uv run pytest` — 188 passed, 1 skipped in 12s (no hang)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-18 21:59:02 +00:00
Alex Cheema
ce5a65d3b9 Add MiniMax M2.5 model cards (#1514)
## Summary
- Adds model cards for MiniMax M2.5 in three quantizations: 4bit (~129
GB), 6bit (~186 GB), 8bit (~243 GB)
- No code changes needed — `MiniMaxM2ForCausalLM` is already in the
tensor parallel whitelist and `MiniMaxShardingStrategy` is already
implemented in `auto_parallel.py`
- Credit to @vskiwi for confirming MiniMax M2.5 works out of the box
with existing code

Closes #1480

## Test plan
- [x] `basedpyright` passes with 0 errors
- [x] `ruff check` passes
- [x] `pytest` passes (260 passed, 1 skipped)
- [ ] Verify MiniMax M2.5 models appear in model selector on dashboard

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 21:11:13 +00:00
rltakashige
c2f2111b88 Fix tool calling (#1529)
## Motivation

GPT OSS tool calling issues.

## Changes

Fixes those and adds a bunch of evals for tool calling.
Fixes GLM5 prefix caching, where CacheList wasn't getting handled
properly.
Extracts a bunch of the setup functionality of exo bench to a harness
that can be reused elsewhere, such as in the tool calling eval.

## Test Plan
### Automated Testing
Let's run the evals for all models
2026-02-18 20:29:18 +00:00
Alex Cheema
6c322ebb72 feat: only show thinking toggle for models that support it (#1497)
## Summary
- Adds `thinking_toggle` capability to 26 model cards that support
toggling thinking mode on/off
- GPT-OSS models (20b, 120b) excluded — they always think and don't
support toggling
- Dashboard UI updated to check for `thinking_toggle` capability before
showing the toggle button

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 0 failed
- [x] Security review passed (no secrets, eval/exec, innerHTML, or dep
changes)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 17:05:00 +00:00
vskiwi
2ebe6216b4 feat: add explicit --offline mode for air-gapped clusters (#1525)
## Motivation

Closes #1510

There is currently no reliable way to run exo on an air-gapped or offline cluster where models are pre-staged on local disks. The two existing mechanisms — `--no-downloads` and `HF_HUB_OFFLINE=1` — each cover only a subset of the problem:

1. **`--no-downloads` blocks model loading**: When passed, `DownloadCoordinator` is not created. No `NodeDownloadProgress` events are ever emitted, so `_model_needs_download()` in `plan.py` perpetually returns `DownloadModel`, short-circuiting `_load_model()` and preventing the model from ever being loaded.

2. **`HF_HUB_OFFLINE=1` doesn't cover exo's aiohttp code**: exo's download pipeline primarily uses raw `aiohttp` for HTTP operations (file list fetching, file downloads, HEAD verification), not the `huggingface_hub` library. These calls will attempt connections and time out on air-gapped networks.

3. **`skip_internet` is not propagated to `download_file_with_retry()`**: Even when `internet_connection = False`, the `_download_file()` function still makes HTTP HEAD calls via `file_meta()` to verify local files and unconditionally attempts downloads for missing files.

## Changes

### `src/exo/main.py`
- Add `--offline` flag to `Args` with env var detection (`EXO_OFFLINE=1`, `HF_HUB_OFFLINE=1`)
- Pass `offline` to `DownloadCoordinator` at creation and re-creation (election loop)

### `src/exo/download/coordinator.py`
- Add `offline: bool = False` field
- In offline mode: set `internet_connection = False` immediately in `__post_init__`, skip `_test_internet_connection()` ping (avoids 3s timeout), skip `_check_internet_connection` periodic loop
- In `_start_download()`: if model is not fully available locally, emit `DownloadFailed` with clear message instead of starting a download task

### `src/exo/download/download_utils.py`
- Add `skip_internet: bool` parameter to `download_file_with_retry()` and `_download_file()`
- When `skip_internet=True` in `_download_file()`: return local file immediately without HTTP HEAD verification; raise `FileNotFoundError` for missing files
- Propagate `skip_internet` from `download_shard()` to `download_file_with_retry()`

### `src/exo/download/tests/test_offline_mode.py` (new)
- 8 tests covering `_download_file`, `download_file_with_retry`, and `fetch_file_list_with_cache` in offline mode

## Why It Works

Unlike `--no-downloads` which disables `DownloadCoordinator` entirely, `--offline` keeps the coordinator running in a restricted mode. The existing `_emit_existing_download_progress()` disk scanner still runs every 60 seconds, emitting `DownloadCompleted` events for pre-staged models. These events flow through the event-sourcing pipeline and populate `state.downloads`, which unblocks `_model_needs_download()` in `plan.py` — no changes to the planning logic required.

```
--offline flag
  → DownloadCoordinator (offline mode)
    → Skip 1.1.1.1 ping, internet_connection = False
    → _emit_existing_download_progress scans disk
      → Emits DownloadCompleted for pre-staged models
        → _model_needs_download sees DownloadCompleted
          → _load_model proceeds normally
```

## Test Plan

### Automated Testing
- `ruff check` — passes
- 8 new tests in `test_offline_mode.py` — all pass
- 11 existing download tests in `test_download_verification.py` — all pass (no regressions)

### Manual Testing
1. Pre-stage a model on disk (e.g., `~/.exo/models/mlx-community--Qwen3-0.6B-4bit/`)
2. Start exo with `--offline` (or `EXO_OFFLINE=1`)
3. Place an instance via API or dashboard
4. Verify: model loads into memory and inference works without any network calls

### Environment
- macOS (Apple Silicon), multi-node cluster with Thunderbolt interconnect
- Models pre-staged via rsync / NFS mount
2026-02-18 16:18:09 +00:00
ciaranbor
f54c80b121 Ciaran/image edit api (#1500)
## Motivation

- Image editing previously ignored input image dimensions, always
defaulting to 1024x1024
- Size dropdown was hidden in edit mode, giving users no control over
output dimensions
- Portrait/landscape presets used non-standard aspect ratios (1024x1365
/ 1365x1024)

## Changes

- Added "auto" size option that uses input image dimensions for edits,
defaults to 1024x1024 for generation
- Introduced ImageSize Literal type and normalize_image_size() validator
(replaces raw str size fields)
  - Updated portrait/landscape presets to standard 1024x1536 / 1536x1024
  - Made size selector visible in edit mode (previously hidden)
  - Default size changed from "1024x1024" to "auto"

## Why It Works

- "auto" reads actual input image dimensions via PIL at generation time,
so edits preserve the original aspect ratio
- Pydantic field_validator on both ImageGenerationTaskParams and
ImageEditsTaskParams normalizes None → "auto", keeping the API
backward-compatible

## Test Plan

### Manual Testing

- Verify image edits output at the input image's native resolution when
size is "auto"
- Verify size dropdown appears and works in both generate and edit modes
2026-02-18 16:05:39 +00:00
rltakashige
48b8f86395 Add support for GLM 5 (#1526)
## Motivation

Add GLM 5 support in favor of #1513 

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-18 14:04:06 +00:00
Evan
5cbd6377a2 prioritize official model cards over custom model cards
our old model card search path would override official model cards with
custom model cards - our packaged model cards should always be the
default here
2026-02-18 13:20:05 +00:00
Evan Quiney
8f01523ddb remove dead code (#1496) 2026-02-18 11:43:27 +00:00
Alex Cheema
3addeadea8 Update mlx-lm to 0.30.7 (#1520)
## Summary
- Bumps `mlx-lm` from 0.30.6 to 0.30.7 in `pyproject.toml` and `uv.lock`

## Test plan
- [x] `uv lock` resolves successfully
- [x] `basedpyright` — no new errors (63 pre-existing in unrelated
`test_tool_call_tracker.py`)
- [x] `ruff check` — all checks passed
- [x] `nix fmt` — no formatting changes
- [x] `pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 11:14:23 +00:00
rltakashige
f2be929211 Leo/address rdma gpu locks 2 (#1515)
Same as #1489 . Had to revert and redo thanks to Claude.

---------

Co-authored-by: Jake Hillion <jake@hillion.co.uk>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:00:52 -08:00
rltakashige
83af8c63fa Revert "Use custom fork that resolves GPU locks" (#1502)
Reverts exo-explore/exo#1489

Goddammit Claude...
2026-02-17 18:18:54 +00:00
Evan Quiney
eccc6298d1 Revert "Add MetaInstance declarative layer (#1447)"
This reverts commit a962a28afc.
2026-02-17 18:11:47 +00:00
Evan Quiney
c8997217cf Revert "feat: better onboarding UX for new users (#1479)"
This reverts commit 490d2e46ba.
2026-02-17 18:02:32 +00:00
Alex Cheema
490d2e46ba feat: better onboarding UX for new users (#1479)
## Summary

- **Auto-open dashboard** in browser on first launch (uses
`~/.exo/.dashboard_opened` marker)
- **Welcome overlay** with "Choose a Model" CTA button when no model
instance is running
- **Tutorial progress messages** during model download → loading → ready
lifecycle stages
- **Fix conversation sidebar** text contrast — bumped to white text,
added active state background
- **Simplify technical jargon** — sharding/instance type/min nodes
hidden behind collapsible "Advanced Options" toggle; strategy display
hidden behind debug mode
- **Polished DMG installer** with drag-to-Applications layout, custom
branded background, and AppleScript-configured window positioning

## Test plan

- [ ] Launch exo for the first time (delete `~/.exo/.dashboard_opened`
to simulate) — browser should auto-open
- [ ] Verify welcome overlay appears on topology when no model is loaded
- [ ] Launch a model and verify download/loading/ready messages appear
in instance cards
- [ ] Check conversation sidebar text is readable (white on dark, yellow
when active)
- [ ] Verify "Advanced Options" toggle hides/shows sharding controls
- [ ] Build DMG with `packaging/dmg/create-dmg.sh` and verify
drag-to-Applications layout

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 17:52:49 +00:00
rltakashige
facf2d4d03 Use custom fork that resolves GPU locks (#1489)
## Motivation

There is an issue on Macs that means that an explicit synchronization is
necessary for memory to be updated from L1 cache. This means that GPU
locks can occur when a spin wait does not see the updated timestamp.

## Changes

Updated in my own personal fork.

## Why It Works

https://github.com/ARM-software/acle/releases

## Test Plan

### Manual Testing
Tested manually that no GPU locks occur (even with multiple simultaneous
instances running) and that the performance differential is negligible
(267 vs 269 tps on Llama 3.2 1B at an approx 10k context.)


------------------------------------------------------
I have seen a GPU lock, specifically when sending a particularly large
chat completion while the model was loading. However, I have since been
unable to reproduce and this may be something I did wrong. Please do
create an issue and tag me if any GPU locks do occur.

---------

Co-authored-by: Jake Hillion <jake@hillion.co.uk>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 17:48:43 +00:00
Alex Cheema
a962a28afc Add MetaInstance declarative layer (#1447)
## Motivation

Users currently manage instances directly, which means if a node
disconnects or connections break, the instance dies and nothing
recreates it. MetaInstance is a declarative primitive: "ensure an
instance matching these parameters always exists." The reconciler
watches for unhealthy or missing backing instances and re-places them
automatically.

## Changes

- **MetaInstance type** (`meta_instance.py`): declarative constraint
with `model_id`, `min_nodes`, optional `node_ids`, and `sharding`
- **Reconciler** (`reconcile.py`): `find_unsatisfied_meta_instances`
checks which MetaInstances lack a healthy backing instance,
`try_place_for_meta_instance` creates one
- **Master loop** (`main.py`): periodically reconciles unsatisfied
MetaInstances; immediate placement on `CreateMetaInstance` command
- **API** (`api.py`): `create_meta_instance` / `delete_meta_instance` /
`GET /meta_instances` endpoints; delete cascades to backing instances
with task cancellation
- **Binding via `meta_instance_id` on Instance** (`instances.py`): no
separate binding event or backing map — the instance carries its parent
MetaInstance ID directly, eliminating race conditions in the reconciler
- **Dashboard**: sidebar shows MetaInstances with their backing instance
status; orphan instances (created directly) still shown separately
- **Tests**: constraint matching, connection health, unsatisfied
detection, exclusive binding, cascade delete with task cancellation

### Recent improvements

- **fix: cancel active tasks on cascade delete** — `DeleteMetaInstance`
now emits `TaskStatusUpdated(Cancelled)` for any Pending/Running tasks
on backing instances before emitting `InstanceDeleted`. Previously,
cascade-deleting backing instances left orphaned task references in
state.
- **Lifecycle logging** — added `logger.info`/`logger.warning` for:
`CreateMetaInstance` (model, min_nodes, sharding), `DeleteMetaInstance`
(with cascade count), reconciler placement success/failure, and retry
decisions with attempt counts in `InstanceHealthReconciler`.
- **GET `/meta_instances` endpoint** — lists all meta-instances without
needing to fetch full state.
- **2 regression tests** — `test_cascade_delete_cancels_active_tasks`
and `test_cascade_delete_skips_completed_tasks` verify the
cascade-delete event sequence.

## Why It Works

Putting `meta_instance_id` on `BaseInstance` makes binding inherent to
instance creation. When the reconciler creates an instance for a
MetaInstance, it tags it via `model_copy`. When the instance is deleted,
the binding disappears with it. This avoids the two bugs that a separate
binding mechanism would introduce:
1. Stale exclusion sets — the reconciler loop can't accidentally bind
two MetaInstances to the same instance
2. Delete ordering race — no window between deleting an instance and its
binding where the reconciler could re-place

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
- Created MetaInstance via dashboard, verified instance placed
- Verified delete cascades (deleting MetaInstance removes backing
instance)
- Verified orphan instances still work independently

### Automated Testing
- 30 tests in `test_meta_instance_edge_cases.py`: lifecycle, retry
logic, error handling, concurrent operations, cascade delete with task
cancellation
- 24 tests in `test_reconcile.py`: constraint matching, connection
health (single/multi-node, edge removal, IP changes), unsatisfied
detection, exclusive binding, idempotency
- All 261 tests pass
- basedpyright 0 errors, ruff clean, dashboard builds

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 09:48:19 -08:00
Alex Cheema
db79c350c1 Fix graceful process shutdown in macOS app (#1372)
## Motivation

Fixes #1370

When the macOS app stops exo, GPU/system memory isn't released. This
happens because:

1. The macOS app calls `process.terminate()` (SIGTERM) but the Python
process only registers a graceful shutdown handler for SIGINT, not
SIGTERM. SIGTERM's default Python behavior raises `SystemExit` which
bypasses the cleanup cascade (runner subprocess MLX cleanup via
`mx.clear_cache()`, channel closing, etc.).
2. The app doesn't wait for the process to actually finish cleanup — it
immediately nils out the process reference.

## Changes

**`src/exo/main.py`**: Register SIGTERM handler alongside SIGINT so the
graceful shutdown cascade (`Node.shutdown()` → cancel task group →
worker/runner cleanup → `mx.clear_cache()` + `gc.collect()`) runs
regardless of which signal is received.

**`app/EXO/EXO/ExoProcessController.swift`**: Replace immediate
`process.terminate()` with escalating shutdown per @Evanev7's
suggestion:
1. Send SIGINT via `process.interrupt()` — triggers the registered
Python handler for graceful cleanup
2. Wait up to 5 seconds for the process to exit
3. If still running, escalate to SIGTERM via `process.terminate()`
4. Wait up to 3 seconds
5. If still running, force kill via SIGKILL

The escalation runs in a detached `Task` so the UI updates immediately
(status → stopped) without blocking.

## Why It Works

The root cause is that SIGTERM wasn't triggering the graceful shutdown
path. By registering a SIGTERM handler in Python and sending SIGINT
first from the macOS app, the process gets a chance to run the full
cleanup cascade: cancelling the task group, shutting down runners (which
call `del model; mx.clear_cache(); gc.collect()`), closing channels, and
flushing logs. The escalation to SIGTERM and SIGKILL ensures the process
always terminates even if graceful shutdown hangs.

## Test Plan

### Manual Testing
<!-- Hardware: Mac Studio M4 Max 128GB -->
- Start exo via macOS app, load a model, run inference
- Stop via the toggle switch, verify memory is released without
requiring a system restart
- Test rapid stop/start (restart) to ensure no race conditions

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — no changes

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan Quiney <evanev7@gmail.com>
2026-02-17 09:03:54 -08:00
Alex Cheema
d6301ed593 dashboard: redesign downloads page as model×node table (#1465)
## Motivation

The current downloads page uses a node-centric card grid layout that is
messy and hard to read — the same model across different nodes appears
in separate cards, and deep nesting wastes space. This makes it
difficult to quickly see which models are on which nodes.

## Changes

Rewrote the downloads page
(`dashboard/src/routes/downloads/+page.svelte`) from a card grid to a
clean table layout:

- **Rows** = models (unique across all nodes)
- **Columns** = nodes (with disk free shown in header)
- **Cells** show status at a glance:
  -  Green checkmark + size for completed downloads
  - 🟡 Yellow percentage + mini progress bar + speed for active downloads
  - `...` for pending downloads
  -  Red X for failed downloads
  - `--` for models not present on a node
- Delete/download action buttons appear on row hover
- Model name column is sticky on horizontal scroll (for many-node
clusters)
- Models sorted by number of nodes with completed downloads
- Imported shared utilities from `$lib/utils/downloads` instead of
inline re-implementations

### Backend: model directory in download events

- Added `model_directory` field to `BaseDownloadProgress` so all
download status events include the on-disk path
- Added `_model_dir()` helper to `DownloadCoordinator` to compute the
path from `EXO_MODELS_DIR`
- Dashboard uses this to show file location and enable "open in Finder"
for completed downloads

### Info modal

- Clicking a model name opens an info modal showing card details
(family, quantization, capabilities, storage size, layer count, tensor
parallelism support)

### Other fixes

- Fixed model name truncation in the table
- Excluded `tests/start_distributed_test.py` from pytest collection (CLI
script that calls `sys.exit()` at import time)

## Test Plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all passed
- [x] `nix fmt` — clean
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:31:47 +00:00
Evan Quiney
6d1ca6689b don't time out node identities (#1493)
currently nodes leaving and rejoining the cluster can lose their identity. We have no need to delete this data on node timing out, so let's just persist it.
2026-02-17 11:48:28 +00:00
Evan
c01b6fff21 eprint banner
our banner was being printed to stdout but should be printed to stderr
as its essentially a log message
2026-02-17 11:43:06 +00:00
Jake Hillion
8392e78afe bench: add spec for automatic canary benchmarks (#1483)
Adds all the models that can fit onto a single M3 Ultra for single
machine benchmarks. Fixes the macOS version, GPU spec, and chip type for
maximum reproducibility. Specifies the minimum memory accordingly for
each type of model, using the smallest machine available (the smallest
M3 Ultra is 96GiB).

Test plan:
- Running this with some code that makes machines of this spec available
and stores the results. It works.

This will become part of a larger testing/stability strategy once we've
collected more of the data.
2026-02-17 10:52:05 +00:00
Evan
86735ece78 begins
begins
2026-02-16 19:26:19 +00:00
139 changed files with 4943 additions and 4143 deletions

View File

@@ -1,15 +0,0 @@
.venv/
.direnv/
target/
.git/
.idea/
.pytest_cache/
.ruff_cache/
dashboard/node_modules/
dashboard/.svelte-kit/
dashboard/build/
dist/
*.pdb
**/__pycache__
**/.DS_Store
.mlx_typings/

View File

@@ -1,44 +0,0 @@
name: e2e-tests
on:
push:
branches:
- e2e-tests
pull_request:
branches:
- staging
- main
jobs:
e2e:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
/opt/microsoft /opt/az
docker system prune -af
df -h /
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build E2E image with cache
uses: docker/build-push-action@v6
with:
context: .
file: e2e/Dockerfile
tags: exo-e2e:latest
load: true
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Run E2E tests
run: python3 e2e/run_all.py

View File

@@ -200,7 +200,7 @@ class Module(dict):
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
"""Return the submodules that do not contain other modules."""
def update(self, parameters: dict, strict: bool = ...) -> Module:
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.

View File

@@ -7,7 +7,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mlx.core import MX_ARRAY_TREE
def tree_map(
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[..., bool] | None = ...,
) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -44,11 +47,11 @@ def tree_map(
"""
def tree_map_with_path(
fn: Callable,
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Optional[Callable] = ...,
path: Optional[Any] = ...,
is_leaf: Callable[..., bool] | None = ...,
path: str | None = ...,
) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -80,9 +83,9 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
prefix: str = ...,
is_leaf: Optional[Callable] = ...,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
is_leaf: Callable[..., bool] | None = ...,
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
) -> list[tuple[str, Any]] | dict[str, Any]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -118,7 +121,7 @@ def tree_flatten(
the Python tree.
"""
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python

View File

@@ -0,0 +1,46 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

136
Cargo.lock generated
View File

@@ -141,12 +141,6 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.7.1"
@@ -304,19 +298,6 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
[[package]]
name = "bigdecimal"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
dependencies = [
"autocfg",
"libm",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.3"
@@ -516,15 +497,6 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
[[package]]
name = "convert_case"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -746,29 +718,6 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_more"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.111",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -939,22 +888,17 @@ name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"env_logger",
"extend",
"futures",
"impl-trait-for-tuples",
"futures-lite",
"libp2p",
"log",
"networking",
"once_cell",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"thiserror 2.0.17",
"thread_local",
"tokio",
"util",
]
@@ -970,6 +914,12 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
version = "0.13.1"
@@ -1078,7 +1028,10 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
@@ -1640,17 +1593,6 @@ dependencies = [
"xmltree",
]
[[package]]
name = "impl-trait-for-tuples"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "indexmap"
version = "2.12.1"
@@ -1829,12 +1771,6 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libp2p"
version = "0.56.0"
@@ -2824,16 +2760,13 @@ name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-lite",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"libp2p",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
@@ -2918,17 +2851,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -3279,28 +3201,14 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
dependencies = [
"bigdecimal",
"either",
"hashbrown 0.16.1",
"indexmap",
"indoc",
"inventory",
"libc",
"lock_api",
"memoffset",
"num-bigint",
"num-complex",
"num-rational",
"num-traits",
"once_cell",
"ordered-float",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"rust_decimal",
"smallvec",
"unindent",
]
@@ -3741,16 +3649,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -4615,24 +4513,12 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"

View File

@@ -26,49 +26,20 @@ opt-level = 3
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-util = "0.3"
futures-lite = "2.6.1"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"

View File

@@ -72,16 +72,30 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly:
```bash
nix run .#exo
```
**Note:** To accept the Cachix binary cache (and avoid the Xcode Metal ToolChain), add to `/etc/nix/nix.conf`:
```
trusted-users = root (or your username)
experimental-features = nix-command flakes
```
Then restart the Nix daemon: `sudo launchctl kickstart -k system/org.nixos.nix-daemon`
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```

View File

@@ -126,11 +126,37 @@ final class ExoProcessController: ObservableObject {
return
}
process.terminationHandler = nil
if process.isRunning {
process.terminate()
}
self.process = nil
status = .stopped
guard process.isRunning else {
self.process = nil
return
}
let proc = process
self.process = nil
Task.detached {
proc.interrupt()
for _ in 0..<50 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
proc.terminate()
}
for _ in 0..<30 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
kill(proc.processIdentifier, SIGKILL)
}
}
}
func restart() {

7
bench/bench.toml Normal file
View File

@@ -0,0 +1,7 @@
# Canary benchmark manifest
#
# Lists the suite files to include. Each file defines benchmarks
# with shared constraints, topology, and default args.
include = [
"single-m3-ultra.toml",
]

1104
bench/eval_tool_calls.py Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,29 +1,48 @@
# type: ignore
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""Tool-calling eval for exo's OpenAI-compatible API.
Tests whether models correctly:
- Trigger tool calls when appropriate
- Return valid JSON arguments matching function schemas
- Handle multi-turn tool use (call -> result -> final answer)
- Avoid calling tools when unnecessary
Start exo with a model first, then run:
uv run python tool_call_eval.py --model <model-id>
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
uv run python tool_call_eval.py --model <model-id> --repeat 3
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
"""
from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from harness import (
ExoClient,
ExoHttpError,
add_common_instance_args,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -103,154 +122,6 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
@@ -269,184 +140,6 @@ def parse_int_list(values: list[str]) -> list[int]:
return items
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
@@ -538,76 +231,12 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
add_common_instance_args(ap)
ap.add_argument(
"--pp",
nargs="+",
@@ -620,34 +249,6 @@ def main() -> int:
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
@@ -657,9 +258,6 @@ def main() -> int:
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
@@ -674,17 +272,6 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -719,24 +306,10 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_deadline:
backoff = _SETTLE_INITIAL_BACKOFF_S
while not selected and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
@@ -760,6 +333,10 @@ def main() -> int:
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,

477
bench/harness.py Normal file
View File

@@ -0,0 +1,477 @@
# type: ignore
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from typing import Any
from urllib.parse import urlencode
from loguru import logger
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if (m.get("name") or "").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def settle_and_fetch_placements(
client: ExoClient,
full_model_id: str,
args: argparse.Namespace,
settle_timeout: float = 0,
) -> list[dict[str, Any]]:
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
description = "Benchmarking tool for exo distributed inference"
requires-python = ">=3.13"
dependencies = [
"httpx>=0.27.0",
"loguru>=0.7.3",
"transformers>=5.0.0",
"huggingface-hub>=0.33.4",

306
bench/scenarios.toml Normal file
View File

@@ -0,0 +1,306 @@
# Tool definitions — each becomes an OpenAI function tool.
# All scenarios get all tools unless they specify a `tools` list.
[tools.get_current_weather]
description = "Get the current weather in a given location"
required = ["location"]
[tools.get_current_weather.properties.location]
type = "string"
description = "City and state, e.g. San Francisco, CA"
[tools.get_current_weather.properties.unit]
type = "string"
enum = ["celsius", "fahrenheit"]
description = "Temperature unit"
[tools.calculate]
description = "Evaluate a mathematical expression and return the numeric result"
required = ["expression"]
[tools.calculate.properties.expression]
type = "string"
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
[tools.search_products]
description = "Search for products in a catalog by query, category, and price"
required = ["query"]
[tools.search_products.properties.query]
type = "string"
description = "Search query string"
[tools.search_products.properties.category]
type = "string"
enum = ["electronics", "clothing", "food", "books"]
description = "Product category to filter by"
[tools.search_products.properties.max_price]
type = "number"
description = "Maximum price in USD"
[tools.create_todos]
description = "Create a structured todo list"
required = ["todos"]
[tools.create_todos.properties.todos]
type = "array"
description = "List of todo items"
[tools.create_todos.properties.todos.items]
type = "object"
required = ["content", "status", "priority"]
[tools.create_todos.properties.todos.items.properties.content]
type = "string"
description = "The todo item text"
[tools.create_todos.properties.todos.items.properties.status]
type = "string"
description = "Status: pending, in_progress, or completed"
[tools.create_todos.properties.todos.items.properties.priority]
type = "string"
description = "Priority: low, normal, or high"
# -- Should call a tool --
[[scenarios]]
name = "weather_simple"
description = "Basic weather query -> get_current_weather"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "What's the weather like in Tokyo right now?"
[[scenarios]]
name = "calculator_simple"
description = "Math question -> calculate"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 3847 * 926 + 17293"
[[scenarios]]
name = "search_with_filters"
description = "Product search with category and price filter"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.messages]]
role = "user"
content = "Find me electronics under $50"
# -- Multi-turn: tool call then follow-up --
[[scenarios]]
name = "weather_multi_turn"
description = "Weather query -> tool result -> natural language summary"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[scenarios.tool_result]
temperature = "18C"
condition = "partly cloudy"
humidity = "65%"
wind = "12 km/h NW"
[[scenarios.messages]]
role = "user"
content = "What's the weather in Paris?"
[[scenarios]]
name = "calculator_multi_turn"
description = "Math query -> tool result -> model reports the answer"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[scenarios.tool_result]
result = 491682
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 1847 * 263 + 5921"
[[scenarios]]
name = "search_multi_turn"
description = "Search query -> tool result -> model summarizes products"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.tool_result.results]]
name = "Hands-On Machine Learning"
price = 45.99
rating = 4.8
[[scenarios.tool_result.results]]
name = "Deep Learning with Python"
price = 39.99
rating = 4.6
[[scenarios.messages]]
role = "user"
content = "Search for books about machine learning"
# -- Sequential tool calls --
[[scenarios]]
name = "chained_tool_calls_same"
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare the weather in Tokyo and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check both cities. Let me start with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_1"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_1"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios]]
name = "chained_tool_calls_different"
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
[[scenarios.messages]]
role = "assistant"
content = "I'll handle both. Let me check Berlin's weather first."
[[scenarios.messages.tool_calls]]
id = "call_2"
name = "get_current_weather"
arguments = { location = "Berlin" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_2"
content = '{"temperature": "12C", "condition": "rainy"}'
[[scenarios]]
name = "chained_tool_calls_three"
description = "Two prior thinking+tool calls -> results -> model must make a third"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare weather in Tokyo, Paris, and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check all three cities. Starting with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_3"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_3"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios.messages]]
role = "assistant"
content = "Got Tokyo. Now checking Paris."
[[scenarios.messages.tool_calls]]
id = "call_4"
name = "get_current_weather"
arguments = { location = "Paris" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_4"
content = '{"temperature": "18C", "condition": "cloudy"}'
# -- Nested object schema (regression for lossy chat template rendering) --
[[scenarios]]
name = "nested_schema_tool_call"
description = "Tool call with nested object array schema -> create_todos"
expect_tool_call = true
expected_function = "create_todos"
required_arg_keys = ["todos"]
nested_array_key = "todos"
required_item_keys = ["content", "status", "priority"]
tools = ["create_todos"]
[[scenarios.messages]]
role = "user"
content = "Create a todo list with 3 items to learn Python"
# -- Tool name integrity (regression for harmony token leaking into name) --
[tools.glob]
description = "Search for files matching a glob pattern in the codebase"
required = ["pattern"]
[tools.glob.properties.pattern]
type = "string"
description = "The glob pattern to match files against, e.g. '**/*.py'"
[tools.glob.properties.path]
type = "string"
description = "The directory to search in"
[[scenarios]]
name = "tool_name_integrity"
description = "Tool name must not contain harmony tokens like <|channel|>"
expect_tool_call = true
expected_function = "glob"
required_arg_keys = ["pattern"]
tools = ["glob"]
[[scenarios.messages]]
role = "user"
content = "Find all Python files in the src directory"
# -- Should NOT call a tool --
[[scenarios]]
name = "no_tool_joke"
description = "Joke request should NOT trigger any tool"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "Tell me a funny joke about cats."
[[scenarios]]
name = "no_tool_factual"
description = "Factual question answerable from training data"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "What is the capital of Japan?"

189
bench/single-m3-ultra.toml Normal file
View File

@@ -0,0 +1,189 @@
# Single-node M3 Ultra benchmarks
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=1)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
]
[topology]
type = "none"
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]

View File

@@ -1 +0,0 @@
collect_ignore = ["tests/start_distributed_test.py"]

View File

@@ -14,6 +14,7 @@
totalTokens,
thinkingEnabled as thinkingEnabledStore,
setConversationThinking,
stopGeneration,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -103,7 +104,7 @@
const modelSupportsThinking = $derived(() => {
if (!currentModel) return false;
const caps = modelCapabilities[currentModel] || [];
return caps.includes("thinking") && caps.includes("text");
return caps.includes("thinking_toggle") && caps.includes("text");
});
const isEditOnlyWithoutImage = $derived(
@@ -653,86 +654,92 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
{#if loading}
<button
type="button"
onclick={() => stopGeneration()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-exo-medium-gray hover:text-white"
aria-label="Stop generation"
>
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
class="w-3 h-3 sm:w-3.5 sm:h-3.5"
fill="currentColor"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
<rect x="6" y="6" width="12" height="12" rx="1" />
</svg>
<span>EDIT</span>
<span class="hidden sm:inline">Cancel</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
</div>
<!-- Bottom accent line -->

View File

@@ -3,16 +3,17 @@
messages,
currentResponse,
isLoading,
prefillProgress,
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import ImageLightbox from "./ImageLightbox.svelte";
interface Props {
@@ -25,6 +26,7 @@
const messageList = $derived(messages());
const response = $derived(currentResponse());
const loading = $derived(isLoading());
const prefill = $derived(prefillProgress());
// Scroll management - user controls scroll, show button when not at bottom
const SCROLL_THRESHOLD = 100;
@@ -428,6 +430,9 @@
{:else}
<!-- Assistant message styling -->
<div class="p-3 sm:p-4">
{#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}
<PrefillProgressBar progress={prefill} class="mb-3" />
{/if}
{#if message.thinking && message.thinking.trim().length > 0}
<div
class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40"

View File

@@ -26,7 +26,8 @@
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
function formatNumber(num: number | undefined): string {
if (num == null) return "0";
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {

View File

@@ -59,13 +59,14 @@
}
const sizeOptions: ImageGenerationParams["size"][] = [
"auto",
"512x512",
"768x768",
"1024x1024",
"1024x768",
"768x1024",
"1024x1365",
"1365x1024",
"1024x1536",
"1536x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
@@ -176,92 +177,90 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
{params.size.toUpperCase()}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size.toUpperCase()}</span>
</button>
{/each}
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
{/if}
</div>
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -311,7 +310,7 @@
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {qualityDropdownPosition()
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
>

View File

@@ -0,0 +1,52 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
function formatTokenCount(count: number | undefined): string {
if (count == null) return "0";
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -273,6 +273,11 @@ export interface TokenData {
topLogprobs: TopLogprob[];
}
export interface PrefillProgress {
processed: number;
total: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -306,13 +311,14 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
export interface ImageGenerationParams {
// Basic params
size:
| "auto"
| "512x512"
| "768x768"
| "1024x1024"
| "1024x768"
| "768x1024"
| "1024x1365"
| "1365x1024";
| "1024x1536"
| "1536x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
@@ -336,7 +342,7 @@ export interface EditingImage {
}
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "1024x1024",
size: "auto",
quality: "medium",
outputFormat: "png",
numImages: 1,
@@ -519,6 +525,10 @@ class AppStore {
ttftMs = $state<number | null>(null); // Time to first token in ms
tps = $state<number | null>(null); // Tokens per second
totalTokens = $state<number>(0); // Total tokens in current response
prefillProgress = $state<PrefillProgress | null>(null);
// Abort controller for stopping generation
private currentAbortController: AbortController | null = null;
// Topology state
topologyData = $state<TopologyData | null>(null);
@@ -2004,6 +2014,7 @@ class AppStore {
reader: ReadableStreamDefaultReader<Uint8Array>,
targetConversationId: string,
onChunk: (parsed: T) => void,
onEvent?: Record<string, (data: unknown) => void>,
): Promise<void> {
const decoder = new TextDecoder();
let buffer = "";
@@ -2024,6 +2035,24 @@ class AppStore {
const trimmed = line.trim();
if (!trimmed) continue;
// Handle SSE comments (": key json") for prefill progress etc.
if (trimmed.startsWith(": ") && onEvent) {
const comment = trimmed.slice(2);
const spaceIdx = comment.indexOf(" ");
if (spaceIdx > 0) {
const key = comment.slice(0, spaceIdx);
if (onEvent[key]) {
try {
const parsed = JSON.parse(comment.slice(spaceIdx + 1));
onEvent[key](parsed);
} catch {
// Skip malformed JSON in comment
}
}
}
continue;
}
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
@@ -2255,6 +2284,9 @@ class AppStore {
let firstTokenTime: number | null = null;
let tokenCount = 0;
const abortController = new AbortController();
this.currentAbortController = abortController;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
@@ -2271,6 +2303,7 @@ class AppStore {
enable_thinking: enableThinking,
}),
}),
signal: abortController.signal,
});
if (!response.ok) {
@@ -2308,6 +2341,11 @@ class AppStore {
reader,
targetConversationId,
(parsed) => {
// Clear prefill progress when first token data arrives
if (this.prefillProgress) {
this.prefillProgress = null;
}
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
@@ -2370,8 +2408,26 @@ class AppStore {
this.persistConversation(targetConversationId);
}
},
{
prefill_progress: (data) => {
// TaggedModel wraps as {"PrefillProgressChunk": {...}}
// model_dump_json() uses snake_case (by_alias defaults to False)
const raw = data as Record<string, unknown>;
const inner = (raw["PrefillProgressChunk"] ?? raw) as {
processed_tokens: number;
total_tokens: number;
};
this.prefillProgress = {
processed: inner.processed_tokens,
total: inner.total_tokens,
};
},
},
);
// Clear prefill progress after stream ends
this.prefillProgress = null;
// Calculate final TPS
if (firstTokenTime !== null && tokenCount > 1) {
const totalGenerationTime = performance.now() - firstTokenTime;
@@ -2402,20 +2458,31 @@ class AppStore {
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
if (error instanceof DOMException && error.name === "AbortError") {
// User stopped generation — not an error
} else {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
}
} finally {
this.currentAbortController = null;
this.prefillProgress = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
}
}
stopGeneration(): void {
this.currentAbortController?.abort();
this.currentAbortController = null;
}
/**
* Generate an image using the image generation API
*/
@@ -3042,6 +3109,7 @@ export const isLoading = () => appStore.isLoading;
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const prefillProgress = () => appStore.prefillProgress;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const runners = () => appStore.runners;
@@ -3059,6 +3127,7 @@ export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
// Actions
export const stopGeneration = () => appStore.stopGeneration();
export const startChat = () => appStore.startChat();
export const sendMessage = (
content: string,

View File

@@ -932,13 +932,6 @@
};
}
// Debug: Log downloads data when it changes
$effect(() => {
if (downloadsData && Object.keys(downloadsData).length > 0) {
console.log("[Download Debug] Current downloads:", downloadsData);
}
});
// Helper to get download status for an instance
function getInstanceDownloadStatus(
instanceId: string,

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,58 +0,0 @@
# Stage 1: Build the dashboard
FROM node:22-slim AS dashboard
WORKDIR /app/dashboard
COPY dashboard/package.json dashboard/package-lock.json ./
RUN npm ci
COPY dashboard/ .
RUN npm run build
# Stage 2: Build and run exo
FROM python:3.13-slim
# Install system dependencies
# libblas-dev/liblapack-dev/liblapacke-dev are required by MLX CPU backend on Linux
RUN apt-get update && apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
libblas-dev \
liblapack-dev \
liblapacke-dev \
curl \
protobuf-compiler \
iptables \
&& rm -rf /var/lib/apt/lists/*
# Install Rust nightly
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
ENV PATH="/root/.cargo/bin:${PATH}"
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
# Must be done BEFORE uv sync so any source builds also get the fix
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
chmod +x /usr/bin/g++
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files first for better layer caching
COPY pyproject.toml Cargo.toml uv.lock README.md ./
COPY rust/ ./rust/
COPY bench/pyproject.toml ./bench/pyproject.toml
# Copy source and resources
COPY src/ ./src/
COPY resources/ ./resources/
# Copy built dashboard from stage 1
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
# Install Python deps and build Rust bindings, then clean up build artifacts
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
CMD [".venv/bin/exo", "-v"]

View File

@@ -1,195 +0,0 @@
"""Shared E2E test infrastructure for exo cluster tests."""
import asyncio
import json
import os
import sys
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
E2E_DIR = Path(__file__).parent.resolve()
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
class Cluster:
"""Async wrapper around a docker compose exo cluster."""
def __init__(self, name: str, overrides: list[str] | None = None):
self.name = name
self.project = f"e2e-{name}"
compose_files = [str(E2E_DIR / "docker-compose.yml")]
for path in overrides or []:
compose_files.append(str(E2E_DIR / path))
self._compose_base = [
"docker",
"compose",
"-p",
self.project,
*[arg for f in compose_files for arg in ("-f", f)],
]
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
await self.stop()
async def _run(self, *args: str, check: bool = True) -> str:
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
print(output, file=sys.stderr)
raise RuntimeError(
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
)
return output
async def build(self):
# Skip build if the image was pre-built (e.g. in CI with buildx cache)
proc = await asyncio.create_subprocess_exec(
"docker",
"image",
"inspect",
"exo-e2e:latest",
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await proc.wait()
if proc.returncode == 0:
print(" Using pre-built image (exo-e2e:latest)")
return
print(" Building images...")
await self._run("build", "--quiet")
async def start(self):
print(" Starting cluster...")
await self._run("up", "-d")
async def stop(self):
print(" Cleaning up...")
await self._run("down", "--timeout", "5", check=False)
async def logs(self) -> str:
return await self._run("logs", check=False)
async def exec(
self, service: str, *cmd: str, check: bool = True
) -> tuple[int, str]:
"""Run a command inside a running container. Returns (returncode, output)."""
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
"exec",
"-T",
service,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
raise RuntimeError(
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
)
return proc.returncode, output
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
"""Poll check_fn every 2s until it returns True or timeout expires."""
print(f" Waiting for {description}...")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if await check_fn():
print(f" {description}")
return
await asyncio.sleep(2)
output = await self.logs()
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
raise TimeoutError(f"Timed out waiting for {description}")
async def assert_healthy(self):
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
async def both_nodes_started():
log = await self.logs()
return log.count("Starting node") >= 2
async def nodes_discovered():
log = await self.logs()
return log.count("ConnectionMessageType.Connected") >= 2
async def master_elected():
log = await self.logs()
return "demoting self" in log
async def api_responding():
try:
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
return resp.status == 200
except (URLError, OSError):
return False
await self.wait_for("Both nodes started", both_nodes_started)
await self.wait_for("Nodes discovered each other", nodes_discovered)
await self.wait_for("Master election resolved", master_elected)
await self.wait_for("API responding", api_responding)
async def _api(
self, method: str, path: str, body: dict | None = None, timeout: int = 30
) -> dict:
"""Make an API request to the cluster. Returns parsed JSON."""
url = f"http://localhost:52415{path}"
data = json.dumps(body).encode() if body else None
req = Request(
url, data=data, headers={"Content-Type": "application/json"}, method=method
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda: urlopen(req, timeout=timeout).read()
)
return json.loads(resp_bytes)
async def place_model(self, model: str, timeout: int = 600):
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
await self._api("POST", "/place_instance", {"model_id": model})
async def model_ready():
try:
resp = await self._api("GET", "/v1/models")
return any(m.get("id") == model for m in resp.get("data", []))
except Exception:
return False
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
async def chat(
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
) -> dict:
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
deadline = asyncio.get_event_loop().time() + timeout
last_error = None
while asyncio.get_event_loop().time() < deadline:
try:
req = Request(
"http://localhost:52415/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda r=req: urlopen(r, timeout=300).read()
)
return json.loads(resp_bytes)
except Exception as e:
last_error = e
await asyncio.sleep(5)
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")

View File

@@ -1,20 +0,0 @@
services:
exo-node-1:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]
ports:
- "52415:52415"
exo-node-2:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]

View File

@@ -1,83 +0,0 @@
#!/usr/bin/env python3
"""Discovers and runs all E2E tests in e2e/test_*.py.
Tests with '# slow' on the first line of their docstring are skipped
unless --slow is passed or E2E_SLOW=1 is set.
"""
import os
import subprocess
import sys
from pathlib import Path
E2E_DIR = Path(__file__).parent.resolve()
def is_slow(test_file: Path) -> bool:
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
with open(test_file) as f:
for line in f:
if line.strip().startswith("#"):
continue
if line.strip().startswith('"""') or line.strip().startswith("'''"):
# Read into the docstring
for doc_line in f:
if "slow" in doc_line.lower() and doc_line.strip().startswith(
"slow"
):
return True
if '"""' in doc_line or "'''" in doc_line:
break
break
return False
def main():
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
if "--update-snapshots" in sys.argv:
os.environ["UPDATE_SNAPSHOTS"] = "1"
test_files = sorted(E2E_DIR.glob("test_*.py"))
if not test_files:
print("No test files found")
sys.exit(1)
passed = 0
failed = 0
skipped = 0
failures = []
for test_file in test_files:
name = test_file.stem
if is_slow(test_file) and not run_slow:
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
skipped += 1
continue
print(f"=== {name} ===")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
# Retry once — Docker networking (mDNS) can be slow on first boot
print(f"\n=== {name} === RETRYING (attempt 2/2)")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
failed += 1
failures.append(name)
print()
total = passed + failed + skipped
print("================================")
print(
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
)
if failed:
print(f"Failed: {' '.join(failures)}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,78 +0,0 @@
"""Snapshot testing infrastructure for E2E tests.
Provides deterministic regression testing by comparing inference output
against committed baseline snapshots. Tests FAIL if no baseline exists —
baselines must be explicitly generated and committed.
Generate baselines: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow
Update after intentional changes: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow
Snapshots are stored per-architecture (e.g. snapshots/x86_64/, snapshots/arm64/)
since floating-point results differ between CPU architectures.
"""
import difflib
import json
import os
import platform
from pathlib import Path
ARCH = platform.machine()
SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" / ARCH
def assert_snapshot(
name: str,
content: str,
metadata: dict,
) -> None:
"""Compare content against a saved snapshot, or create one if missing.
Args:
name: Snapshot identifier (used as filename: snapshots/{arch}/{name}.json).
content: The actual inference output to compare.
metadata: Additional context stored alongside content (model, seed, etc.).
Not used for comparison -- purely documentary.
Raises:
AssertionError: If content doesn't match the saved snapshot.
Environment:
UPDATE_SNAPSHOTS=1: Overwrite existing snapshot with actual content.
"""
snapshot_file = SNAPSHOTS_DIR / f"{name}.json"
update = os.environ.get("UPDATE_SNAPSHOTS") == "1"
if update:
# Explicitly regenerate snapshot
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
snapshot_data = {**metadata, "arch": ARCH, "content": content}
snapshot_file.write_text(json.dumps(snapshot_data, indent=2) + "\n")
print(f" Updated snapshot: {ARCH}/{snapshot_file.name}")
elif not snapshot_file.exists():
raise AssertionError(
f"No baseline snapshot for '{name}' on {ARCH}.\n"
f"Expected file: {snapshot_file}\n\n"
f"Generate baselines with: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow"
)
else:
snapshot = json.loads(snapshot_file.read_text())
expected = snapshot["content"]
if content != expected:
diff = "\n".join(
difflib.unified_diff(
expected.splitlines(),
content.splitlines(),
fromfile=f"expected ({snapshot_file.relative_to(SNAPSHOTS_DIR.parent.parent)})",
tofile="actual",
lineterm="",
)
)
raise AssertionError(
f"Snapshot mismatch for '{name}' on {ARCH}!\n\n"
f"{diff}\n\n"
f"Expected: {expected!r}\n"
f"Actual: {content!r}\n\n"
f"To update: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow"
)
print(f" Output matches snapshot ({ARCH}/{snapshot_file.name})")

View File

@@ -1,22 +0,0 @@
"""Test: Basic cluster formation.
Verifies two nodes discover each other, elect a master, and the API responds.
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster("cluster_formation") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print("PASSED: cluster_formation")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,61 +0,0 @@
"""Test: Deterministic inference output (snapshot test).
Sends a chat completion request with a fixed seed,
then verifies the output matches a known-good snapshot. This ensures
inference produces consistent results across runs.
Uses MLX CPU backend in Docker on x86 Linux.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "What is 2+2? Reply with just the number."
MAX_TOKENS = 32
async def main():
async with Cluster("inference_snapshot") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="inference_snapshot",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: inference_snapshot")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,47 +0,0 @@
"""Test: Cluster works without internet access.
Verifies exo functions correctly when containers can talk to each other
but cannot reach the internet. Uses iptables to block all outbound traffic
except private subnets and multicast (for mDNS discovery).
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster(
"no_internet",
overrides=["tests/no_internet/docker-compose.override.yml"],
) as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Verify internet is actually blocked from inside the containers
for node in ["exo-node-1", "exo-node-2"]:
rc, _ = await cluster.exec(
node,
"curl",
"-sf",
"--max-time",
"3",
"https://huggingface.co",
check=False,
)
assert rc != 0, f"{node} should not be able to reach the internet"
print(f" {node}: internet correctly blocked")
# Verify exo detected no internet connectivity
log = await cluster.logs()
assert "Internet connectivity: False" in log, "exo should detect no internet"
print(" exo correctly detected no internet connectivity")
print("PASSED: no_internet")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,65 +0,0 @@
"""Test: Runner chaos — abrupt runner death detection.
slow
Sends a chat completion with the EXO_RUNNER_MUST_DIE trigger, which causes
the runner process to call os._exit(1) (simulating an OOM kill). Verifies that
the RunnerSupervisor health check detects the death and the system doesn't hang.
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import contextlib
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
async def main():
async with Cluster("runner_chaos") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Place the model so a runner is loaded and ready
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
# Send a chat request with the die trigger.
# The runner will call os._exit(1) mid-inference, simulating OOM kill.
# The chat request itself will fail — that's expected.
print(" Sending EXO_RUNNER_MUST_DIE trigger...")
with contextlib.suppress(Exception):
await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": "EXO RUNNER MUST DIE"}],
timeout=60,
)
# Wait for the health check to detect the death and emit RunnerFailed
async def health_check_detected():
log = await cluster.logs()
return "runner process died unexpectedly" in log
await cluster.wait_for(
"Health check detected runner death",
health_check_detected,
timeout=30,
)
# Verify RunnerFailed was emitted (visible in logs)
log = await cluster.logs()
assert "runner process died unexpectedly" in log, (
f"Expected health check to detect runner death but it didn't.\nLogs:\n{log}"
)
print("PASSED: runner_chaos")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,60 +0,0 @@
"""Test: Code generation snapshot.
slow
Verifies deterministic output for a code generation prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = (
"Write a Python function to reverse a string. Only output the code, no explanation."
)
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_code_gen") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_code_gen",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_code_gen")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,65 +0,0 @@
"""Test: Edge case snapshots.
slow
Verifies deterministic output for edge-case prompts: single word input,
special characters, and unicode.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
MAX_TOKENS = 32
CASES = [
("edge_single_word", "Hi"),
("edge_special_chars", "What does 2 * (3 + 4) / 7 - 1 equal? Use <math> tags."),
("edge_unicode", "Translate 'hello' to Japanese, Chinese, and Korean."),
]
async def main():
async with Cluster("snapshot_edge") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
for snapshot_name, prompt in CASES:
print(f" [{snapshot_name}] Sending: {prompt!r}")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": prompt}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{snapshot_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": prompt,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_edge")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,58 +0,0 @@
"""Test: Longer output snapshot.
slow
Verifies deterministic output with a higher max_tokens (128).
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "Explain how a binary search algorithm works."
MAX_TOKENS = 128
async def main():
async with Cluster("snapshot_long_output") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED}, max_tokens={MAX_TOKENS})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_long_output",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_long_output")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,73 +0,0 @@
"""Test: Multi-model snapshot tests.
slow
Verifies deterministic output across different model architectures to catch
model-specific regressions. Each model uses its own snapshot file.
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
SEED = 42
PROMPT = "What is the capital of France?"
MAX_TOKENS = 32
MODELS = [
"mlx-community/SmolLM2-135M-Instruct",
"mlx-community/Llama-3.2-1B-Instruct-4bit",
"mlx-community/gemma-2-2b-it-4bit",
]
async def main():
async with Cluster("snapshot_multi_model") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
for model in MODELS:
short_name = (
model.split("/")[-1].lower().replace("-", "_").replace(".", "_")
)
snapshot_name = f"snapshot_multi_{short_name}"
print(f" Launching model {model}...")
await cluster.place_model(model)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=model,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{short_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": model,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print(f" [{short_name}] PASSED")
print("PASSED: snapshot_multi_model")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,58 +0,0 @@
"""Test: Reasoning/math snapshot.
slow
Verifies deterministic output for a simple reasoning prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "If I have 3 apples and give away 1, how many do I have? Think step by step."
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_reasoning") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_reasoning",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_reasoning")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,32 +0,0 @@
# Block all outbound internet traffic using iptables while preserving:
# - Multicast (224.0.0.0/4) for mDNS peer discovery
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
# - Loopback (127/8)
# Requires NET_ADMIN capability for iptables.
services:
exo-node-1:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v
exo-node-2:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v

View File

@@ -74,7 +74,6 @@
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
@@ -115,7 +114,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.6"; in
version = let v = "0.30.7.dev20260218+14841977"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
};
patches = [

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"mlx-lm==0.30.7",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -64,6 +64,7 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
@@ -132,7 +133,7 @@ markers = [
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -58,6 +58,21 @@
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
@@ -74,14 +89,25 @@
linuxOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;
@@ -132,6 +158,7 @@
exo-test-env = testVenv;
} // {
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 405874409472

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 765577920512

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 122406567936

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 229780750336

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 198556925568

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 286737579648

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 396963397248

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 19327352832

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 22548578304

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 26843545600

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 34359738368

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-8bit-MXFP8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 706522120192

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 662498705408

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 100086644736

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 242986745856

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-4bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "4bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 128666664960

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-6bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "6bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 185826705408

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-8bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 342884352

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 698351616

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 141733920768

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 268435456000

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 17612931072

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 33279705088

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 47080074240

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/SmolLM2-135M-Instruct"
n_layers = 30
hidden_size = 576
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "SmolLM2 135M"
capabilities = ["text"]
[storage_size]
in_bytes = 269060381

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 114572190076

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 159039627774

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 209082699847

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/gemma-2-2b-it-4bit"
n_layers = 26
hidden_size = 2304
supports_tensor = false
tasks = ["TextGeneration"]
family = "gemma2"
quantization = "4bit"
base_model = "Gemma 2 2B"
capabilities = ["text"]
[storage_size]
in_bytes = 1492755242

View File

@@ -1,2 +0,0 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -25,17 +25,17 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
# "nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
# "ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.17.2" }
@@ -45,33 +45,18 @@ pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
futures-lite = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -19,7 +19,7 @@ class ConnectionUpdate:
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
def peer_id(self) -> builtins.str:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@@ -40,92 +40,22 @@ class Keypair:
Identity keypair of a node.
"""
@staticmethod
def generate_ed25519() -> Keypair:
def generate() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def generate_ecdsa() -> Keypair:
def from_bytes(bytes: bytes) -> Keypair:
r"""
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
Construct an Ed25519 keypair from secret key bytes
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
Get the secret key bytes underlying the keypair
"""
def to_string(self) -> builtins.str:
def to_node_id(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
"""
@typing.final
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""

View File

@@ -2,11 +2,10 @@
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
pin::{Pin, pin},
pin::Pin,
task::{Context, Poll},
};
@@ -26,15 +25,13 @@ where
impl<F> Future for AllowThreads<F>
where
F: Future + Ungil,
F::Output: Ungil,
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
}
}

View File

@@ -1,240 +0,0 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -0,0 +1,47 @@
use crate::ext::ResultExt as _;
use libp2p::identity::Keypair;
use pyo3::types::{PyBytes, PyBytesMethods as _};
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Construct an Ed25519 keypair from secret key bytes
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Get the secret key bytes underlying the keypair
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
fn to_node_id(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}

View File

@@ -4,28 +4,14 @@
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
mod ident;
mod networking;
use crate::ident::PyKeypair;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -34,24 +20,11 @@ pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
@@ -62,7 +35,7 @@ pub(crate) mod ext {
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new(py, self).unbind())
Python::attach(|py| PyBytes::new(py, self).unbind())
}
}
@@ -98,7 +71,7 @@ pub(crate) mod ext {
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
Python::attach(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
@@ -175,24 +148,6 @@ pub(crate) mod ext {
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
@@ -204,8 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
m.add_class::<PyKeypair>()?;
networking_submodule(m)?;
// top-level constructs

View File

@@ -8,12 +8,12 @@
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair;
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
@@ -25,7 +25,7 @@ use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
use pyo3::types::PyTuple;
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
peer_id: String,
/// Remote connection's IPv4 address.
#[pyo3(get)]
@@ -155,7 +155,6 @@ async fn networking_task(
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
@@ -252,7 +251,7 @@ async fn networking_task(
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
@@ -273,7 +272,7 @@ async fn networking_task(
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
@@ -485,7 +484,7 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,

View File

@@ -1,159 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -1,8 +0,0 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -1,81 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -19,21 +19,14 @@ either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-lite = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
@@ -41,4 +34,4 @@ keccak-const = { workspace = true }
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,4 +1,4 @@
use futures::stream::StreamExt as _;
use futures_lite::StreamExt;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
@@ -38,19 +38,19 @@ async fn main() {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
event = swarm.next() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
}))) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
@@ -64,7 +64,7 @@ async fn main() {
}
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }

View File

@@ -1,127 +0,0 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -1,8 +1,7 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_lite::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
@@ -363,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
if self.retry_delay.poll(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)

View File

@@ -1,44 +0,0 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -3,19 +3,7 @@
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
@@ -54,11 +42,3 @@ pub(crate) mod ext {
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures::{AsyncRead, AsyncWrite};
use futures_lite::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;

View File

@@ -1,11 +1,10 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
{ inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
rustToolchain = inputs'.fenix.packages.stable.withComponents [
"cargo"
"rustc"
"clippy"

View File

@@ -1,2 +0,0 @@
[toolchain]
channel = "nightly"

View File

@@ -14,6 +14,7 @@ from exo.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
@@ -46,6 +47,7 @@ class DownloadCoordinator:
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -61,8 +63,13 @@ class DownloadCoordinator:
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
return str(EXO_MODELS_DIR / model_id.normalize())
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
@@ -74,6 +81,7 @@ class DownloadCoordinator:
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -93,6 +101,7 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
@@ -101,23 +110,30 @@ class DownloadCoordinator:
self._last_progress_time[model_id] = current_time()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
logger.info(
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
)
if not self.offline:
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
if not self.offline:
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
# Try multiple endpoints since some ISPs/networks block specific IPs
for host in ("1.1.1.1", "8.8.8.8", "1.0.0.1"):
try:
socket.create_connection((host, 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
logger.debug(f"Internet connectivity: True (via {host})")
return
except OSError:
continue
self.shard_downloader.set_internet_connection(False)
logger.debug("Internet connectivity: False")
async def _check_internet_connection(self) -> None:
first_connection = True
@@ -170,7 +186,11 @@ class DownloadCoordinator:
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
progress = DownloadPending(
shard_metadata=shard,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
@@ -184,6 +204,7 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -191,6 +212,20 @@ class DownloadCoordinator:
)
return
if self.offline:
logger.warning(
f"Offline mode: model {model_id} is not fully available locally, cannot download"
)
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=f"Model files not found locally in offline mode: {model_id}",
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
return
# Start actual download
self._start_download_task(shard, initial_progress)
@@ -206,6 +241,7 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
@@ -219,6 +255,7 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(
@@ -253,6 +290,7 @@ class DownloadCoordinator:
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
@@ -295,11 +333,18 @@ class DownloadCoordinator:
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
status = DownloadOngoing(
@@ -308,6 +353,9 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
continue

View File

@@ -448,12 +448,13 @@ async def download_file_with_retry(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
skip_internet: bool = False,
) -> Path:
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress, skip_internet
)
except HuggingFaceAuthenticationError:
raise
@@ -487,10 +488,14 @@ async def _download_file(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
skip_internet: bool = False,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
if skip_internet:
return target_path
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
@@ -510,6 +515,11 @@ async def _download_file(
)
return target_path
if skip_internet:
raise FileNotFoundError(
f"File {path} not found locally and cannot download in offline mode"
)
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -814,6 +824,7 @@ async def download_shard(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
skip_internet=skip_internet,
)
if not skip_download:

Some files were not shown because too many files have changed in this diff Show More