Compare commits

...

50 Commits

Author SHA1 Message Date
Alex Cheema
3cf844e7f1 Merge branch 'main' into meta-instance-split/download-completion-detection 2026-02-22 06:57:27 -08:00
Alex Cheema
18717023ad chore: remove deprecated MlxIbv dashboard references (#1584)
## Summary
- Remove legacy MlxIbvInstance references from ChatSidebar and ModelCard
components
- MlxIbv was replaced by MlxJaccl; these are leftover type checks
- Split from #1519 for independent review

## Test plan
- [x] Visual inspection of dashboard components

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 06:56:12 -08:00
Alex Cheema
4768f50d56 fix: detect completed downloads by checking final file exists
The previous byte-comparison fallback in the coordinator could falsely
report .partial files as complete (e.g. when a process was killed after
download but before hash verification and rename). Instead, fix the
source: only mark a file as "complete" during status scanning when the
final (non-.partial) file exists on disk, which implies hash verification
and rename succeeded. Remove the coordinator-level byte comparison
workaround since the source now reports correctly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-21 13:14:07 -08:00
Alex Cheema
3b54e7dfa7 fix: detect completed downloads via byte comparison
When scanning existing download status, a download could report
status "in_progress" or "not_started" even though all bytes have
been downloaded. This adds a fallback check: if downloaded >= total
bytes (and total > 0), treat it as completed regardless of the
reported status string.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-21 07:21:44 -08:00
Alex Cheema
1780e4ade4 fix: change RDMA AVAILABLE to RDMA NOT ENABLED warning (#1580)
## Summary
- Changed blue info badge "RDMA AVAILABLE" to yellow warning badge "RDMA
NOT ENABLED" — more accurately describes the state
- Added hover tooltip with enable instructions to all views (was missing
in 2 of 4 instances)
- Warning icon instead of info icon, consistent with other cluster
warnings (TB cycle, macOS mismatch)

## Screenshots

**Badge (yellow warning):**
![RDMA warning
badge](https://raw.githubusercontent.com/exo-explore/exo/3f7bdb482c5011d60f140aa84ab21023032e4a57/rdma-warning.png)

**Hover tooltip with instructions:**
![RDMA warning
hover](https://raw.githubusercontent.com/exo-explore/exo/3f7bdb482c5011d60f140aa84ab21023032e4a57/rdma-warning-hover.png)

## Test plan
- [x] Dashboard builds successfully
- [ ] Verify badge appears when 2+ TB5 nodes have RDMA disabled
- [ ] Verify hover tooltip shows in normal layout
- [ ] Verify hover tooltip shows in topology-only mode
- [ ] Verify dismiss button works
- [ ] Verify compact badge in status bar shows yellow warning

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-20 21:40:07 +00:00
Jake Hillion
ab9273e723 downloads: add read_only flag to DownloadCompleted for EXO_MODELS_PATH
Models in EXO_MODELS_PATH are pre-downloaded into read-only directories
and must not be deleted. The DownloadCoordinator had no awareness of
these paths, so they never appeared as completed downloads in cluster
state, and the bench harness could attempt to delete them when freeing
disk space.

Added a `read_only: bool` field to `DownloadCompleted` (default False).
The DownloadCoordinator now checks `resolve_model_in_path` in
`_start_download`, proactively scans EXO_MODELS_PATH in
`_emit_existing_download_progress` to emit DownloadCompleted events for
all pre-downloaded models (overriding DownloadPending from the regular
scan), and refuses deletion of read-only models. The bench harness
filters out read-only models from deletion candidates.

Test plan:
- Ran with EXO_MODELS_PATH. Available models now show as downloaded in
  the UI. There isn't good UI for the fact they can't be deleted, but it
  should work with exo_bench.
2026-02-20 20:27:45 +00:00
Jake Hillion
71e48c0f62 model-cards: add missing metadata for Qwen3 Coder Next variants (#1576)
The Qwen3-Coder-Next model card TOML files were missing family,
quantization, base_model, and capabilities fields. This caused them not
to appear under the Qwen family filter in the dashboard's model picker.

Added the missing metadata to all five variants (4bit, 5bit, 6bit, 8bit,
bf16), matching the format used by the existing Qwen3-Coder-480B model
cards.

Test plan:
- Eyeballs
2026-02-20 18:25:49 +00:00
Jake Hillion
42da58c297 worker: add EXO_MODELS_PATH for pre-downloaded model directories
Users with pre-existing model files (e.g. on shared NFS mounts or from
prior downloads) had no way to point exo at those directories without
going through the download coordinator. EXO_MODELS_DIR only moves the
download target directory, it doesn't support read-only search paths.

Added EXO_MODELS_PATH environment variable as a colon-separated list of
directories to search for models. When the worker's plan loop encounters
a DownloadModel task, it checks these directories first and emits a
synthetic DownloadCompleted event if found, bypassing the download
coordinator entirely. The runner's build_model_path also checks these
directories first so the correct path is used during model loading.

This keeps the existing event sourcing state machine unchanged — the
DownloadCompleted event propagates naturally through the system, so
_load_model and all downstream logic work without modification.

Test plan:
- `s1@s1s-Mac-Studio ~ % EXO_LIBP2P_NAMESPACE=jake EXO_MODELS_PATH="/Volumes/Definitely Leo's SSD" nix --extra-experimental-features 'nix-command flakes' run github:exo-explore/exo/f2babbc2f742357d97dc177619fec062ef545be4`
- Started mlx-community/Qwen3-Coder-Next-4bit - it's present on the disk
  and it worked.
- Renamed one safetensor of mlx-community/Qwen3-Coder-Next-4bit on the
  disk. It then started the download locally, as expected.
2026-02-20 18:17:56 +00:00
Mustafa Alp Yılmaz
6b5a705959 fix: immediate cancel check after prefill completes (#1575)
## Problem

When a request is cancelled during prefill, the cancellation is not
detected until `check_for_cancel_every` additional tokens have been
generated. This is because `tokens_since_last_cancel_check` is
initialized to `0`, meaning the first cancel check only happens after
generating `check_for_cancel_every` tokens post-prefill.

For long prefills (which are the most likely to be cancelled), this adds
unnecessary latency before the cancellation is actually honoured.

## Fix

Initialize `tokens_since_last_cancel_check` to `check_for_cancel_every`
instead of `0`, so the very first token generated after prefill triggers
an immediate cancel check.

```diff
- tokens_since_last_cancel_check = 0
+ tokens_since_last_cancel_check = check_for_cancel_every
```

## Impact

- Cancellations issued during prefill are detected immediately when
generation begins
- No change in behaviour for non-cancelled requests (the counter resets
to `0` after each check as before)
- 1 line changed

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-20 18:00:59 +00:00
Alex Cheema
6b54a27019 fix: add downloaded_bytes to DownloadPending event (#1564)
## Summary
- Add downloaded_bytes field to existing DownloadPending event for
accurate resume progress
- Minimal change per maintainer directive — no new download states
introduced

## Test plan
- [x] 42 tests passed, 1 skipped
- [x] Verified downloaded_bytes populates correctly for partial
downloads

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-20 17:54:18 +00:00
rltakashige
e01f50a5cd Update mlx fork (#1565)
## Motivation

Some fixes upstream. This sort of commit will probably be quite common
until GPU locks are resolved.
2026-02-20 17:23:52 +00:00
Evan Quiney
1093080214 cancel active downloads on coordinator shutdown (#1567)
we were seeing some crashes as lost download tasks were trying to push
data toward a deleted coordinator. this cancels download tasks with the
coordinator's shutdown on master election
2026-02-20 17:17:43 +00:00
rltakashige
1a2b8b044a Refactor runner into separate runners (#1570)
## Motivation

We're going to be refactoring the llm inference code, so we should split
the runner up into parts while we can.

## Test Plan

### Manual Testing
Works on single node, at least.

### Automated Testing
Passes CI. Will be tested by our tests today.
2026-02-20 17:11:01 +00:00
Evan Quiney
dc8d42b4dc add system ids (#1536)
addresses some election edge cases where a new worker with an old master
would get stuck on the old workers buffer index - we now use new system
ids each time we instantiate a node, and each event-producing system has
a unique system id for its lifespan (until the master moves).
2026-02-20 15:41:59 +00:00
Jake Hillion
d484b062e8 bench: add download timing to bench output (#1566)
The bench script downloads models during the planning phase but doesn't
record how long the download took, making it difficult to track download
performance for a given model over time.

Modified `run_planning_phase` to return download metadata: whether a
fresh download occurred, the wall-clock duration, and the model size in
bytes. These fields are included in every JSON output row alongside the
existing per-run metrics, and a summary line is logged to the console.

This allows filtering bench results by `download_occurred` and grouping
by `model_id` to compute average download times across runs.

Test plan:

```
# existing model
jake@maverick:/data/users/jake/repos/exo/ > nix run .#exo-bench -- --host s1 --model mlx-community/gpt-oss-120b-MXFP4-Q8 --pp 128 --tg 128
...
2026-02-20 15:23:49.081 | INFO     | __main__:main:340 - Planning phase: checking downloads...
2026-02-20 15:23:49.152 | INFO     | harness:run_planning_phase:402 - Started download on 12D3KooWKx41iikn188ozrxSdoG26g88jFCfie9wEA1eQR8csbPm
2026-02-20 15:23:49.184 | INFO     | __main__:main:352 - Download: model already cached
...
Wrote results JSON: bench/results.json
jake@maverick:/data/users/jake/repos/exo/ > cat bench/results.json
[
  {
    "elapsed_s": 2.9446684420108795,
    "output_text_preview": "The user just typed a long series of \"a\". Possibly they are testing. There's no explicit question. Could be they want a response? Might be a test of handling long input. We can respond politely, ask i",
    "stats": {
      "prompt_tps": 117.7872141515621,
      "generation_tps": 85.49598231498028,
      "prompt_tokens": 129,
      "generation_tokens": 128,
      "peak_memory_usage": {
        "inBytes": 68215145744
      }
    },
    "model_short_id": "gpt-oss-120b-MXFP4-Q8",
    "model_id": "mlx-community/gpt-oss-120b-MXFP4-Q8",
    "placement_sharding": "Pipeline",
    "placement_instance_meta": "MlxRing",
    "placement_nodes": 1,
    "instance_id": "68babc2a-6e94-4c70-aa07-7ec681f7c856",
    "pp_tokens": 128,
    "tg": 128,
    "repeat_index": 0
  }
]%
# no change to output
```

```
# missing model
jake@maverick:/data/users/jake/repos/exo/ > nix run .#exo-bench -- --host s1 --model mlx-community/Meta-Llama-3.1-8B-Instruct-4bit --pp 128 --tg 128
...
2026-02-20 15:24:42.553 | INFO     | __main__:main:340 - Planning phase: checking downloads...
2026-02-20 15:24:42.625 | INFO     | harness:run_planning_phase:402 - Started download on 12D3KooWKx41iikn188ozrxSdoG26g88jFCfie9wEA1eQR8csbPm
2026-02-20 15:25:37.494 | INFO     | __main__:main:350 - Download: 54.9s (freshly downloaded)
...
Wrote results JSON: bench/results.json
jake@maverick:/data/users/jake/repos/exo/ > cat bench/results.json
[
  {
    "elapsed_s": 1.500349276990164,
    "output_text_preview": "It seems like you've entered a large number of 'a's. If you'd like to discuss something or ask a question, I'm here to help. If not, is there anything else I can assist you with? \n\nIf you're intereste",
    "stats": {
      "prompt_tps": 395.43264952543666,
      "generation_tps": 128.03520443181478,
      "prompt_tokens": 129,
      "generation_tokens": 128,
      "peak_memory_usage": {
        "inBytes": 5116952079
      }
    },
    "model_short_id": "Meta-Llama-3.1-8B-Instruct-4bit",
    "model_id": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
    "placement_sharding": "Pipeline",
    "placement_instance_meta": "MlxRing",
    "placement_nodes": 1,
    "instance_id": "ccd9bd71-d4cc-4b75-a37f-98090544626a",
    "pp_tokens": 128,
    "tg": 128,
    "repeat_index": 0,
    "download_duration_s": 54.88322358299047
  }
]%
# one new field
```
2026-02-20 15:33:08 +00:00
Alex Cheema
e32b649d2f fix: enable psutil fallback for memory monitoring when macmon is missing (#1478)
## Summary
- On macOS, memory monitoring relied exclusively on `macmon` — the
psutil fallback was explicitly disabled (`memory_poll_rate = None`)
- When `macmon` is not installed (e.g., mac-mini-2 through mac-mini-4 in
our cluster), **no memory data was reported**, causing nodes to show 0GB
memory in the cluster state
- This blocked the scheduler from placing shards on those nodes since it
had no memory data to work with
- Fix: when `macmon` is not found on Darwin, fall back to psutil-based
memory polling (`memory_poll_rate = 1`)

## Root cause
`InfoGatherer` has two memory monitoring paths:
1. `macmon` (Darwin-only): provides memory + GPU/CPU/power stats
2. `psutil` (non-Darwin fallback): provides memory via
`MemoryUsage.from_psutil()`

Line 378 disabled psutil on Darwin: `memory_poll_rate = None if
IS_DARWIN else 1`
Line 389 only starts macmon if the binary exists: `if
shutil.which("macmon") is not None`

If macmon is missing on Darwin, **neither path runs** — zero memory
reported.

## Test plan
- [ ] Verify `uv run basedpyright` passes (0 errors confirmed)
- [ ] Verify `uv run ruff check` passes (confirmed)
- [ ] Verify `uv run pytest src/exo/utils/info_gatherer/` passes (2/2
confirmed)
- [ ] Deploy to cluster nodes without macmon and verify memory appears
in `/state`

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-20 13:12:22 +00:00
Alex Cheema
bddad7e79c feat: show ETA on prefill progress bar (#1557)
## Summary
- Show estimated time remaining during prefill (prompt processing phase)
- Track prefill start time via performance.now() and extrapolate from
observed token throughput
- Display ~Xs remaining or ~Xm Ys remaining next to the percentage on
the progress bar
- Wait 200ms before showing ETA to ensure a stable sample window

## Changes
**PrefillProgressBar.svelte**: Add etaText derived computation that
calculates remaining time from (remainingTokens / tokensPerMs). Renders
in a new flex row below the progress bar alongside the percentage.

**app.svelte.ts**: Add startedAt: number field to PrefillProgress
interface. Set on first prefill_progress SSE event, preserved across
subsequent updates.

## Test plan
- [ ] Start inference with a long prompt (10k+ tokens) on a multi-node
cluster
- [ ] Verify the progress bar shows ~Xs remaining after ~200ms of
prefill
- [ ] Verify the ETA decreases as prefill progresses
- [ ] Verify short prefills (<200ms) dont flash a briefly-visible ETA
- [ ] Verify ETA disappears when prefill completes and token generation
begins

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-20 12:37:56 +00:00
rltakashige
addf73a144 Add support for Ollama API (#1560)
## Motivation

Ollama has a bunch of integrations, such as OpenWebUI, that are very
handy. Let's support it :)

## Test Plan

### Manual Testing
<img width="3426" height="1998" alt="image"
src="https://github.com/user-attachments/assets/44b07f1e-308e-4ff1-9a11-922d8279939f"
/>
2026-02-20 12:03:27 +00:00
Mustafa Alp Yılmaz
a16ff2c047 fix: correct misleading docstring in seed_models (#1561)
## Summary
- Fixed stale docstring in `seed_models()` that referenced
`.cache/huggingface/hub` when the function actually moves models to
`EXO_MODELS_DIR` (resolved via `ensure_models_dir()`)
- The old docstring was misleading for AI coding agents analyzing the
codebase, causing incorrect conclusions about model storage paths

## Changes
`src/exo/download/download_utils.py`: Updated docstring from `"Move
model in resources folder of app to .cache/huggingface/hub"` to `"Move
models from resources folder to EXO_MODELS_DIR."`

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-20 11:57:55 +00:00
rltakashige
3006c8ea4e Ensure coordinator is rank 0 (#1559)
## Motivation

Coordinator can be a random rank. Let's just fix this to rank 0 as
that's what we typically assume.

## Test Plan

### Manual Testing
Works as normal on 2 nodes.


Let's wait for a little more testing to merge this.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-02-20 11:46:24 +00:00
rltakashige
f662c129dd Prioritise tb for ring instances (#1556)
## Motivation

TB has better bandwidth and latency than ethernet. We should prioritise
TB5 where possible. This drastically improves distributed image
generation performance.

## Test Plan

### Manual Testing
Saw on the dashboard that TB (169.254) addresses were prioritised.

Tested that image models scale much better.

### Automated Testing
No regression on Kimi K2.5
2026-02-19 21:32:48 +00:00
Evan Quiney
c45ff9ad43 memory tidy (#1558)
add some pythonic extensions to memory, did a bunch of cleanup.
2026-02-19 21:15:33 +00:00
rltakashige
7031901ae5 Prevent common fatal crashes (#1555)
## Motivation
Occasionally, memory does not get released when we shut down. There is
no reason to delay deleting the model.

Also handles can become None during shutdown, causing TypeErrors which
are not handled and bringing down exo.

Similarly, we were closing the event sender in the wrong place.

Also let's not verify the SSL certificate for http connections to local
peers, as this is failing sometimes and crashing.

## Test Plan

### Manual Testing
No more crashes as described.
2026-02-19 20:51:17 +00:00
rltakashige
cf648a53b8 Add thinking in thinking blocks, and fix DeepSeek interleaved tool calls (#1548)
## Motivation

OpenCode shows <think> tags and not thinking blocks as we aren't
following the API specs properly.

Claude was also getting horrible prefix cache hits because it sends
headers.

## Changes

Handle thinking tokens properly by placing them in think tags for each
of the API endpoints.
Also support DeepSeekV3.2 tool calling properly as a minor feature.
Strips Claude headers at the API level.

## Test Plan

### Manual Testing
Tested OpenCode manually
Needs testing with Claude.

### Automated Testing
All CI and tests passing - added a new e2e test for DeepSeekV32 tool
parsing.
2026-02-19 18:44:49 +00:00
Alex Cheema
94b2ce6922 feat: Mac Studio en2 RDMA port warning v2 (#1551)
Rebuilt from scratch (replaces PR #1543). Detects when Mac Studio uses
RDMA over en2 (TB5 port next to Ethernet) which does not support RDMA.
Shows dismissible warning banner with hover tooltip showing affected
devices, SVG rear panel illustration, and fix instructions. 205 lines in
+page.svelte.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-19 18:39:17 +00:00
rltakashige
423ed0f07f Strip Claude headers to improve prefix cache hit rates (#1552)
## Motivation
Our hits are really bad at the moment (0.2%). This PR makes it 98.5% on
average.

## Changes

Also adds an example for how to run Claude using Exo.

## Why It Works
Claude sends some billing and session headers that change with each
message.

## Test Plan

### Manual Testing
Works in manual testing.
2026-02-19 18:29:34 +00:00
Evan Quiney
ed001f2409 remove prefillprogress event (#1550)
this should never have been a separate event, but i didnt quite
communicate that well when this was merged. convert PrefillProgress to a
chunk like the rest of the runner responses.

tested with Llama-3.3-70B, prefill progress events still show up in the
dashboard as usual
2026-02-19 18:23:28 +00:00
Evan Quiney
4c4c6ce99f simplify rust ident module
this is partly dead code, partly narrowing the rust-python boundary in
prep for future rewrites. no testing as this is all type safe
refactoring.
2026-02-19 17:19:31 +00:00
Jake Hillion
42e1e7322b bench: restore --danger-delete-downloads planning phase (#1542)
c2f2111b extracted shared utilities from exo_bench.py into harness.py
but accidentally dropped the run_planning_phase function and
--danger-delete-downloads CLI argument in the process.

Restored run_planning_phase in harness.py (where its dependencies now
live) and re-added the --danger-delete-downloads argument to
add_common_instance_args. Re-wired the planning phase call in
exo_bench.py's main() before the benchmark loop.
2026-02-19 15:42:02 +00:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
Mustafa Alp Yılmaz
2e29605194 fix: finalize cancel tasks (#1498)
# Cancel task finalization (main.py)

After forwarding the cancel to the runner supervisor, emit TaskStatusUpdated(Complete) for the cancel task itself. This ensures the cancel task is properly removed from state.tasks.
2026-02-19 13:27:34 +00:00
Evan Quiney
cacb456cb2 remove nightly (#1538)
we have no good need for rust nightly (nor futures, for that matter)
2026-02-19 12:55:31 +00:00
rltakashige
51021f6fc6 Add cancellation button and the ability to cancel during prefill (#1540)
## Motivation
There's no way to easily use the cancellation features we added! Also,
prefill can take ages so let's allow cancelling out of that.

## Changes

Wiring up our existing functionality to easily cancel during generation
(and adding stuff to do so during prefill)

## Test Plan

### Manual Testing
Tested it works during both prefill and decode.

### Automated testing
Needs testing to see if this causes a GPU timeout error on large prefill
on large models in pipeline parallel. However, from manually testing GLM
5 pipeline ring on 2 nodes, and from reading the code, it does not seem
like this will be the case.
2026-02-19 11:40:59 +00:00
Alex Cheema
025ed9fd82 feat: add prefill progress bar for long prompts (#1181)
## Motivation

Users processing long prompts have no visibility into when token
generation will start. This feature adds a progress bar showing prefill
progress, giving users real-time feedback during prompt processing.

## Changes

### Backend
- Added `PrefillProgress` event type with `command_id`,
`processed_tokens`, `total_tokens`
- Added `PrefillProgressResponse` type (though now using direct callback
approach)
- Wired `prompt_progress_callback` through MLX's `stream_generate()`
- Progress events sent directly from callback for real-time updates (not
batched)
- API generates SSE named events: `event: prefill_progress\ndata: {...}`
- Added `PrefillProgressData` dataclass and `StreamEvent` union type in
API

### Dashboard
- Added `PrefillProgress` interface to store
- Updated SSE parsing to handle `event:` lines (named events)
- Created `PrefillProgressBar.svelte` with animated progress bar
- Shows "Processing prompt: X/Y tokens" with percentage
- Progress bar disappears when first token arrives

## Why It Works

MLX's `stream_generate()` accepts a `prompt_progress_callback(processed,
total)` that's called after each prefill chunk. By sending events
directly from this callback (rather than yielding from the generator),
progress updates are sent in real-time during prefill.

Using SSE named events (`event: prefill_progress`) maintains full
OpenAI/Claude API compatibility - standard clients ignore named events
they don't recognize, while the exo dashboard explicitly listens for
them.

## Test Plan

### Manual Testing
- Hardware: MacBook Pro M3 Max
- Set `prefill_step_size=256` for more frequent updates
- Tested with long prompts (pasted large documents)
- Verified progress bar updates incrementally during prefill
- Confirmed progress bar disappears when generation starts
- Tested with curl - standard `data:` events still work normally

Here is it working:


https://github.com/user-attachments/assets/5cc6f075-c5b2-4a44-bb4d-9efb246bc5fe


### Automated Testing
- Type checker passes (0 errors)
- All 192 tests pass
- Dashboard builds successfully

### API Compatibility
- Named SSE events are ignored by OpenAI SDK clients
- Regular token data uses standard `data: {...}` format
- `[DONE]` sentinel works as expected

---

**Note:** `prefill_step_size` is temporarily set to 256 for testing.
Should be changed back to 2048 before merging for production
performance.

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-19 03:18:25 +00:00
rltakashige
19bc09550d Add status=downloaded filter for model endpoint (#1539)
## Motivation

https://github.com/exo-explore/exo/issues/1346#issuecomment-3831427905


## Test Plan

### Manual Testing
**Without filter**
<img width="1708" height="1010" alt="Screenshot 2026-02-18 at 22 26 22"
src="https://github.com/user-attachments/assets/f4bf7142-717d-4042-ac28-d8a55a8e45e7"
/>

**With filter**
<img width="1723" height="1021" alt="Screenshot 2026-02-18 at 22 26 45"
src="https://github.com/user-attachments/assets/40a522d5-c6e6-4148-b21a-02caa1221ebe"
/>
2026-02-18 22:34:11 +00:00
Alex Cheema
7cadca4f27 Try multiple endpoints for internet connectivity check (#1516)
## Summary
- `_test_internet_connection()` previously only tried `1.1.1.1:443`,
which some ISPs/networks block, causing exo to incorrectly report no
internet and fail downloads on startup
- Now tries `1.1.1.1`, `8.8.8.8`, and `1.0.0.1` in sequence, succeeding
if any endpoint responds
- Returns early on first success for minimal latency in the common case

Fixes #1425

## Test plan
- [ ] Verify downloads work on networks that block `1.1.1.1`
- [ ] Verify existing behavior unchanged on networks where `1.1.1.1`
works
- [ ] Verify `internet_connection` is set to `False` only when all three
endpoints fail

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 22:10:07 +00:00
rltakashige
24e99ce197 Cleanup mistakes (#1537)
Oops
2026-02-18 22:05:26 +00:00
Alex Cheema
315992549b fix: unblock MpReceiver.close() to prevent shutdown hang (#1511)
## Summary

- `MpReceiver.close()` did not unblock threads stuck on `queue.get()` in
`receive_async()`, causing abandoned threads (via
`abandon_on_cancel=True`) to keep the Python process alive indefinitely
after tests pass
- This caused the `aarch64-darwin` CI jobs in PR #1462 to hang for ~6
hours until the GitHub Actions timeout killed them
- Sends an `_MpEndOfStream` sentinel before closing the buffer,
mirroring what `MpSender.close()` already does

## Test plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — clean
- [x] `nix fmt` — 0 changed
- [x] `uv run pytest` — 188 passed, 1 skipped in 12s (no hang)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-18 21:59:02 +00:00
Alex Cheema
ce5a65d3b9 Add MiniMax M2.5 model cards (#1514)
## Summary
- Adds model cards for MiniMax M2.5 in three quantizations: 4bit (~129
GB), 6bit (~186 GB), 8bit (~243 GB)
- No code changes needed — `MiniMaxM2ForCausalLM` is already in the
tensor parallel whitelist and `MiniMaxShardingStrategy` is already
implemented in `auto_parallel.py`
- Credit to @vskiwi for confirming MiniMax M2.5 works out of the box
with existing code

Closes #1480

## Test plan
- [x] `basedpyright` passes with 0 errors
- [x] `ruff check` passes
- [x] `pytest` passes (260 passed, 1 skipped)
- [ ] Verify MiniMax M2.5 models appear in model selector on dashboard

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 21:11:13 +00:00
rltakashige
c2f2111b88 Fix tool calling (#1529)
## Motivation

GPT OSS tool calling issues.

## Changes

Fixes those and adds a bunch of evals for tool calling.
Fixes GLM5 prefix caching, where CacheList wasn't getting handled
properly.
Extracts a bunch of the setup functionality of exo bench to a harness
that can be reused elsewhere, such as in the tool calling eval.

## Test Plan
### Automated Testing
Let's run the evals for all models
2026-02-18 20:29:18 +00:00
Alex Cheema
6c322ebb72 feat: only show thinking toggle for models that support it (#1497)
## Summary
- Adds `thinking_toggle` capability to 26 model cards that support
toggling thinking mode on/off
- GPT-OSS models (20b, 120b) excluded — they always think and don't
support toggling
- Dashboard UI updated to check for `thinking_toggle` capability before
showing the toggle button

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 0 failed
- [x] Security review passed (no secrets, eval/exec, innerHTML, or dep
changes)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 17:05:00 +00:00
vskiwi
2ebe6216b4 feat: add explicit --offline mode for air-gapped clusters (#1525)
## Motivation

Closes #1510

There is currently no reliable way to run exo on an air-gapped or offline cluster where models are pre-staged on local disks. The two existing mechanisms — `--no-downloads` and `HF_HUB_OFFLINE=1` — each cover only a subset of the problem:

1. **`--no-downloads` blocks model loading**: When passed, `DownloadCoordinator` is not created. No `NodeDownloadProgress` events are ever emitted, so `_model_needs_download()` in `plan.py` perpetually returns `DownloadModel`, short-circuiting `_load_model()` and preventing the model from ever being loaded.

2. **`HF_HUB_OFFLINE=1` doesn't cover exo's aiohttp code**: exo's download pipeline primarily uses raw `aiohttp` for HTTP operations (file list fetching, file downloads, HEAD verification), not the `huggingface_hub` library. These calls will attempt connections and time out on air-gapped networks.

3. **`skip_internet` is not propagated to `download_file_with_retry()`**: Even when `internet_connection = False`, the `_download_file()` function still makes HTTP HEAD calls via `file_meta()` to verify local files and unconditionally attempts downloads for missing files.

## Changes

### `src/exo/main.py`
- Add `--offline` flag to `Args` with env var detection (`EXO_OFFLINE=1`, `HF_HUB_OFFLINE=1`)
- Pass `offline` to `DownloadCoordinator` at creation and re-creation (election loop)

### `src/exo/download/coordinator.py`
- Add `offline: bool = False` field
- In offline mode: set `internet_connection = False` immediately in `__post_init__`, skip `_test_internet_connection()` ping (avoids 3s timeout), skip `_check_internet_connection` periodic loop
- In `_start_download()`: if model is not fully available locally, emit `DownloadFailed` with clear message instead of starting a download task

### `src/exo/download/download_utils.py`
- Add `skip_internet: bool` parameter to `download_file_with_retry()` and `_download_file()`
- When `skip_internet=True` in `_download_file()`: return local file immediately without HTTP HEAD verification; raise `FileNotFoundError` for missing files
- Propagate `skip_internet` from `download_shard()` to `download_file_with_retry()`

### `src/exo/download/tests/test_offline_mode.py` (new)
- 8 tests covering `_download_file`, `download_file_with_retry`, and `fetch_file_list_with_cache` in offline mode

## Why It Works

Unlike `--no-downloads` which disables `DownloadCoordinator` entirely, `--offline` keeps the coordinator running in a restricted mode. The existing `_emit_existing_download_progress()` disk scanner still runs every 60 seconds, emitting `DownloadCompleted` events for pre-staged models. These events flow through the event-sourcing pipeline and populate `state.downloads`, which unblocks `_model_needs_download()` in `plan.py` — no changes to the planning logic required.

```
--offline flag
  → DownloadCoordinator (offline mode)
    → Skip 1.1.1.1 ping, internet_connection = False
    → _emit_existing_download_progress scans disk
      → Emits DownloadCompleted for pre-staged models
        → _model_needs_download sees DownloadCompleted
          → _load_model proceeds normally
```

## Test Plan

### Automated Testing
- `ruff check` — passes
- 8 new tests in `test_offline_mode.py` — all pass
- 11 existing download tests in `test_download_verification.py` — all pass (no regressions)

### Manual Testing
1. Pre-stage a model on disk (e.g., `~/.exo/models/mlx-community--Qwen3-0.6B-4bit/`)
2. Start exo with `--offline` (or `EXO_OFFLINE=1`)
3. Place an instance via API or dashboard
4. Verify: model loads into memory and inference works without any network calls

### Environment
- macOS (Apple Silicon), multi-node cluster with Thunderbolt interconnect
- Models pre-staged via rsync / NFS mount
2026-02-18 16:18:09 +00:00
ciaranbor
f54c80b121 Ciaran/image edit api (#1500)
## Motivation

- Image editing previously ignored input image dimensions, always
defaulting to 1024x1024
- Size dropdown was hidden in edit mode, giving users no control over
output dimensions
- Portrait/landscape presets used non-standard aspect ratios (1024x1365
/ 1365x1024)

## Changes

- Added "auto" size option that uses input image dimensions for edits,
defaults to 1024x1024 for generation
- Introduced ImageSize Literal type and normalize_image_size() validator
(replaces raw str size fields)
  - Updated portrait/landscape presets to standard 1024x1536 / 1536x1024
  - Made size selector visible in edit mode (previously hidden)
  - Default size changed from "1024x1024" to "auto"

## Why It Works

- "auto" reads actual input image dimensions via PIL at generation time,
so edits preserve the original aspect ratio
- Pydantic field_validator on both ImageGenerationTaskParams and
ImageEditsTaskParams normalizes None → "auto", keeping the API
backward-compatible

## Test Plan

### Manual Testing

- Verify image edits output at the input image's native resolution when
size is "auto"
- Verify size dropdown appears and works in both generate and edit modes
2026-02-18 16:05:39 +00:00
rltakashige
48b8f86395 Add support for GLM 5 (#1526)
## Motivation

Add GLM 5 support in favor of #1513 

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-18 14:04:06 +00:00
Evan
5cbd6377a2 prioritize official model cards over custom model cards
our old model card search path would override official model cards with
custom model cards - our packaged model cards should always be the
default here
2026-02-18 13:20:05 +00:00
Evan Quiney
8f01523ddb remove dead code (#1496) 2026-02-18 11:43:27 +00:00
Alex Cheema
3addeadea8 Update mlx-lm to 0.30.7 (#1520)
## Summary
- Bumps `mlx-lm` from 0.30.6 to 0.30.7 in `pyproject.toml` and `uv.lock`

## Test plan
- [x] `uv lock` resolves successfully
- [x] `basedpyright` — no new errors (63 pre-existing in unrelated
`test_tool_call_tracker.py`)
- [x] `ruff check` — all checks passed
- [x] `nix fmt` — no formatting changes
- [x] `pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 11:14:23 +00:00
rltakashige
f2be929211 Leo/address rdma gpu locks 2 (#1515)
Same as #1489 . Had to revert and redo thanks to Claude.

---------

Co-authored-by: Jake Hillion <jake@hillion.co.uk>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:00:52 -08:00
rltakashige
83af8c63fa Revert "Use custom fork that resolves GPU locks" (#1502)
Reverts exo-explore/exo#1489

Goddammit Claude...
2026-02-17 18:18:54 +00:00
Evan Quiney
eccc6298d1 Revert "Add MetaInstance declarative layer (#1447)"
This reverts commit a962a28afc.
2026-02-17 18:11:47 +00:00
161 changed files with 8552 additions and 6204 deletions

View File

@@ -200,7 +200,7 @@ class Module(dict):
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
"""Return the submodules that do not contain other modules."""
def update(self, parameters: dict, strict: bool = ...) -> Module:
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.

View File

@@ -7,7 +7,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mlx.core import MX_ARRAY_TREE
def tree_map(
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[..., bool] | None = ...,
) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -44,11 +47,11 @@ def tree_map(
"""
def tree_map_with_path(
fn: Callable,
fn: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Optional[Callable] = ...,
path: Optional[Any] = ...,
is_leaf: Callable[..., bool] | None = ...,
path: str | None = ...,
) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -80,9 +83,9 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
prefix: str = ...,
is_leaf: Optional[Callable] = ...,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
is_leaf: Callable[..., bool] | None = ...,
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
) -> list[tuple[str, Any]] | dict[str, Any]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -118,7 +121,7 @@ def tree_flatten(
the Python tree.
"""
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python

View File

@@ -0,0 +1,46 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

136
Cargo.lock generated
View File

@@ -141,12 +141,6 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.7.1"
@@ -304,19 +298,6 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
[[package]]
name = "bigdecimal"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
dependencies = [
"autocfg",
"libm",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.3"
@@ -516,15 +497,6 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
[[package]]
name = "convert_case"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -746,29 +718,6 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_more"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.111",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -939,22 +888,17 @@ name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"env_logger",
"extend",
"futures",
"impl-trait-for-tuples",
"futures-lite",
"libp2p",
"log",
"networking",
"once_cell",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"thiserror 2.0.17",
"thread_local",
"tokio",
"util",
]
@@ -970,6 +914,12 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
version = "0.13.1"
@@ -1078,7 +1028,10 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
@@ -1640,17 +1593,6 @@ dependencies = [
"xmltree",
]
[[package]]
name = "impl-trait-for-tuples"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "indexmap"
version = "2.12.1"
@@ -1829,12 +1771,6 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libp2p"
version = "0.56.0"
@@ -2824,16 +2760,13 @@ name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-lite",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"libp2p",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
@@ -2918,17 +2851,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -3279,28 +3201,14 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
dependencies = [
"bigdecimal",
"either",
"hashbrown 0.16.1",
"indexmap",
"indoc",
"inventory",
"libc",
"lock_api",
"memoffset",
"num-bigint",
"num-complex",
"num-rational",
"num-traits",
"once_cell",
"ordered-float",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"rust_decimal",
"smallvec",
"unindent",
]
@@ -3741,16 +3649,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -4615,24 +4513,12 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"

View File

@@ -26,49 +26,20 @@ opt-level = 3
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-util = "0.3"
futures-lite = "2.6.1"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"

View File

@@ -72,12 +72,19 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly:
```bash
nix run .#exo
```
**Note:** To accept the Cachix binary cache (and avoid the Xcode Metal ToolChain), add to `/etc/nix/nix.conf`:
```
trusted-users = root (or your username)
experimental-features = nix-command flakes
```
Then restart the Nix daemon: `sudo launchctl kickstart -k system/org.nixos.nix-daemon`
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)

1104
bench/eval_tool_calls.py Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,29 +1,48 @@
# type: ignore
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""Tool-calling eval for exo's OpenAI-compatible API.
Tests whether models correctly:
- Trigger tool calls when appropriate
- Return valid JSON arguments matching function schemas
- Handle multi-turn tool use (call -> result -> final answer)
- Avoid calling tools when unnecessary
Start exo with a model first, then run:
uv run python tool_call_eval.py --model <model-id>
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
uv run python tool_call_eval.py --model <model-id> --repeat 3
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
"""
from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from harness import (
ExoClient,
ExoHttpError,
add_common_instance_args,
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
)
from loguru import logger
from transformers import AutoTokenizer
# Backoff constants for cluster settling retry
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
@@ -103,154 +122,6 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
@@ -269,184 +140,6 @@ def parse_int_list(values: list[str]) -> list[int]:
return items
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("name").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
@@ -538,76 +231,12 @@ class PromptSizer:
return content, tok
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
add_common_instance_args(ap)
ap.add_argument(
"--pp",
nargs="+",
@@ -620,34 +249,6 @@ def main() -> int:
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
@@ -657,9 +258,6 @@ def main() -> int:
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
@@ -674,17 +272,6 @@ def main() -> int:
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -719,24 +306,10 @@ def main() -> int:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
selected = settle_and_fetch_placements(
client, full_model_id, args, settle_timeout=args.settle_timeout
)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_deadline:
backoff = _SETTLE_INITIAL_BACKOFF_S
while not selected and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
@@ -760,8 +333,12 @@ def main() -> int:
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
download_duration_s = run_planning_phase(
client,
full_model_id,
selected[0],
@@ -769,6 +346,10 @@ def main() -> int:
args.timeout,
settle_deadline,
)
if download_duration_s is not None:
logger.info(f"Download: {download_duration_s:.1f}s (freshly downloaded)")
else:
logger.info("Download: model already cached")
all_rows: list[dict[str, Any]] = []
@@ -832,6 +413,11 @@ def main() -> int:
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
**(
{"download_duration_s": download_duration_s}
if download_duration_s is not None
else {}
),
}
)
runs.append(row)

489
bench/harness.py Normal file
View File

@@ -0,0 +1,489 @@
# type: ignore
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from typing import Any
from urllib.parse import urlencode
from loguru import logger
_SETTLE_INITIAL_BACKOFF_S = 1.0
_SETTLE_MAX_BACKOFF_S = 60.0
_SETTLE_BACKOFF_MULTIPLIER = 2.0
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if (m.get("name") or "").lower() == model_arg.lower():
short_id = str(m["name"])
full_id = str(m.get("hugging_face_id") or m["name"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["name"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def fetch_and_filter_placements(
client: ExoClient, full_model_id: str, args: argparse.Namespace
) -> list[dict[str, Any]]:
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if (
args.skip_tensor_ring
and (
args.instance_meta == "both"
and "ring" in p.get("instance_meta", "").lower()
)
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
return selected
def settle_and_fetch_placements(
client: ExoClient,
full_model_id: str,
args: argparse.Namespace,
settle_timeout: float = 0,
) -> list[dict[str, Any]]:
selected = fetch_and_filter_placements(client, full_model_id, args)
if not selected and settle_timeout > 0:
backoff = _SETTLE_INITIAL_BACKOFF_S
deadline = time.monotonic() + settle_timeout
while not selected and time.monotonic() < deadline:
remaining = deadline - time.monotonic()
logger.warning(
f"No valid placements yet (cluster may still be settling). "
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
selected = fetch_and_filter_placements(client, full_model_id, args)
return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> float | None:
"""Check disk space and ensure model is downloaded before benchmarking.
Returns the wall-clock download duration in seconds if a fresh download
was needed, or None if the model was already cached on all nodes.
"""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return None
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
needs_download = False
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
needs_download = True
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest (skip read-only models from EXO_MODELS_PATH)
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
and not p["DownloadCompleted"].get("readOnly", False)
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
download_t0 = time.perf_counter() if needs_download else None
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
if download_t0 is not None:
return time.perf_counter() - download_t0
return None
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Skip pipeline+jaccl placements, as it's often pointless.",
)
ap.add_argument(
"--skip-tensor-ring",
action="store_true",
help="Skip tensor+ring placements, as it's so slow.",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--settle-timeout",
type=float,
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
description = "Benchmarking tool for exo distributed inference"
requires-python = ">=3.13"
dependencies = [
"httpx>=0.27.0",
"loguru>=0.7.3",
"transformers>=5.0.0",
"huggingface-hub>=0.33.4",

306
bench/scenarios.toml Normal file
View File

@@ -0,0 +1,306 @@
# Tool definitions — each becomes an OpenAI function tool.
# All scenarios get all tools unless they specify a `tools` list.
[tools.get_current_weather]
description = "Get the current weather in a given location"
required = ["location"]
[tools.get_current_weather.properties.location]
type = "string"
description = "City and state, e.g. San Francisco, CA"
[tools.get_current_weather.properties.unit]
type = "string"
enum = ["celsius", "fahrenheit"]
description = "Temperature unit"
[tools.calculate]
description = "Evaluate a mathematical expression and return the numeric result"
required = ["expression"]
[tools.calculate.properties.expression]
type = "string"
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
[tools.search_products]
description = "Search for products in a catalog by query, category, and price"
required = ["query"]
[tools.search_products.properties.query]
type = "string"
description = "Search query string"
[tools.search_products.properties.category]
type = "string"
enum = ["electronics", "clothing", "food", "books"]
description = "Product category to filter by"
[tools.search_products.properties.max_price]
type = "number"
description = "Maximum price in USD"
[tools.create_todos]
description = "Create a structured todo list"
required = ["todos"]
[tools.create_todos.properties.todos]
type = "array"
description = "List of todo items"
[tools.create_todos.properties.todos.items]
type = "object"
required = ["content", "status", "priority"]
[tools.create_todos.properties.todos.items.properties.content]
type = "string"
description = "The todo item text"
[tools.create_todos.properties.todos.items.properties.status]
type = "string"
description = "Status: pending, in_progress, or completed"
[tools.create_todos.properties.todos.items.properties.priority]
type = "string"
description = "Priority: low, normal, or high"
# -- Should call a tool --
[[scenarios]]
name = "weather_simple"
description = "Basic weather query -> get_current_weather"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "What's the weather like in Tokyo right now?"
[[scenarios]]
name = "calculator_simple"
description = "Math question -> calculate"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 3847 * 926 + 17293"
[[scenarios]]
name = "search_with_filters"
description = "Product search with category and price filter"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.messages]]
role = "user"
content = "Find me electronics under $50"
# -- Multi-turn: tool call then follow-up --
[[scenarios]]
name = "weather_multi_turn"
description = "Weather query -> tool result -> natural language summary"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[scenarios.tool_result]
temperature = "18C"
condition = "partly cloudy"
humidity = "65%"
wind = "12 km/h NW"
[[scenarios.messages]]
role = "user"
content = "What's the weather in Paris?"
[[scenarios]]
name = "calculator_multi_turn"
description = "Math query -> tool result -> model reports the answer"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[scenarios.tool_result]
result = 491682
[[scenarios.messages]]
role = "user"
content = "Use the calculator to compute 1847 * 263 + 5921"
[[scenarios]]
name = "search_multi_turn"
description = "Search query -> tool result -> model summarizes products"
expect_tool_call = true
expected_function = "search_products"
required_arg_keys = ["query"]
[[scenarios.tool_result.results]]
name = "Hands-On Machine Learning"
price = 45.99
rating = 4.8
[[scenarios.tool_result.results]]
name = "Deep Learning with Python"
price = 39.99
rating = 4.6
[[scenarios.messages]]
role = "user"
content = "Search for books about machine learning"
# -- Sequential tool calls --
[[scenarios]]
name = "chained_tool_calls_same"
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare the weather in Tokyo and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check both cities. Let me start with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_1"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_1"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios]]
name = "chained_tool_calls_different"
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
expect_tool_call = true
expected_function = "calculate"
required_arg_keys = ["expression"]
[[scenarios.messages]]
role = "user"
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
[[scenarios.messages]]
role = "assistant"
content = "I'll handle both. Let me check Berlin's weather first."
[[scenarios.messages.tool_calls]]
id = "call_2"
name = "get_current_weather"
arguments = { location = "Berlin" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_2"
content = '{"temperature": "12C", "condition": "rainy"}'
[[scenarios]]
name = "chained_tool_calls_three"
description = "Two prior thinking+tool calls -> results -> model must make a third"
expect_tool_call = true
expected_function = "get_current_weather"
required_arg_keys = ["location"]
[[scenarios.messages]]
role = "user"
content = "Compare weather in Tokyo, Paris, and London."
[[scenarios.messages]]
role = "assistant"
content = "I'll check all three cities. Starting with Tokyo."
[[scenarios.messages.tool_calls]]
id = "call_3"
name = "get_current_weather"
arguments = { location = "Tokyo" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_3"
content = '{"temperature": "25C", "condition": "sunny"}'
[[scenarios.messages]]
role = "assistant"
content = "Got Tokyo. Now checking Paris."
[[scenarios.messages.tool_calls]]
id = "call_4"
name = "get_current_weather"
arguments = { location = "Paris" }
[[scenarios.messages]]
role = "tool"
tool_call_id = "call_4"
content = '{"temperature": "18C", "condition": "cloudy"}'
# -- Nested object schema (regression for lossy chat template rendering) --
[[scenarios]]
name = "nested_schema_tool_call"
description = "Tool call with nested object array schema -> create_todos"
expect_tool_call = true
expected_function = "create_todos"
required_arg_keys = ["todos"]
nested_array_key = "todos"
required_item_keys = ["content", "status", "priority"]
tools = ["create_todos"]
[[scenarios.messages]]
role = "user"
content = "Create a todo list with 3 items to learn Python"
# -- Tool name integrity (regression for harmony token leaking into name) --
[tools.glob]
description = "Search for files matching a glob pattern in the codebase"
required = ["pattern"]
[tools.glob.properties.pattern]
type = "string"
description = "The glob pattern to match files against, e.g. '**/*.py'"
[tools.glob.properties.path]
type = "string"
description = "The directory to search in"
[[scenarios]]
name = "tool_name_integrity"
description = "Tool name must not contain harmony tokens like <|channel|>"
expect_tool_call = true
expected_function = "glob"
required_arg_keys = ["pattern"]
tools = ["glob"]
[[scenarios.messages]]
role = "user"
content = "Find all Python files in the src directory"
# -- Should NOT call a tool --
[[scenarios]]
name = "no_tool_joke"
description = "Joke request should NOT trigger any tool"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "Tell me a funny joke about cats."
[[scenarios]]
name = "no_tool_factual"
description = "Factual question answerable from training data"
expect_tool_call = false
[[scenarios.messages]]
role = "user"
content = "What is the capital of Japan?"

View File

@@ -14,6 +14,7 @@
totalTokens,
thinkingEnabled as thinkingEnabledStore,
setConversationThinking,
stopGeneration,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -103,7 +104,7 @@
const modelSupportsThinking = $derived(() => {
if (!currentModel) return false;
const caps = modelCapabilities[currentModel] || [];
return caps.includes("thinking") && caps.includes("text");
return caps.includes("thinking_toggle") && caps.includes("text");
});
const isEditOnlyWithoutImage = $derived(
@@ -653,86 +654,92 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
{#if loading}
<button
type="button"
onclick={() => stopGeneration()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-exo-medium-gray hover:text-white"
aria-label="Stop generation"
>
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
class="w-3 h-3 sm:w-3.5 sm:h-3.5"
fill="currentColor"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
<rect x="6" y="6" width="12" height="12" rx="1" />
</svg>
<span>EDIT</span>
<span class="hidden sm:inline">Cancel</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
</div>
<!-- Bottom accent line -->

View File

@@ -3,16 +3,17 @@
messages,
currentResponse,
isLoading,
prefillProgress,
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import ImageLightbox from "./ImageLightbox.svelte";
interface Props {
@@ -25,6 +26,7 @@
const messageList = $derived(messages());
const response = $derived(currentResponse());
const loading = $derived(isLoading());
const prefill = $derived(prefillProgress());
// Scroll management - user controls scroll, show button when not at bottom
const SCROLL_THRESHOLD = 100;
@@ -428,6 +430,9 @@
{:else}
<!-- Assistant message styling -->
<div class="p-3 sm:p-4">
{#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}
<PrefillProgressBar progress={prefill} class="mb-3" />
{/if}
{#if message.thinking && message.thinking.trim().length > 0}
<div
class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40"

View File

@@ -26,7 +26,8 @@
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
function formatNumber(num: number | undefined): string {
if (num == null) return "0";
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {

View File

@@ -59,13 +59,14 @@
}
const sizeOptions: ImageGenerationParams["size"][] = [
"auto",
"512x512",
"768x768",
"1024x1024",
"1024x768",
"768x1024",
"1024x1365",
"1365x1024",
"1024x1536",
"1536x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
@@ -176,92 +177,90 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
{params.size.toUpperCase()}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size.toUpperCase()}</span>
</button>
{/each}
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
{/if}
</div>
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -311,7 +310,7 @@
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {qualityDropdownPosition()
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
>

View File

@@ -0,0 +1,70 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
const etaText = $derived.by(() => {
if (progress.processed <= 0 || progress.total <= 0) return null;
const elapsedMs = performance.now() - progress.startedAt;
if (elapsedMs < 200) return null; // need a minimum sample window
const tokensPerMs = progress.processed / elapsedMs;
const remainingTokens = progress.total - progress.processed;
const remainingMs = remainingTokens / tokensPerMs;
const remainingSec = Math.ceil(remainingMs / 1000);
if (remainingSec <= 0) return null;
if (remainingSec < 60) return `~${remainingSec}s remaining`;
const mins = Math.floor(remainingSec / 60);
const secs = remainingSec % 60;
return `~${mins}m ${secs}s remaining`;
});
function formatTokenCount(count: number | undefined): string {
if (count == null) return "0";
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
>
<span>{etaText ?? ""}</span>
<span>{percentage}%</span>
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
export interface PlacementPreview {
model_id: string;
sharding: "Pipeline" | "Tensor";
instance_meta: "MlxRing" | "MlxJaccl";
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
instance: unknown | null;
memory_delta_by_node: Record<string, number> | null;
error: string | null;
@@ -219,6 +219,7 @@ interface RawStateResponse {
string,
{
MlxRingInstance?: Instance;
MlxIbvInstance?: Instance;
MlxJacclInstance?: Instance;
}
>;
@@ -249,20 +250,11 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// MetaInstances (declarative instance constraints)
metaInstances?: Record<string, MetaInstanceData>;
}
export interface MetaInstanceData {
metaInstanceId: string;
modelId: string;
sharding: string;
instanceMeta: string;
minNodes: number;
nodeIds: string[] | null;
placementError: string | null;
consecutiveFailures: number;
lastFailureError: string | null;
// Disk usage per node
nodeDisk?: Record<
string,
{ total: { inBytes: number }; available: { inBytes: number } }
>;
}
export interface MessageAttachment {
@@ -286,6 +278,13 @@ export interface TokenData {
topLogprobs: TopLogprob[];
}
export interface PrefillProgress {
processed: number;
total: number;
/** Timestamp (performance.now()) when prefill started. */
startedAt: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -319,13 +318,14 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
export interface ImageGenerationParams {
// Basic params
size:
| "auto"
| "512x512"
| "768x768"
| "1024x1024"
| "1024x768"
| "768x1024"
| "1024x1365"
| "1365x1024";
| "1024x1536"
| "1536x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
@@ -349,7 +349,7 @@ export interface EditingImage {
}
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "1024x1024",
size: "auto",
quality: "medium",
outputFormat: "png",
numImages: 1,
@@ -532,6 +532,10 @@ class AppStore {
ttftMs = $state<number | null>(null); // Time to first token in ms
tps = $state<number | null>(null); // Tokens per second
totalTokens = $state<number>(0); // Total tokens in current response
prefillProgress = $state<PrefillProgress | null>(null);
// Abort controller for stopping generation
private currentAbortController: AbortController | null = null;
// Topology state
topologyData = $state<TopologyData | null>(null);
@@ -550,7 +554,6 @@ class AppStore {
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
metaInstances = $state<Record<string, MetaInstanceData>>({});
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderbolt = $state<
Record<
@@ -909,7 +912,11 @@ class AppStore {
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {
@@ -1283,8 +1290,6 @@ class AppStore {
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
// RDMA ctl status per node
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
// MetaInstances
this.metaInstances = data.metaInstances ?? {};
// Thunderbolt bridge cycles
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
// Thunderbolt bridge status per node
@@ -1654,11 +1659,12 @@ class AppStore {
if (!reader) throw new Error("No response body");
let fullContent = prefixText;
let streamedThinking = "";
const collectedTokens: TokenData[] = [...tokensToKeep];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
delta?: { content?: string; reasoning_content?: string };
logprobs?: {
content?: Array<{
token: string;
@@ -1679,6 +1685,7 @@ class AppStore {
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
const thinkingDelta = choice?.delta?.reasoning_content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
@@ -1697,7 +1704,11 @@ class AppStore {
}
}
if (delta) {
if (thinkingDelta) {
streamedThinking += thinkingDelta;
}
if (delta || thinkingDelta) {
if (firstTokenTime === null) {
firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime;
@@ -1711,9 +1722,14 @@ class AppStore {
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
}
fullContent += delta;
const { displayContent, thinkingContent } =
if (delta) {
fullContent += delta;
}
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(fullContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
if (this.activeConversationId === targetConversationId) {
this.currentResponse = displayContent;
@@ -1725,7 +1741,7 @@ class AppStore {
messageId,
(m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.thinking = combinedThinking || undefined;
m.tokens = [...collectedTokens];
},
);
@@ -1737,11 +1753,14 @@ class AppStore {
// Final update
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(fullContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.thinking = finalThinking || undefined;
m.tokens = [...collectedTokens];
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
if (this.tps !== null) m.tps = this.tps;
@@ -1849,11 +1868,12 @@ class AppStore {
}
let streamedContent = "";
let streamedThinking = "";
const collectedTokens: TokenData[] = [];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
delta?: { content?: string; reasoning_content?: string };
logprobs?: {
content?: Array<{
token: string;
@@ -1874,6 +1894,7 @@ class AppStore {
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
const thinkingDelta = choice?.delta?.reasoning_content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
@@ -1892,10 +1913,19 @@ class AppStore {
}
}
if (delta) {
streamedContent += delta;
const { displayContent, thinkingContent } =
if (thinkingDelta) {
streamedThinking += thinkingDelta;
}
if (delta || thinkingDelta) {
if (delta) {
streamedContent += delta;
}
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
// Only update currentResponse if target conversation is active
if (this.activeConversationId === targetConversationId) {
@@ -1908,7 +1938,7 @@ class AppStore {
assistantMessage.id,
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.thinking = combinedThinking || undefined;
msg.tokens = [...collectedTokens];
},
);
@@ -1920,14 +1950,17 @@ class AppStore {
// Final cleanup of the message (if conversation still exists)
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.thinking = finalThinking || undefined;
msg.tokens = [...collectedTokens];
},
);
@@ -2016,6 +2049,7 @@ class AppStore {
reader: ReadableStreamDefaultReader<Uint8Array>,
targetConversationId: string,
onChunk: (parsed: T) => void,
onEvent?: Record<string, (data: unknown) => void>,
): Promise<void> {
const decoder = new TextDecoder();
let buffer = "";
@@ -2036,6 +2070,24 @@ class AppStore {
const trimmed = line.trim();
if (!trimmed) continue;
// Handle SSE comments (": key json") for prefill progress etc.
if (trimmed.startsWith(": ") && onEvent) {
const comment = trimmed.slice(2);
const spaceIdx = comment.indexOf(" ");
if (spaceIdx > 0) {
const key = comment.slice(0, spaceIdx);
if (onEvent[key]) {
try {
const parsed = JSON.parse(comment.slice(spaceIdx + 1));
onEvent[key](parsed);
} catch {
// Skip malformed JSON in comment
}
}
}
continue;
}
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
@@ -2267,6 +2319,9 @@ class AppStore {
let firstTokenTime: number | null = null;
let tokenCount = 0;
const abortController = new AbortController();
this.currentAbortController = abortController;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
@@ -2283,6 +2338,7 @@ class AppStore {
enable_thinking: enableThinking,
}),
}),
signal: abortController.signal,
});
if (!response.ok) {
@@ -2296,10 +2352,11 @@ class AppStore {
}
let streamedContent = "";
let streamedThinking = "";
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
delta?: { content?: string; reasoning_content?: string };
logprobs?: {
content?: Array<{
token: string;
@@ -2320,8 +2377,14 @@ class AppStore {
reader,
targetConversationId,
(parsed) => {
// Clear prefill progress when first token data arrives
if (this.prefillProgress) {
this.prefillProgress = null;
}
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
const thinkingContent = choice?.delta?.reasoning_content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
@@ -2340,7 +2403,11 @@ class AppStore {
}
}
if (tokenContent) {
if (thinkingContent) {
streamedThinking += thinkingContent;
}
if (tokenContent || thinkingContent) {
// Track first token for TTFT
if (firstTokenTime === null) {
firstTokenTime = performance.now();
@@ -2357,11 +2424,16 @@ class AppStore {
this.tps = (tokenCount / elapsed) * 1000;
}
streamedContent += tokenContent;
if (tokenContent) {
streamedContent += tokenContent;
}
// Strip thinking tags for display and extract thinking content
const { displayContent, thinkingContent } =
// Use stripThinkingTags as fallback for any <think> tags still in content
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
// Only update currentResponse if target conversation is active
if (this.activeConversationId === targetConversationId) {
@@ -2374,7 +2446,7 @@ class AppStore {
assistantMessage.id,
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.thinking = combinedThinking || undefined;
msg.tokens = [...collectedTokens];
},
);
@@ -2382,8 +2454,27 @@ class AppStore {
this.persistConversation(targetConversationId);
}
},
{
prefill_progress: (data) => {
// TaggedModel wraps as {"PrefillProgressChunk": {...}}
// model_dump_json() uses snake_case (by_alias defaults to False)
const raw = data as Record<string, unknown>;
const inner = (raw["PrefillProgressChunk"] ?? raw) as {
processed_tokens: number;
total_tokens: number;
};
this.prefillProgress = {
processed: inner.processed_tokens,
total: inner.total_tokens,
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
};
},
},
);
// Clear prefill progress after stream ends
this.prefillProgress = null;
// Calculate final TPS
if (firstTokenTime !== null && tokenCount > 1) {
const totalGenerationTime = performance.now() - firstTokenTime;
@@ -2392,14 +2483,17 @@ class AppStore {
// Final cleanup of the message (if conversation still exists)
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.thinking = finalThinking || undefined;
msg.tokens = [...collectedTokens];
// Store performance metrics on the message
if (this.ttftMs !== null) {
@@ -2414,20 +2508,31 @@ class AppStore {
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
if (error instanceof DOMException && error.name === "AbortError") {
// User stopped generation — not an error
} else {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
}
} finally {
this.currentAbortController = null;
this.prefillProgress = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
}
}
stopGeneration(): void {
this.currentAbortController?.abort();
this.currentAbortController = null;
}
/**
* Generate an image using the image generation API
*/
@@ -3054,9 +3159,9 @@ export const isLoading = () => appStore.isLoading;
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const prefillProgress = () => appStore.prefillProgress;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const metaInstances = () => appStore.metaInstances;
export const runners = () => appStore.runners;
export const downloads = () => appStore.downloads;
export const nodeDisk = () => appStore.nodeDisk;
@@ -3072,6 +3177,7 @@ export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
// Actions
export const stopGeneration = () => appStore.stopGeneration();
export const startChat = () => appStore.startChat();
export const sendMessage = (
content: string,

View File

File diff suppressed because it is too large Load Diff

View File

@@ -29,7 +29,12 @@
etaMs: number;
modelDirectory?: string;
}
| { kind: "pending"; modelDirectory?: string }
| {
kind: "pending";
downloaded: number;
total: number;
modelDirectory?: string;
}
| { kind: "failed"; modelDirectory?: string }
| { kind: "not_present" };
@@ -74,7 +79,6 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;
@@ -231,23 +235,14 @@
undefined;
let cell: CellStatus;
if (tag === "DownloadCompleted") {
const totalBytes = getBytes(
payload.total_bytes ?? payload.totalBytes,
);
const totalBytes = getBytes(payload.total);
cell = { kind: "completed", totalBytes, modelDirectory };
} else if (tag === "DownloadOngoing") {
const rawProgress =
payload.download_progress ?? payload.downloadProgress ?? {};
const prog = rawProgress as Record<string, unknown>;
const totalBytes = getBytes(
prog.total_bytes ??
prog.totalBytes ??
payload.total_bytes ??
payload.totalBytes,
);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const totalBytes = getBytes(prog.total ?? payload.total);
const downloadedBytes = getBytes(prog.downloaded);
const speed = (prog.speed as number) ?? 0;
const etaMs =
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;
@@ -265,7 +260,20 @@
} else if (tag === "DownloadFailed") {
cell = { kind: "failed", modelDirectory };
} else {
cell = { kind: "pending", modelDirectory };
const downloaded = getBytes(
payload.downloaded ??
payload.downloaded_bytes ??
payload.downloadedBytes,
);
const total = getBytes(
payload.total ?? payload.total_bytes ?? payload.totalBytes,
);
cell = {
kind: "pending",
downloaded,
total,
modelDirectory,
};
}
const existing = row.cells[nodeId];
@@ -275,14 +283,51 @@
}
}
function rowSortKey(row: ModelRow): number {
// in progress (4) -> completed (3) -> paused (2) -> not started (1) -> not present (0)
let best = 0;
for (const cell of Object.values(row.cells)) {
let score = 0;
if (cell.kind === "downloading") score = 4;
else if (cell.kind === "completed") score = 3;
else if (cell.kind === "pending" && cell.downloaded > 0)
score = 2; // paused
else if (cell.kind === "pending" || cell.kind === "failed") score = 1; // not started
if (score > best) best = score;
}
return best;
}
function totalCompletedBytes(row: ModelRow): number {
let total = 0;
for (const cell of Object.values(row.cells)) {
if (cell.kind === "completed") total += cell.totalBytes;
}
return total;
}
const rows = Array.from(rowMap.values()).sort((a, b) => {
const aCompleted = Object.values(a.cells).filter(
(c) => c.kind === "completed",
).length;
const bCompleted = Object.values(b.cells).filter(
(c) => c.kind === "completed",
).length;
if (aCompleted !== bCompleted) return bCompleted - aCompleted;
const aPriority = rowSortKey(a);
const bPriority = rowSortKey(b);
if (aPriority !== bPriority) return bPriority - aPriority;
// Within completed or paused, sort by biggest size first
if (aPriority === 3 && bPriority === 3) {
const sizeDiff = totalCompletedBytes(b) - totalCompletedBytes(a);
if (sizeDiff !== 0) return sizeDiff;
}
if (aPriority === 2 && bPriority === 2) {
const aSize = Math.max(
...Object.values(a.cells).map((c) =>
c.kind === "pending" ? c.total : 0,
),
);
const bSize = Math.max(
...Object.values(b.cells).map((c) =>
c.kind === "pending" ? c.total : 0,
),
);
if (aSize !== bSize) return bSize - aSize;
}
return a.modelId.localeCompare(b.modelId);
});
@@ -492,9 +537,34 @@
{:else if cell.kind === "pending"}
<div
class="flex flex-col items-center gap-0.5"
title="Download pending"
title={cell.downloaded > 0
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded`
: "Download pending"}
>
<span class="text-exo-light-gray/50 text-sm">...</span>
{#if cell.downloaded > 0 && cell.total > 0}
<span class="text-exo-light-gray/70 text-[10px]"
>{formatBytes(cell.downloaded)} / {formatBytes(
cell.total,
)}</span
>
<div
class="w-full h-1 bg-white/10 rounded-full overflow-hidden"
>
<div
class="h-full bg-exo-light-gray/40 rounded-full"
style="width: {(
(cell.downloaded / cell.total) *
100
).toFixed(1)}%"
></div>
</div>
<span class="text-exo-light-gray/40 text-[9px]"
>paused</span
>
{:else}
<span class="text-exo-light-gray/50 text-sm">...</span
>
{/if}
</div>
{:else if cell.kind === "failed"}
<div

View File

@@ -74,7 +74,6 @@
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in

View File

@@ -41,7 +41,7 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260217+50487b41"; in
version = let v = "0.30.7.dev20260220+13998a05"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
@@ -49,8 +49,8 @@ let
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
rev = "13998a054715edcdc93618fb1496c79c7c25ff7c";
hash = "sha256-fAqA3hFwNBx7FcoGnhQsIFpAIRbC2EerACm4Fvne0Cc=";
};
patches = [

View File

@@ -19,7 +19,7 @@ dependencies = [
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"mlx-lm==0.30.7",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",

View File

@@ -158,6 +158,7 @@
exo-test-env = testVenv;
} // {
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
};

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 405874409472

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 765577920512

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 122406567936

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 229780750336

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 198556925568

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 286737579648

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 396963397248

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 19327352832

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 22548578304

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 26843545600

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 34359738368

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-8bit-MXFP8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 706522120192

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 662498705408

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 100086644736

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 242986745856

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-4bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "4bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 128666664960

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-6bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "6bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 185826705408

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/MiniMax-M2.5-8bit"
n_layers = 62
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 342884352

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 698351616

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 141733920768

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 268435456000

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 17612931072

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 33279705088

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Coder Next"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 45644286500

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "5bit"
base_model = "Qwen3 Coder Next"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 57657697020

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "6bit"
base_model = "Qwen3 Coder Next"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 68899327465

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Coder Next"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 89357758772

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "bf16"
base_model = "Qwen3 Coder Next"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 157548627945

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 47080074240

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 88814387200

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 114572190076

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 159039627774

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 209082699847

View File

@@ -1,2 +0,0 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -25,17 +25,17 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
# "nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
# "ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.17.2" }
@@ -45,33 +45,18 @@ pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
futures-lite = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -19,7 +19,7 @@ class ConnectionUpdate:
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
def peer_id(self) -> builtins.str:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@@ -40,92 +40,22 @@ class Keypair:
Identity keypair of a node.
"""
@staticmethod
def generate_ed25519() -> Keypair:
def generate() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def generate_ecdsa() -> Keypair:
def from_bytes(bytes: bytes) -> Keypair:
r"""
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
Construct an Ed25519 keypair from secret key bytes
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
Get the secret key bytes underlying the keypair
"""
def to_string(self) -> builtins.str:
def to_node_id(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
"""
@typing.final
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""

View File

@@ -2,11 +2,10 @@
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
pin::{Pin, pin},
pin::Pin,
task::{Context, Poll},
};
@@ -26,15 +25,13 @@ where
impl<F> Future for AllowThreads<F>
where
F: Future + Ungil,
F::Output: Ungil,
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
}
}

View File

@@ -1,240 +0,0 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -0,0 +1,47 @@
use crate::ext::ResultExt as _;
use libp2p::identity::Keypair;
use pyo3::types::{PyBytes, PyBytesMethods as _};
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Construct an Ed25519 keypair from secret key bytes
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Get the secret key bytes underlying the keypair
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
fn to_node_id(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}

View File

@@ -4,28 +4,14 @@
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
mod ident;
mod networking;
use crate::ident::PyKeypair;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -34,24 +20,11 @@ pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
@@ -62,7 +35,7 @@ pub(crate) mod ext {
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new(py, self).unbind())
Python::attach(|py| PyBytes::new(py, self).unbind())
}
}
@@ -98,7 +71,7 @@ pub(crate) mod ext {
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
Python::attach(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
@@ -175,24 +148,6 @@ pub(crate) mod ext {
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
@@ -204,8 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
m.add_class::<PyKeypair>()?;
networking_submodule(m)?;
// top-level constructs

View File

@@ -8,12 +8,12 @@
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair;
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
@@ -25,7 +25,7 @@ use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
use pyo3::types::PyTuple;
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
peer_id: String,
/// Remote connection's IPv4 address.
#[pyo3(get)]
@@ -155,7 +155,6 @@ async fn networking_task(
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
@@ -252,7 +251,7 @@ async fn networking_task(
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
@@ -273,7 +272,7 @@ async fn networking_task(
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
@@ -485,7 +484,7 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,

View File

@@ -1,159 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -1,8 +0,0 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -1,81 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -19,21 +19,14 @@ either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-lite = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
@@ -41,4 +34,4 @@ keccak-const = { workspace = true }
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,4 +1,4 @@
use futures::stream::StreamExt as _;
use futures_lite::StreamExt;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
@@ -38,19 +38,19 @@ async fn main() {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
event = swarm.next() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
}))) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
@@ -64,7 +64,7 @@ async fn main() {
}
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }

View File

@@ -1,127 +0,0 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -1,8 +1,7 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_lite::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
@@ -363,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
if self.retry_delay.poll(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)

View File

@@ -1,44 +0,0 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -3,19 +3,7 @@
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
@@ -54,11 +42,3 @@ pub(crate) mod ext {
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures::{AsyncRead, AsyncWrite};
use futures_lite::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;

View File

@@ -1,11 +1,10 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
{ inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
rustToolchain = inputs'.fenix.packages.stable.withComponents [
"cargo"
"rustc"
"clippy"

View File

@@ -1,2 +0,0 @@
[toolchain]
channel = "nightly"

View File

@@ -1,7 +1,7 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
from random import random
import anyio
from anyio import current_time
@@ -12,20 +12,24 @@ from exo.download.download_utils import (
RepoDownloadProgress,
delete_model,
map_repo_download_progress_to_download_progress_data,
resolve_model_in_path,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.constants import EXO_MODELS_DIR, EXO_MODELS_PATH
from exo.shared.models.model_cards import ModelId, get_model_cards
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
ForwarderEvent,
EventId,
# TODO(evan): just for acks, should delete this ASAP
GlobalForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -35,7 +39,7 @@ from exo.shared.types.worker.downloads import (
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
@@ -45,8 +49,15 @@ class DownloadCoordinator:
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
local_event_sender: Sender[LocalForwarderEvent]
# ack stuff
_global_event_receiver: Receiver[GlobalForwarderEvent]
_out_for_delivery: dict[EventId, LocalForwarderEvent] = field(default_factory=dict)
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -62,6 +73,8 @@ class DownloadCoordinator:
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
@@ -77,7 +90,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
total=progress.total,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -107,23 +120,36 @@ class DownloadCoordinator:
self._last_progress_time[model_id] = current_time()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
logger.info(
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
)
if not self.offline:
self._test_internet_connection()
try:
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._clear_ofd)
if not self.offline:
tg.start_soon(self._check_internet_connection)
finally:
for task in self.active_downloads.values():
task.cancel()
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
# Try multiple endpoints since some ISPs/networks block specific IPs
for host in ("1.1.1.1", "8.8.8.8", "1.0.0.1"):
try:
socket.create_connection((host, 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
logger.debug(f"Internet connectivity: True (via {host})")
return
except OSError:
continue
self.shard_downloader.set_internet_connection(False)
logger.debug("Internet connectivity: False")
async def _check_internet_connection(self) -> None:
first_connection = True
@@ -143,6 +169,20 @@ class DownloadCoordinator:
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
# directly copied from worker
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
while True:
await anyio.sleep(1 + random())
for event in self._out_for_delivery.copy().values():
await self.local_event_sender.send(event)
async def _clear_ofd(self) -> None:
with self._global_event_receiver as events:
async for event in events:
self._out_for_delivery.pop(event.event.event_id, None)
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
@@ -175,6 +215,25 @@ class DownloadCoordinator:
)
return
# Check EXO_MODELS_PATH for pre-downloaded models
found_path = resolve_model_in_path(model_id)
if found_path is not None:
logger.info(
f"DownloadCoordinator: Model {model_id} found in EXO_MODELS_PATH at {found_path}"
)
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total=shard.model_card.storage_size,
model_directory=str(found_path),
read_only=True,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
return
# Emit pending status
progress = DownloadPending(
shard_metadata=shard,
@@ -193,7 +252,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
total=initial_progress.total,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -202,6 +261,20 @@ class DownloadCoordinator:
)
return
if self.offline:
logger.warning(
f"Offline mode: model {model_id} is not fully available locally, cannot download"
)
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=f"Model files not found locally in offline mode: {model_id}",
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
return
# Start actual download
self._start_download_task(shard, initial_progress)
@@ -245,6 +318,15 @@ class DownloadCoordinator:
self.active_downloads[model_id] = task
async def _delete_download(self, model_id: ModelId) -> None:
# Protect read-only models (from EXO_MODELS_PATH) from deletion
if model_id in self.download_status:
current = self.download_status[model_id]
if isinstance(current, DownloadCompleted) and current.read_only:
logger.warning(
f"Refusing to delete read-only model {model_id} (from EXO_MODELS_PATH)"
)
return
# Cancel if active
if model_id in self.active_downloads:
logger.info(f"Cancelling active download for {model_id} before deletion")
@@ -274,19 +356,21 @@ class DownloadCoordinator:
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self.node_id,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
self._out_for_delivery[event.event_id] = fe
async def _emit_existing_download_progress(self) -> None:
try:
@@ -308,29 +392,21 @@ class DownloadCoordinator:
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
total=progress.total,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if (
progress.downloaded_bytes.in_bytes
>= progress.total_bytes.in_bytes
> 0
):
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
if progress.downloaded_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
downloaded=progress.downloaded,
total=progress.total,
)
else:
status = DownloadOngoing(
@@ -350,6 +426,39 @@ class DownloadCoordinator:
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
# Scan EXO_MODELS_PATH for pre-downloaded models
if EXO_MODELS_PATH is not None:
for card in await get_model_cards():
mid = card.model_id
if mid in self.active_downloads:
continue
if isinstance(
self.download_status.get(mid),
(DownloadCompleted, DownloadOngoing, DownloadFailed),
):
continue
found = resolve_model_in_path(mid)
if found is not None:
path_shard = PipelineShardMetadata(
model_card=card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
path_completed: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=path_shard,
total=card.storage_size,
model_directory=str(found),
read_only=True,
)
self.download_status[mid] = path_completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=path_completed)
)
logger.debug(
"DownloadCoordinator: Done emitting existing download progress."
)

View File

@@ -20,7 +20,6 @@ from huggingface_hub import (
)
from loguru import logger
from pydantic import (
DirectoryPath,
TypeAdapter,
)
@@ -31,7 +30,7 @@ from exo.download.huggingface_utils import (
get_hf_endpoint,
get_hf_token,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.constants import EXO_MODELS_DIR, EXO_MODELS_PATH
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
@@ -80,9 +79,9 @@ def map_repo_file_download_progress_to_download_progress_data(
repo_file_download_progress: RepoFileDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
downloaded_bytes=repo_file_download_progress.downloaded,
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
total_bytes=repo_file_download_progress.total,
downloaded=repo_file_download_progress.downloaded,
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
total=repo_file_download_progress.total,
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
total_files=1,
speed=repo_file_download_progress.speed,
@@ -95,9 +94,9 @@ def map_repo_download_progress_to_download_progress_data(
repo_download_progress: RepoDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
total_bytes=repo_download_progress.total_bytes,
downloaded_bytes=repo_download_progress.downloaded_bytes,
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
total=repo_download_progress.total,
downloaded=repo_download_progress.downloaded,
downloaded_this_session=repo_download_progress.downloaded_this_session,
completed_files=repo_download_progress.completed_files,
total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed,
@@ -111,7 +110,27 @@ def map_repo_download_progress_to_download_progress_data(
)
def build_model_path(model_id: ModelId) -> DirectoryPath:
def resolve_model_in_path(model_id: ModelId) -> Path | None:
"""Search EXO_MODELS_PATH directories for a pre-existing model.
Checks each directory for the normalized name (org--model). A candidate
is only returned if ``is_model_directory_complete`` confirms all weight
files are present.
"""
if EXO_MODELS_PATH is None:
return None
normalized = model_id.normalize()
for search_dir in EXO_MODELS_PATH:
candidate = search_dir / normalized
if candidate.is_dir() and is_model_directory_complete(candidate):
return candidate
return None
def build_model_path(model_id: ModelId) -> Path:
found = resolve_model_in_path(model_id)
if found is not None:
return found
return EXO_MODELS_DIR / model_id.normalize()
@@ -142,7 +161,7 @@ async def delete_model(model_id: ModelId) -> bool:
async def seed_models(seed_dir: str | Path):
"""Move model in resources folder of app to .cache/huggingface/hub"""
"""Move models from resources folder to EXO_MODELS_DIR."""
source_dir = Path(seed_dir)
dest_dir = await ensure_models_dir()
for path in source_dir.iterdir():
@@ -158,6 +177,72 @@ async def seed_models(seed_dir: str | Path):
logger.error(traceback.format_exc())
def _scan_model_directory(
model_dir: Path, recursive: bool = False
) -> list[FileListEntry] | None:
"""Scan a local model directory and build a file list.
Requires at least one ``*.safetensors.index.json``. Every weight file
referenced by the index that is missing on disk gets ``size=None``.
"""
index_files = list(model_dir.glob("**/*.safetensors.index.json"))
if not index_files:
return None
entries_by_path: dict[str, FileListEntry] = {}
if recursive:
for dirpath, _, filenames in os.walk(model_dir):
for filename in filenames:
if filename.endswith(".partial"):
continue
full_path = Path(dirpath) / filename
rel_path = str(full_path.relative_to(model_dir))
entries_by_path[rel_path] = FileListEntry(
type="file",
path=rel_path,
size=full_path.stat().st_size,
)
else:
for item in model_dir.iterdir():
if item.is_file() and not item.name.endswith(".partial"):
entries_by_path[item.name] = FileListEntry(
type="file",
path=item.name,
size=item.stat().st_size,
)
# Add expected weight files from index that haven't been downloaded yet
for index_file in index_files:
try:
index_data = ModelSafetensorsIndex.model_validate_json(
index_file.read_text()
)
relative_dir = index_file.parent.relative_to(model_dir)
for filename in set(index_data.weight_map.values()):
rel_path = (
str(relative_dir / filename)
if relative_dir != Path(".")
else filename
)
if rel_path not in entries_by_path:
entries_by_path[rel_path] = FileListEntry(
type="file",
path=rel_path,
size=None,
)
except Exception:
continue
return list(entries_by_path.values())
def is_model_directory_complete(model_dir: Path) -> bool:
"""Check if a model directory contains all required weight files."""
file_list = _scan_model_directory(model_dir, recursive=True)
return file_list is not None and all(f.size is not None for f in file_list)
async def _build_file_list_from_local_directory(
model_id: ModelId,
recursive: bool = False,
@@ -172,59 +257,7 @@ async def _build_file_list_from_local_directory(
if not await aios.path.exists(model_dir):
return None
def _scan() -> list[FileListEntry] | None:
index_files = list(model_dir.glob("**/*.safetensors.index.json"))
if not index_files:
return None
entries_by_path: dict[str, FileListEntry] = {}
if recursive:
for dirpath, _, filenames in os.walk(model_dir):
for filename in filenames:
if filename.endswith(".partial"):
continue
full_path = Path(dirpath) / filename
rel_path = str(full_path.relative_to(model_dir))
entries_by_path[rel_path] = FileListEntry(
type="file",
path=rel_path,
size=full_path.stat().st_size,
)
else:
for item in model_dir.iterdir():
if item.is_file() and not item.name.endswith(".partial"):
entries_by_path[item.name] = FileListEntry(
type="file",
path=item.name,
size=item.stat().st_size,
)
# Add expected weight files from index that haven't been downloaded yet
for index_file in index_files:
try:
index_data = ModelSafetensorsIndex.model_validate_json(
index_file.read_text()
)
relative_dir = index_file.parent.relative_to(model_dir)
for filename in set(index_data.weight_map.values()):
rel_path = (
str(relative_dir / filename)
if relative_dir != Path(".")
else filename
)
if rel_path not in entries_by_path:
entries_by_path[rel_path] = FileListEntry(
type="file",
path=rel_path,
size=None,
)
except Exception:
continue
return list(entries_by_path.values())
file_list = await asyncio.to_thread(_scan)
file_list = await asyncio.to_thread(_scan_model_directory, model_dir, recursive)
if not file_list:
return None
return file_list
@@ -448,12 +481,13 @@ async def download_file_with_retry(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
skip_internet: bool = False,
) -> Path:
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress, skip_internet
)
except HuggingFaceAuthenticationError:
raise
@@ -487,10 +521,14 @@ async def _download_file(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
skip_internet: bool = False,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
if skip_internet:
return target_path
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
@@ -510,6 +548,11 @@ async def _download_file(
)
return target_path
if skip_internet:
raise FileNotFoundError(
f"File {path} not found locally and cannot download in offline mode"
)
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -568,19 +611,20 @@ def calculate_repo_progress(
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
) -> RepoDownloadProgress:
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes = sum(
(p.downloaded.in_bytes for p in file_progress.values()), 0
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
all_downloaded = sum(
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
)
all_downloaded_bytes_this_session = sum(
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
all_downloaded_this_session = sum(
(p.downloaded_this_session for p in file_progress.values()),
Memory.from_bytes(0),
)
elapsed_time = time.time() - all_start_time
all_speed = (
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
)
all_eta = (
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
if all_speed > 0
else timedelta(seconds=0)
)
@@ -599,11 +643,9 @@ def calculate_repo_progress(
[p for p in file_progress.values() if p.downloaded == p.total]
),
total_files=len(file_progress),
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
downloaded_bytes_this_session=Memory.from_bytes(
all_downloaded_bytes_this_session
),
total_bytes=Memory.from_bytes(all_total_bytes),
downloaded=all_downloaded,
downloaded_this_session=all_downloaded_this_session,
total=all_total,
overall_speed=all_speed,
overall_eta=all_eta,
status=status,
@@ -781,6 +823,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
final_file_exists = await aios.path.exists(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_revision=revision,
@@ -790,7 +833,9 @@ async def download_shard(
total=Memory.from_bytes(file.size or 0),
speed=0,
eta=timedelta(0),
status="complete" if downloaded_bytes == file.size else "not_started",
status="complete"
if final_file_exists and downloaded_bytes == file.size
else "not_started",
start_time=time.time(),
)
@@ -814,6 +859,7 @@ async def download_shard(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
skip_internet=skip_internet,
)
if not skip_download:

View File

@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
downloaded=Memory.from_bytes(0),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",

View File

@@ -0,0 +1,98 @@
from typing import Any
import anyio
import pytest
from exo.download.coordinator import DownloadCoordinator
from exo.download.shard_downloader import NoopShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
GlobalForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
DownloadPending,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.utils.channels import channel
# Use the builtin NoopShardDownloader directly it already implements the required abstract interface.
# No additional subclass is needed for this test.
@pytest.mark.anyio
async def test_ack_behaviour():
# Create channels (type Any for simplicity)
_, command_receiver = channel[Any]()
local_sender, _ = channel[Any]()
global_sender, global_receiver = channel[Any]()
# Minimal identifiers
node_id = NodeId()
session_id = SessionId(master_node_id=node_id, election_clock=0)
# Create a dummy model card and shard metadata
model_id = ModelId("test/model")
model_card = ModelCard(
model_id=model_id,
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
# Instantiate the coordinator with the dummy downloader
coord = DownloadCoordinator(
node_id=node_id,
session_id=session_id,
shard_downloader=NoopShardDownloader(),
download_command_receiver=command_receiver,
local_event_sender=local_sender,
_global_event_receiver=global_receiver,
)
async with anyio.create_task_group() as tg:
# Start the forwarding and ackclearing loops
tg.start_soon(coord._forward_events) # pyright: ignore[reportPrivateUsage]
tg.start_soon(coord._clear_ofd) # pyright: ignore[reportPrivateUsage]
# Send a pending download progress event via the internal event sender
pending = DownloadPending(
node_id=node_id,
shard_metadata=shard,
model_directory="/tmp/model",
)
await coord.event_sender.send(NodeDownloadProgress(download_progress=pending))
# Allow the forwarder to process the event
await anyio.sleep(0.1)
# There should be exactly one entry awaiting ACK
assert len(coord._out_for_delivery) == 1 # pyright: ignore[reportPrivateUsage]
# Retrieve the stored LocalForwarderEvent
stored_fe: LocalForwarderEvent = next(iter(coord._out_for_delivery.values())) # pyright: ignore[reportPrivateUsage]
# Simulate receiving a global ack for this event
ack = GlobalForwarderEvent(
origin_idx=0,
origin=node_id,
session=session_id,
event=stored_fe.event,
)
await global_sender.send(ack)
# Give the clearofd task a moment to process the ack
await anyio.sleep(0.1)
# The outfordelivery map should now be empty
assert len(coord._out_for_delivery) == 0 # pyright: ignore[reportPrivateUsage]
# Cancel background tasks
tg.cancel_scope.cancel()

View File

@@ -0,0 +1,230 @@
"""Tests for offline/air-gapped mode."""
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
download_file_with_retry,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.worker.downloads import FileListEntry
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestDownloadFileOffline:
"""Tests for _download_file with skip_internet=True."""
async def test_returns_local_file_without_http_verification(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists locally, return it immediately
without making any HTTP calls (no file_meta verification)."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"model weights data")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_raises_file_not_found_for_missing_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file does NOT exist locally,
raise FileNotFoundError instead of attempting download."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError, match="offline mode"):
await _download_file(
model_id,
"main",
"missing_model.safetensors",
target_dir,
skip_internet=True,
)
async def test_returns_local_file_in_subdirectory(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists in a subdirectory,
return it without HTTP calls."""
target_dir = tmp_path / "downloads"
subdir = target_dir / "transformer"
await aios.makedirs(subdir, exist_ok=True)
local_file = subdir / "diffusion_pytorch_model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"weights")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"transformer/diffusion_pytorch_model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
class TestDownloadFileWithRetryOffline:
"""Tests for download_file_with_retry with skip_internet=True."""
async def test_propagates_skip_internet_to_download_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Verify skip_internet is passed through to _download_file."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "config.json"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b'{"model_type": "qwen2"}')
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_file_not_found_does_not_retry(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""FileNotFoundError from offline mode should not trigger retries."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
await download_file_with_retry(
model_id,
"main",
"nonexistent.safetensors",
target_dir,
skip_internet=True,
)
class TestFetchFileListOffline:
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
async def test_uses_cached_file_list(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and cache file exists, use it without network."""
from pydantic import TypeAdapter
cache_dir = temp_models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
cached_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=200),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
assert result == cached_list
mock_fetch.assert_not_called()
async def test_falls_back_to_local_directory_scan(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and no cache but local files exist,
build file list from local directory."""
import json
model_dir = temp_models_dir / model_id.normalize()
await aios.makedirs(model_dir, exist_ok=True)
async with aiofiles.open(model_dir / "config.json", "w") as f:
await f.write('{"model_type": "qwen2"}')
index_data = {
"metadata": {},
"weight_map": {"model.layers.0.weight": "model.safetensors"},
}
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
await f.write(json.dumps(index_data))
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
await f.write(b"x" * 500)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
mock_fetch.assert_not_called()
paths = {entry.path for entry in result}
assert "config.json" in paths
assert "model.safetensors" in paths
async def test_raises_when_no_cache_and_no_local_files(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and neither cache nor local files exist,
raise FileNotFoundError."""
with pytest.raises(FileNotFoundError, match="No internet"):
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)

View File

@@ -1,11 +1,10 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Iterator, Self
from typing import Self
import anyio
from anyio.abc import TaskGroup
@@ -38,13 +37,13 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
async def create(cls, args: "Args") -> "Self":
async def create(cls, args: "Args") -> Self:
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)
@@ -56,9 +55,6 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
@@ -67,7 +63,9 @@ class Node:
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
# TODO(evan): remove
_global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
)
else:
download_coordinator = None
@@ -93,7 +91,6 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -131,7 +128,7 @@ class Node:
master,
api,
node_id,
event_index_counter,
args.offline,
)
async def run(self):
@@ -209,8 +206,6 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
@@ -221,7 +216,11 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
# TODO(evan): remove
_global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -238,7 +237,6 @@ class Node:
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -254,12 +252,15 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
@@ -282,6 +283,7 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -329,6 +331,11 @@ class Args(CamelCaseModel):
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
parser.add_argument(
"--offline",
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -19,7 +19,12 @@ from exo.shared.types.api import (
ToolCall,
Usage,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
@@ -54,7 +59,11 @@ def chat_request_to_text_generation(
chat_template_messages.append({"role": "system", "content": content})
else:
# Skip messages with no meaningful content
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
if (
msg.content is None
and msg.reasoning_content is None
and msg.tool_calls is None
):
continue
if msg.role in ("user", "assistant", "developer"):
@@ -106,6 +115,11 @@ def chunk_to_response(
]
)
if chunk.is_thinking:
delta = ChatCompletionMessage(role="assistant", reasoning_content=chunk.text)
else:
delta = ChatCompletionMessage(role="assistant", content=chunk.text)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -113,7 +127,7 @@ def chunk_to_response(
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
delta=delta,
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
@@ -123,72 +137,87 @@ def chunk_to_response(
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
PrefillProgressChunk | ErrorChunk | ToolCallChunk | TokenChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
match chunk:
case PrefillProgressChunk():
# Use SSE comment so third-party clients ignore it
yield f": prefill_progress {chunk.model_dump_json()}\n\n"
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
case ErrorChunk():
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
yield f"data: {chunk_response.model_dump_json()}\n\n"
case ToolCallChunk():
last_usage = chunk.usage or last_usage
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
case TokenChunk():
last_usage = chunk.usage or last_usage
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(
update={"usage": last_usage}
)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[ToolCall] = []
logprobs_content: list[LogprobsContentItem] = []
model: str | None = None
@@ -197,43 +226,52 @@ async def collect_chat_response(
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
match chunk:
case PrefillProgressChunk():
continue
if model is None:
model = chunk.model
case ErrorChunk():
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
case TokenChunk():
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=tool.id,
index=i,
function=tool,
case ToolCallChunk():
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
tool_calls.extend(
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts) if thinking_parts else None
assert model is not None
yield ChatCompletionResponse(
@@ -246,6 +284,7 @@ async def collect_chat_response(
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
reasoning_content=combined_thinking,
tool_calls=tool_calls if tool_calls else None,
),
logprobs=Logprobs(content=logprobs_content)

View File

@@ -1,11 +1,17 @@
"""Claude Messages API adapter for converting requests/responses."""
import json
import re
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.api import FinishReason, Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.claude_api import (
ClaudeContentBlock,
ClaudeContentBlockDeltaEvent,
@@ -23,6 +29,8 @@ from exo.shared.types.claude_api import (
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeThinkingBlock,
ClaudeThinkingDelta,
ClaudeToolResultBlock,
ClaudeToolUseBlock,
ClaudeUsage,
@@ -56,6 +64,22 @@ def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
return "".join(sub_block.text for sub_block in block.content)
# Matches "x-anthropic-billing-header: ...;" (with optional trailing newline)
# or similar telemetry headers that change every request and break KV prefix caching.
_VOLATILE_HEADER_RE = re.compile(r"^x-anthropic-[^\n]*;\n?", re.MULTILINE)
def _strip_volatile_headers(text: str) -> str:
"""Remove Anthropic billing/telemetry headers from system prompt text.
Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;
cc_entrypoint=...; cch=...;' that contain per-request content hashes. These
change every request and break KV prefix caching (the prefix diverges at ~20
tokens instead of matching thousands of conversation tokens).
"""
return _VOLATILE_HEADER_RE.sub("", text)
def claude_request_to_text_generation(
request: ClaudeMessagesRequest,
) -> TextGenerationTaskParams:
@@ -68,6 +92,8 @@ def claude_request_to_text_generation(
instructions = request.system
else:
instructions = "".join(block.text for block in request.system)
instructions = _strip_volatile_headers(instructions)
chat_template_messages.append({"role": "system", "content": instructions})
# Convert messages to input
@@ -80,12 +106,15 @@ def claude_request_to_text_generation(
# Process structured content blocks
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
tool_results: list[ClaudeToolResultBlock] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
elif isinstance(block, ClaudeThinkingBlock):
thinking_parts.append(block.thinking)
elif isinstance(block, ClaudeToolUseBlock):
tool_calls.append(
{
@@ -101,6 +130,7 @@ def claude_request_to_text_generation(
tool_results.append(block)
content = "".join(text_parts)
reasoning_content = "".join(thinking_parts) if thinking_parts else None
# Build InputMessage from text content
if msg.role in ("user", "assistant"):
@@ -108,9 +138,14 @@ def claude_request_to_text_generation(
# Build chat_template_messages preserving tool structure
if tool_calls:
chat_template_messages.append(
{"role": "assistant", "content": content, "tool_calls": tool_calls}
)
chat_msg: dict[str, Any] = {
"role": "assistant",
"content": content,
"tool_calls": tool_calls,
}
if reasoning_content:
chat_msg["reasoning_content"] = reasoning_content
chat_template_messages.append(chat_msg)
elif tool_results:
for tr in tool_results:
chat_template_messages.append(
@@ -121,7 +156,10 @@ def claude_request_to_text_generation(
}
)
else:
chat_template_messages.append({"role": msg.role, "content": content})
chat_msg = {"role": msg.role, "content": content}
if reasoning_content:
chat_msg["reasoning_content"] = reasoning_content
chat_template_messages.append(chat_msg)
# Convert Claude tool definitions to OpenAI-style function tools
tools: list[dict[str, Any]] | None = None
@@ -138,6 +176,10 @@ def claude_request_to_text_generation(
for tool in request.tools
]
enable_thinking: bool | None = None
if request.thinking is not None:
enable_thinking = request.thinking.type in ("enabled", "adaptive")
return TextGenerationTaskParams(
model=request.model,
input=input_messages
@@ -151,6 +193,7 @@ def claude_request_to_text_generation(
stop=request.stop_sequences,
stream=request.stream,
tools=tools,
enable_thinking=enable_thinking,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
@@ -160,18 +203,24 @@ def claude_request_to_text_generation(
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_usage: Usage | None = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
@@ -190,7 +239,10 @@ async def collect_claude_response(
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -199,9 +251,12 @@ async def collect_claude_response(
raise ValueError(error_message)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts)
# Build content blocks
content: list[ClaudeContentBlock] = []
if combined_thinking:
content.append(ClaudeThinkingBlock(thinking=combined_thinking))
if combined_text:
content.append(ClaudeTextBlock(text=combined_text))
content.extend(tool_use_blocks)
@@ -230,7 +285,9 @@ async def collect_claude_response(
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
@@ -244,18 +301,21 @@ async def generate_claude_stream(
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start for text block at index 0
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_usage: Usage | None = None
next_block_index = 1 # text block is 0, tool blocks start at 1
next_block_index = 0
# Track whether we've started thinking/text blocks
thinking_block_started = False
thinking_block_index = -1
text_block_started = False
text_block_index = -1
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
# Close text block and bail
break
@@ -295,12 +355,45 @@ async def generate_claude_stream(
output_tokens += 1 # Count each chunk as one token
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.is_thinking:
# Start thinking block on first thinking token
if not thinking_block_started:
thinking_block_started = True
thinking_block_index = next_block_index
next_block_index += 1
block_start = ClaudeContentBlockStartEvent(
index=thinking_block_index,
content_block=ClaudeThinkingBlock(thinking=""),
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
delta_event = ClaudeContentBlockDeltaEvent(
index=thinking_block_index,
delta=ClaudeThinkingDelta(thinking=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
else:
# Close thinking block when transitioning to text
if thinking_block_started and text_block_index == -1:
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# Start text block on first text token
if not text_block_started:
text_block_started = True
text_block_index = next_block_index
next_block_index += 1
block_start = ClaudeContentBlockStartEvent(
index=text_block_index,
content_block=ClaudeTextBlock(text=""),
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
delta_event = ClaudeContentBlockDeltaEvent(
index=text_block_index,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -309,9 +402,22 @@ async def generate_claude_stream(
if last_usage is not None:
output_tokens = last_usage.completion_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# Close any open blocks
if thinking_block_started and text_block_index == -1:
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
if text_block_started:
block_stop = ClaudeContentBlockStopEvent(index=text_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
if not thinking_block_started and not text_block_started:
empty_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {empty_start.model_dump_json()}\n\n"
empty_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {empty_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(

View File

@@ -0,0 +1,456 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaDoneReason,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaMessage,
OllamaToolCall,
OllamaToolFunction,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _map_done_reason(
finish_reason: str | None,
) -> OllamaDoneReason | None:
if finish_reason is None:
return None
if finish_reason == "stop":
return "stop"
if finish_reason == "length":
return "length"
if finish_reason in ("tool_calls", "function_call"):
return "tool_call"
if finish_reason == "error":
return "error"
return "stop"
def _try_parse_json(value: str) -> dict[str, Any] | str:
try:
return json.loads(value) # type: ignore
except json.JSONDecodeError:
return value
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
tool_calls: list[OllamaToolCall] = []
for index, tool in enumerate(chunk.tool_calls):
# tool.arguments is always str; try to parse as JSON dict for Ollama format
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
tool_calls.append(
OllamaToolCall(
id=tool.id,
type="function",
function=OllamaToolFunction(
name=tool.name, arguments=arguments, index=index
),
)
)
return tool_calls
def _get_usage(
chunk: TokenChunk | ToolCallChunk,
) -> tuple[int | None, int | None]:
"""Extract (prompt_eval_count, eval_count) from a chunk."""
if chunk.usage is not None:
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
if chunk.stats is not None:
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
return (None, None)
def ollama_request_to_text_generation(
request: OllamaChatRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama chat request to exo's internal text generation format."""
instructions: str | None = None
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
tool_message_index = 0
for msg in request.messages:
content = msg.content or ""
if msg.role == "system":
if instructions is None:
instructions = content
else:
instructions = f"{instructions}\n{content}"
chat_template_messages.append({"role": "system", "content": content})
continue
if msg.role in ("user", "assistant") and (
msg.content is not None or msg.thinking is not None or msg.tool_calls
):
input_messages.append(InputMessage(role=msg.role, content=content))
dumped: dict[str, Any] = {"role": msg.role, "content": content}
if msg.thinking is not None:
dumped["thinking"] = msg.thinking
if msg.tool_calls is not None:
tool_calls_list: list[dict[str, Any]] = []
for tc in msg.tool_calls:
function: dict[str, Any] = {
"name": tc.function.name,
"arguments": (
json.dumps(tc.function.arguments)
if isinstance(tc.function.arguments, dict)
else tc.function.arguments
),
}
if tc.function.index is not None:
function["index"] = tc.function.index
tool_call: dict[str, Any] = {"function": function}
if tc.id is not None:
tool_call["id"] = tc.id
if tc.type is not None:
tool_call["type"] = tc.type
tool_calls_list.append(tool_call)
dumped["tool_calls"] = tool_calls_list
if msg.name is not None:
dumped["name"] = msg.name
if msg.role == "tool":
tool_message_index += 1
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
dumped["tool_call_id"] = tool_call_id
if msg.tool_name is not None:
dumped["tool_name"] = msg.tool_name
chat_template_messages.append(dumped)
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
tools=request.tools,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_chat_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
error_response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content=chunk.error_message
),
done=True,
done_reason="error",
)
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content="",
tool_calls=_build_tool_calls(chunk),
thinking="".join(thinking_parts) if thinking_parts else None,
),
done=True,
done_reason="tool_call",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content="", thinking=chunk.text
),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content=chunk.text,
),
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
else:
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(role="assistant", content=chunk.text),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_chat_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect streaming chunks into a single non-streaming Ollama response.
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
StreamingResponse cancellation handling.
"""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[OllamaToolCall] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
tool_calls.extend(_build_tool_calls(chunk))
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts) if thinking_parts else None
assert model is not None
yield OllamaChatResponse(
model=model,
message=OllamaMessage(
role="assistant",
content=combined_text,
thinking=combined_thinking,
tool_calls=tool_calls if tool_calls else None,
),
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return
# ── /api/generate ──
def ollama_generate_request_to_text_generation(
request: OllamaGenerateRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama generate request to exo's internal text generation format."""
chat_template_messages: list[dict[str, Any]] = []
if request.system:
chat_template_messages.append({"role": "system", "content": request.system})
chat_template_messages.append({"role": "user", "content": request.prompt})
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=[InputMessage(role="user", content=request.prompt)],
instructions=request.system,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_generate_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="error",
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
# generate endpoint doesn't support tools; emit as done
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="stop",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
thinking=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
else:
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_generate_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect chunks into a single non-streaming /api/generate response."""
text_parts: list[str] = []
thinking_parts: list[str] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
assert model is not None
yield OllamaGenerateResponse(
model=model,
response="".join(text_parts),
thinking="".join(thinking_parts) if thinking_parts else None,
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return

View File

@@ -5,7 +5,12 @@ from itertools import count
from typing import Any
from exo.shared.types.api import Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
FunctionCallInputItem,
@@ -24,8 +29,15 @@ from exo.shared.types.openai_responses import (
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponseReasoningItem,
ResponseReasoningSummaryPartAddedEvent,
ResponseReasoningSummaryPartDoneEvent,
ResponseReasoningSummaryText,
ResponseReasoningSummaryTextDeltaEvent,
ResponseReasoningSummaryTextDoneEvent,
ResponsesRequest,
ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
@@ -33,6 +45,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
@@ -121,19 +138,26 @@ def responses_request_to_text_generation(
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
reasoning_id = f"rs_{command_id}"
accumulated_text = ""
thinking_parts: list[str] = []
function_call_items: list[ResponseFunctionCallItem] = []
last_usage: Usage | None = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
@@ -152,6 +176,10 @@ async def collect_responses_response(
)
continue
if chunk.is_thinking:
thinking_parts.append(chunk.text)
continue
accumulated_text += chunk.text
if error_message is not None:
@@ -166,13 +194,21 @@ async def collect_responses_response(
total_tokens=last_usage.total_tokens,
)
output: list[ResponseItem] = [
output: list[ResponseItem] = []
if thinking_parts:
output.append(
ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text="".join(thinking_parts))],
)
)
output.append(
ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
]
)
output.extend(function_call_items)
yield ResponsesResponse(
@@ -189,11 +225,14 @@ async def collect_responses_response(
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
reasoning_id = f"rs_{command_id}"
seq = count(1)
# response.created
@@ -207,42 +246,30 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
yield _format_sse(created_event)
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
yield _format_sse(in_progress_event)
accumulated_text = ""
accumulated_thinking = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_usage: Usage | None = None
next_output_index = 1 # message item is at 0
next_output_index = 0
# Track dynamic block creation
reasoning_started = False
reasoning_output_index = -1
message_started = False
message_output_index = -1
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
break
@@ -266,7 +293,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
yield _format_sse(fc_added)
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -275,7 +302,7 @@ async def generate_responses_stream(
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
yield _format_sse(args_delta)
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -285,7 +312,7 @@ async def generate_responses_stream(
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
yield _format_sse(args_done)
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
@@ -300,44 +327,205 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item)
next_output_index += 1
continue
if chunk.is_thinking:
# Start reasoning block on first thinking token
if not reasoning_started:
reasoning_started = True
reasoning_output_index = next_output_index
next_output_index += 1
# response.output_item.added for reasoning
reasoning_item = ResponseReasoningItem(
id=reasoning_id,
summary=[],
status="in_progress",
)
rs_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=reasoning_item,
)
yield _format_sse(rs_added)
# response.reasoning_summary_part.added
part_added = ResponseReasoningSummaryPartAddedEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=""),
)
yield _format_sse(part_added)
accumulated_thinking += chunk.text
# response.reasoning_summary_text.delta
rs_delta = ResponseReasoningSummaryTextDeltaEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
delta=chunk.text,
)
yield _format_sse(rs_delta)
continue
# Close reasoning block when transitioning to text
if reasoning_started and not message_started:
# response.reasoning_summary_text.done
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
text=accumulated_thinking,
)
yield _format_sse(rs_text_done)
# response.reasoning_summary_part.done
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=accumulated_thinking),
)
yield _format_sse(rs_part_done)
# response.output_item.done for reasoning
rs_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
),
)
yield _format_sse(rs_item_done)
# Start message block on first text token
if not message_started:
message_started = True
message_output_index = next_output_index
next_output_index += 1
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=message_output_index,
item=initial_item,
)
yield _format_sse(item_added)
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=message_output_index,
content_index=0,
part=initial_part,
)
yield _format_sse(part_added)
accumulated_text += chunk.text
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
output_index=message_output_index,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
yield _format_sse(delta_event)
# Close reasoning block if it was never followed by text
if reasoning_started and not message_started:
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
text=accumulated_thinking,
)
yield _format_sse(rs_text_done)
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=accumulated_thinking),
)
yield _format_sse(rs_part_done)
rs_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
),
)
yield _format_sse(rs_item_done)
# If no message block was started, create one now (empty text)
if not message_started:
message_output_index = next_output_index
next_output_index += 1
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=message_output_index,
item=initial_item,
)
yield _format_sse(item_added)
initial_part = ResponseOutputText(text="")
part_added_evt = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=message_output_index,
content_index=0,
part=initial_part,
)
yield _format_sse(part_added_evt)
# response.output_text.done
text_done = ResponseTextDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
output_index=message_output_index,
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
yield _format_sse(text_done)
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
output_index=message_output_index,
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
yield _format_sse(part_done)
# response.output_item.done
final_message_item = ResponseMessageItem(
@@ -346,9 +534,11 @@ async def generate_responses_stream(
status="completed",
)
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
sequence_number=next(seq),
output_index=message_output_index,
item=final_message_item,
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
yield _format_sse(item_done)
# Create usage from usage data if available
usage = None
@@ -360,7 +550,15 @@ async def generate_responses_stream(
)
# response.completed
output: list[ResponseItem] = [final_message_item]
output: list[ResponseItem] = []
if reasoning_started:
output.append(
ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
)
)
output.append(final_message_item)
output.extend(function_call_items)
final_response = ResponsesResponse(
id=response_id,
@@ -373,4 +571,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
yield _format_sse(completed_event)

View File

@@ -32,6 +32,14 @@ from exo.master.adapters.claude import (
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.ollama import (
collect_ollama_chat_response,
collect_ollama_generate_response,
generate_ollama_chat_stream,
generate_ollama_generate_stream,
ollama_generate_request_to_text_generation,
ollama_request_to_text_generation,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
@@ -71,11 +79,8 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
CreateMetaInstanceParams,
CreateMetaInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
DeleteMetaInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
@@ -88,6 +93,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
ImageListItem,
ImageListResponse,
ImageSize,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -103,11 +109,13 @@ from exo.shared.types.api import (
TraceRankStats,
TraceResponse,
TraceStatsResponse,
normalize_image_size,
)
from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
@@ -118,10 +126,8 @@ from exo.shared.types.claude_api import (
from exo.shared.types.commands import (
Command,
CreateInstance,
CreateMetaInstance,
DeleteDownload,
DeleteInstance,
DeleteMetaInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
@@ -134,21 +140,34 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
GlobalForwarderEvent,
IndexedEvent,
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaModelDetails,
OllamaModelTag,
OllamaPsModel,
OllamaPsResponse,
OllamaShowRequest,
OllamaShowResponse,
OllamaTagsResponse,
)
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -178,8 +197,7 @@ class API:
session_id: SessionId,
*,
port: int,
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
global_event_receiver: Receiver[GlobalForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
@@ -187,6 +205,7 @@ class API:
) -> None:
self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
@@ -224,7 +243,8 @@ class API:
)
self._text_generation_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -237,6 +257,7 @@ class API:
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
@@ -282,9 +303,6 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/meta_instances")(self.list_meta_instances)
self.app.post("/meta_instance")(self.create_meta_instance)
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
@@ -304,6 +322,21 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
# Ollama API
self.app.head("/ollama/")(self.ollama_version)
self.app.head("/ollama/api/version")(self.ollama_version)
self.app.post("/ollama/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/ollama/api/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/ollama/api/v1/chat", response_model=None)(self.ollama_chat)
self.app.post("/ollama/api/generate", response_model=None)(self.ollama_generate)
self.app.get("/ollama/api/tags")(self.ollama_tags)
self.app.get("/ollama/api/api/tags")(self.ollama_tags)
self.app.get("/ollama/api/v1/tags")(self.ollama_tags)
self.app.post("/ollama/api/show")(self.ollama_show)
self.app.get("/ollama/api/ps")(self.ollama_ps)
self.app.get("/ollama/api/version")(self.ollama_version)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(self.stream_events)
self.app.post("/download/start")(self.start_download)
@@ -314,27 +347,12 @@ class API:
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
async def place_instance(self, payload: PlaceInstanceParams):
model_card = await ModelCard.load(payload.model_id)
command = PlaceInstance(
model_card=model_card,
model_card=await ModelCard.load(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
)
# Validate placement before sending — fail fast with a clear error
# instead of silently dropping the command in the master.
try:
get_instance_placements(
command,
topology=self.state.topology,
current_instances=self.state.instances,
node_memory=self.state.node_memory,
node_network=self.state.node_network,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
await self._send(command)
return CreateInstanceResponse(
@@ -546,67 +564,33 @@ class API:
instance_id=instance_id,
)
def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]:
return dict(self.state.meta_instances)
async def create_meta_instance(
self, payload: CreateMetaInstanceParams
) -> CreateMetaInstanceResponse:
meta_instance = MetaInstance(
model_id=payload.model_id,
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
node_ids=payload.node_ids,
)
command = CreateMetaInstance(meta_instance=meta_instance)
await self._send(command)
return CreateMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance.meta_instance_id,
)
async def delete_meta_instance(
self, meta_instance_id: MetaInstanceId
) -> DeleteMetaInstanceResponse:
meta = self.state.meta_instances.get(meta_instance_id)
if not meta:
raise HTTPException(status_code=404, detail="MetaInstance not found")
# Command processor handles cascade-deleting backing instances
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
await self._send(command)
return DeleteMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance_id,
)
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
) -> AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
]:
"""Yield chunks for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._text_generation_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if isinstance(chunk, PrefillProgressChunk):
continue
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -625,6 +609,9 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._token_chunk_stream(command_id):
if isinstance(chunk, PrefillProgressChunk):
continue
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -813,9 +800,11 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"advanced_params": _ensure_seed(payload.advanced_params),
}
)
command = ImageGeneration(
@@ -946,10 +935,10 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -1032,10 +1021,10 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -1071,12 +1060,13 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"stream": False,
"partial_images": 0,
"advanced_params": _ensure_seed(payload.advanced_params),
}
)
command = ImageGeneration(
@@ -1097,7 +1087,7 @@ class API:
prompt: str,
model: ModelId,
n: int,
size: str,
size: ImageSize,
response_format: Literal["url", "b64_json"],
input_fidelity: Literal["low", "high"],
stream: bool,
@@ -1167,7 +1157,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
size: str | None = Form(None),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
@@ -1193,7 +1183,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=size,
size=normalize_image_size(size),
response_format=response_format,
input_fidelity=input_fidelity,
stream=stream_bool,
@@ -1229,7 +1219,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
size: str | None = Form(None),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
@@ -1249,7 +1239,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=size,
size=normalize_image_size(size),
response_format=response_format,
input_fidelity=input_fidelity,
stream=False,
@@ -1340,6 +1330,163 @@ class API:
media_type="application/json",
)
async def _ollama_root(self) -> JSONResponse:
"""Respond to HEAD / from Ollama CLI connectivity checks."""
return JSONResponse(content="Ollama is running")
async def ollama_chat(
self, request: Request
) -> OllamaChatResponse | StreamingResponse:
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaChatRequest.model_validate_json(body)
task_params = ollama_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_generate(
self, request: Request
) -> OllamaGenerateResponse | StreamingResponse:
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaGenerateRequest.model_validate_json(body)
task_params = ollama_generate_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_generate_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_generate_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_tags(self) -> OllamaTagsResponse:
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
def none_if_empty(value: str) -> str | None:
return value or None
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
]
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
return OllamaTagsResponse(
models=[
OllamaModelTag(
name=str(card.model_id),
model=str(card.model_id),
modified_at=now,
size=card.storage_size.in_bytes,
digest="sha256:000000000000",
details=OllamaModelDetails(
family=none_if_empty(card.family),
quantization_level=none_if_empty(card.quantization),
),
)
for card in cards
]
)
async def ollama_show(self, request: Request) -> OllamaShowResponse:
"""Returns model information in Ollama show format."""
body = await request.body()
payload = OllamaShowRequest.model_validate_json(body)
model_name = payload.name or payload.model
if not model_name:
raise HTTPException(status_code=400, detail="name or model is required")
try:
card = await ModelCard.load(ModelId(model_name))
except Exception as exc:
raise HTTPException(
status_code=404, detail=f"Model not found: {model_name}"
) from exc
return OllamaShowResponse(
modelfile=f"FROM {card.model_id}",
template="{{ .Prompt }}",
details=OllamaModelDetails(
family=card.family or None,
quantization_level=card.quantization or None,
),
)
async def ollama_ps(self) -> OllamaPsResponse:
"""Returns list of running models (active instances)."""
models: list[OllamaPsModel] = []
seen: set[str] = set()
for instance in self.state.instances.values():
model_id = str(instance.shard_assignments.model_id)
if model_id in seen:
continue
seen.add(model_id)
models.append(
OllamaPsModel(
name=model_id,
model=model_id,
size=0,
)
)
return OllamaPsResponse(models=models)
async def ollama_version(self) -> dict[str, str]:
"""Returns version information for Ollama API compatibility."""
return {"version": "exo v1.0"}
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -1349,8 +1496,18 @@ class API:
return total_available
async def get_models(self) -> ModelList:
"""Returns list of available models."""
async def get_models(self, status: str | None = Query(default=None)) -> ModelList:
"""Returns list of available models, optionally filtered by being downloaded."""
cards = await get_model_cards()
if status == "downloaded":
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [c for c in cards if c.model_id in downloaded_model_ids]
return ModelList(
data=[
ModelListModel(
@@ -1359,7 +1516,7 @@ class API:
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
storage_size_megabytes=card.storage_size.in_mb,
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),
@@ -1368,7 +1525,7 @@ class API:
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in await get_model_cards()
for card in cards
]
)
@@ -1465,6 +1622,8 @@ class API:
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -1491,7 +1650,6 @@ class API:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
@@ -1529,12 +1687,12 @@ class API:
while self.paused:
await self.paused_ev.wait()
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
ForwarderDownloadCommand(origin=self._system_id, command=command)
)
async def start_download(

View File

@@ -1,5 +1,4 @@
from collections.abc import Sequence
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
import anyio
from anyio.abc import TaskGroup
@@ -13,22 +12,11 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.process_managers import ProcessManager
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.commands import (
CreateInstance,
CreateMetaInstance,
DeleteInstance,
DeleteMetaInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
@@ -41,19 +29,16 @@ from exo.shared.types.commands import (
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
ForwarderEvent,
GlobalForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
JacclSideChannelData,
JacclSideChannelGathered,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
LocalForwarderEvent,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
@@ -76,8 +61,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
@@ -88,8 +72,8 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
local_event_receiver: Receiver[ForwarderEvent],
global_event_sender: Sender[ForwarderEvent],
local_event_receiver: Receiver[LocalForwarderEvent],
global_event_sender: Sender[GlobalForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
@@ -101,16 +85,17 @@ class Master:
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
self._process_managers: Sequence[ProcessManager] = [
InstanceHealthReconciler(),
NodeTimeoutReconciler(),
MetaInstanceReconciler(),
]
async def run(self):
logger.info("Starting Master")
@@ -119,12 +104,15 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._reconcile)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -302,90 +290,10 @@ class Master:
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
origin=self._system_id, command=cmd
)
)
generated_events.extend(transition_events)
case CreateMetaInstance():
logger.info(
f"Creating MetaInstance for {command.meta_instance.model_id}"
f" (min_nodes={command.meta_instance.min_nodes},"
f" sharding={command.meta_instance.sharding})"
)
# Apply immediately so self.state is fresh across
# the await below and the reconciler won't race.
await self._apply_and_broadcast(
MetaInstanceCreated(meta_instance=command.meta_instance)
)
# Immediate placement attempt for responsiveness
model_card = await ModelCard.load(
command.meta_instance.model_id
)
# Re-check: reconciler may have satisfied it during the await
meta_id = command.meta_instance.meta_instance_id
still_unsatisfied = any(
m.meta_instance_id == meta_id
for m in find_unsatisfied_meta_instances(
self.state.meta_instances,
self.state.instances,
self.state.topology,
)
)
if still_unsatisfied:
result = try_place_for_meta_instance(
command.meta_instance,
model_card,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
self.state.tasks,
)
generated_events.extend(result.events)
if result.error is not None:
generated_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_id,
reason=result.error,
)
)
case DeleteMetaInstance():
backing_count = sum(
1
for inst in self.state.instances.values()
if inst.meta_instance_id == command.meta_instance_id
)
logger.info(
f"Deleting MetaInstance {command.meta_instance_id}"
f" (cascade-deleting {backing_count} backing instance(s))"
)
generated_events.append(
MetaInstanceDeleted(
meta_instance_id=command.meta_instance_id
)
)
# Cascade-delete backing instances atomically,
# cancelling any active tasks first.
for iid, inst in self.state.instances.items():
if inst.meta_instance_id == command.meta_instance_id:
for task in self.state.tasks.values():
if (
task.instance_id == iid
and task.task_status
in (
TaskStatus.Pending,
TaskStatus.Running,
)
):
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
generated_events.append(
InstanceDeleted(instance_id=iid)
)
case PlaceInstance():
placement = place_instance(
command,
@@ -417,19 +325,16 @@ class Master:
)
case TaskCancelled():
if (
command.cancelled_command_id
in self.command_task_mapping
):
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.cancelled_command_id
]
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
del self.command_task_mapping[
command.cancelled_command_id
]
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -438,10 +343,9 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time
@@ -452,32 +356,31 @@ class Master:
):
await self._send_event(IndexedEvent(idx=i, event=event))
for event in generated_events:
await self._apply_and_broadcast(event)
await self.event_sender.send(event)
except ValueError as e:
logger.opt(exception=e).warning("Error in command processor")
async def _apply_and_broadcast(self, event: Event) -> None:
"""Apply event to state, persist to disk, and broadcast to workers.
State is updated synchronously (before any await), so callers can
rely on ``self.state`` reflecting this event immediately after the
call. Python's cooperative scheduling guarantees no interleaving
between the state read and write.
"""
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
self._event_log.append(event)
await self._send_event(indexed)
async def _reconcile(self) -> None:
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
async def _plan(self) -> None:
while True:
for pm in self._process_managers:
events = await pm.reconcile(self.state)
for event in events:
await self._apply_and_broadcast(event)
await anyio.sleep(1)
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
await self.event_sender.send(
InstanceDeleted(instance_id=instance_id)
)
break
# time out dead nodes
for node_id, time in self.state.last_seen.items():
now = datetime.now(tz=timezone.utc)
if now - time > timedelta(seconds=30):
logger.info(f"Manually removing node {node_id} due to inactivity")
await self.event_sender.send(NodeTimedOut(node_id=node_id))
await anyio.sleep(10)
async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
@@ -495,21 +398,38 @@ class Master:
await self._handle_traces_collected(event)
continue
if isinstance(event, JacclSideChannelData):
await self._apply_and_broadcast(event)
await self._handle_jaccl_side_channel(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
await self._apply_and_broadcast(event)
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
LocalForwarderEvent(
origin=self._system_id,
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly
await self.global_event_sender.send(
ForwarderEvent(
GlobalForwarderEvent(
origin=self.node_id,
origin_idx=event.idx,
session=self.session_id,
@@ -535,49 +455,10 @@ class Master:
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self._apply_and_broadcast(
await self.event_sender.send(
TracesMerged(task_id=task_id, traces=all_trace_data)
)
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
"""Accumulate SideChannel contributions; when all runners for an instance
have submitted for the same sequence, emit JacclSideChannelGathered."""
iid = event.instance_id
seq = event.sequence
if iid not in self._jaccl_pending:
self._jaccl_pending[iid] = {}
if seq not in self._jaccl_pending[iid]:
self._jaccl_pending[iid][seq] = {}
self._jaccl_pending[iid][seq][event.runner_id] = event.data
instance = self.state.instances.get(iid)
if instance is None:
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
return
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
submitted = set(self._jaccl_pending[iid][seq].keys())
logger.info(
f"JACCL side channel: instance={iid} seq={seq} "
f"submitted={len(submitted)}/{len(expected_runners)}"
)
if submitted >= expected_runners:
gathered = dict(self._jaccl_pending[iid][seq])
del self._jaccl_pending[iid][seq]
if not self._jaccl_pending[iid]:
del self._jaccl_pending[iid]
await self._apply_and_broadcast(
JacclSideChannelGathered(
instance_id=iid,
sequence=seq,
gathered_data=gathered,
)
)

View File

@@ -6,11 +6,11 @@ from typing import Sequence
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_largest_cycles,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
@@ -106,27 +106,23 @@ def place_instance(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
largest_cycles = get_largest_cycles(cycles_with_sufficient_memory)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
largest_rdma_cycles = [
cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle)
smallest_rdma_cycles = [
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl:
if not largest_rdma_cycles:
raise ValueError(
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
)
largest_cycles = largest_rdma_cycles
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycle
for cycle in largest_cycles
for cycle in smallest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles,
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_memory[node_id].ram_available for node_id in cycle),
start=Memory(),
@@ -145,15 +141,29 @@ def place_instance(
if len(selected_cycle) == 1:
command.instance_meta = InstanceMeta.MlxRing
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
# TODO(evan): shard assignments should contain information about ranks, this is ugly
def get_device_rank(node_id: NodeId) -> int:
runner_id = shard_assignments.node_to_runner[node_id]
shard_metadata = shard_assignments.runner_to_shard.get(runner_id)
assert shard_metadata is not None
return shard_metadata.device_rank
zero_node_ids = [
node_id
for node_id in selected_cycle.node_ids
if get_device_rank(node_id) == 0
]
assert len(zero_node_ids) == 1
coordinator_node_id = zero_node_ids[0]
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
[node_id for node_id in selected_cycle],
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle.node_ids[0],
coordinator=coordinator_node_id,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_network=node_network,

View File

@@ -37,11 +37,11 @@ def filter_cycles_by_memory(
return filtered_cycles
def get_largest_cycles(
def get_smallest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
max_nodes = max(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == max_nodes]
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
@@ -102,22 +102,21 @@ def _allocate_and_validate_layers(
layer_allocations = allocate_layers_proportionally(
total_layers=model_card.n_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
for node_id in node_ids
node_memory[node_id].ram_available / total_memory for node_id in node_ids
],
)
total_storage_bytes = model_card.storage_size.in_bytes
total_storage = model_card.storage_size
total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i]
required_memory = (total_storage_bytes * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available.in_bytes
required_memory = (total_storage * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
f"but only has {available_memory.in_gb:.2f} GB available"
)
return layer_allocations
@@ -342,6 +341,7 @@ def _find_ip_prioritised(
other_node_id: NodeId,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
ring: bool,
) -> str | None:
"""Find an IP address between nodes with prioritization.
@@ -354,13 +354,27 @@ def _find_ip_prioritised(
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
}
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
# Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
# TODO: Profile and get actual connection speeds.
if ring:
priority = {
"thunderbolt": 0,
"maybe_ethernet": 1,
"ethernet": 2,
"wifi": 3,
"unknown": 4,
}
# RDMA prefers ethernet coordinator
else:
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
@@ -400,7 +414,7 @@ def get_mlx_ring_hosts_by_node(
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_network
node_id, other_node_id, cycle_digraph, node_network, ring=True
)
if connection_ip is None:
raise ValueError(
@@ -431,7 +445,9 @@ def get_mlx_jaccl_coordinators(
if n == coordinator:
return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
ip = _find_ip_prioritised(
n, coordinator, cycle_digraph, node_network, ring=False
)
if ip is not None:
return ip

View File

@@ -1,12 +0,0 @@
from collections.abc import Sequence
from typing import Protocol, runtime_checkable
from exo.shared.types.events import Event
from exo.shared.types.state import State
@runtime_checkable
class ProcessManager(Protocol):
"""A reconciliation step that examines state and returns corrective events."""
async def reconcile(self, state: State) -> Sequence[Event]: ...

View File

@@ -1,62 +0,0 @@
from collections.abc import Sequence
from typing import final
from loguru import logger
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
from exo.shared.types.state import State
MAX_INSTANCE_RETRIES = 3
@final
class InstanceHealthReconciler:
"""Delete instances whose network connections are broken or whose runners have all failed."""
async def reconcile(self, state: State) -> Sequence[Event]:
events: list[Event] = []
for instance_id, instance in state.instances.items():
if not instance_connections_healthy(instance, state.topology):
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error="Network connection lost",
)
)
continue
is_failed, error_message = instance_runners_failed(
instance, state.runners, state.node_identities
)
if is_failed:
# Retry within the same instance if backed by a MetaInstance
mid = instance.meta_instance_id
mi = state.meta_instances.get(mid) if mid else None
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
logger.info(
f"Instance {instance_id} failed (attempt"
f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES}),"
f" retrying: {error_message}"
)
events.append(
InstanceRetrying(
instance_id=instance_id,
meta_instance_id=mid,
failure_error=error_message or "Runner failed",
)
)
else:
if mid and mi:
logger.warning(
f"Instance {instance_id} exceeded retry limit"
f" ({MAX_INSTANCE_RETRIES}), deleting:"
f" {error_message}"
)
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error=error_message,
)
)
return events

View File

@@ -1,92 +0,0 @@
from collections.abc import Sequence
from typing import final
import anyio
from loguru import logger
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
from exo.shared.types.state import State
from exo.shared.types.worker.instances import Instance, InstanceId
MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10
@final
class MetaInstanceReconciler:
"""Place instances for unsatisfied MetaInstances."""
async def reconcile(self, state: State) -> Sequence[Event]:
all_events: list[Event] = []
# Local copy for intermediate tracking — so placement of B
# sees A's instance and doesn't double-place on same resources.
current_instances: dict[InstanceId, Instance] = dict(state.instances)
unsatisfied = find_unsatisfied_meta_instances(
state.meta_instances,
current_instances,
state.topology,
)
for meta_instance in unsatisfied:
try:
with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS):
model_card = await ModelCard.load(meta_instance.model_id)
except TimeoutError:
logger.warning(
f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle"
)
continue
except Exception as exc:
logger.warning(
f"ModelCard.load failed for {meta_instance.model_id}: {exc}"
)
error = f"Failed to load model card: {exc}"
if meta_instance.placement_error != error:
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=error,
)
)
continue
result = try_place_for_meta_instance(
meta_instance,
model_card,
state.topology,
current_instances,
state.node_memory,
state.node_network,
state.tasks,
)
# Update local instance map so next placement sees this one
for event in result.events:
if isinstance(event, InstanceCreated):
logger.info(
f"MetaInstance reconciler placed instance"
f" {event.instance.instance_id} for"
f" {meta_instance.model_id}"
)
current_instances[event.instance.instance_id] = event.instance
all_events.extend(result.events)
# Emit placement failure if error differs from what's already in state
if (
result.error is not None
and meta_instance.placement_error != result.error
):
logger.warning(
f"MetaInstance placement failed for"
f" {meta_instance.model_id}: {result.error}"
)
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=result.error,
)
)
return all_events

View File

@@ -1,27 +0,0 @@
from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from typing import final
from loguru import logger
from exo.shared.types.events import Event, NodeTimedOut
from exo.shared.types.state import State
_DEFAULT_TIMEOUT = timedelta(seconds=30)
@final
class NodeTimeoutReconciler:
"""Time out nodes that haven't been seen recently."""
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
self.timeout = timeout
async def reconcile(self, state: State) -> Sequence[Event]:
now = datetime.now(tz=timezone.utc)
events: list[Event] = []
for node_id, last_seen in state.last_seen.items():
if now - last_seen > self.timeout:
logger.info(f"Removing node {node_id} due to inactivity")
events.append(NodeTimedOut(node_id=node_id))
return events

View File

@@ -1,244 +0,0 @@
from collections.abc import Mapping, Sequence
from typing import NamedTuple
from loguru import logger
from exo.master.placement import get_transition_events, place_instance
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import Event
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.instances import (
BaseInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
RunnerStatus,
)
class PlacementResult(NamedTuple):
"""Result of a placement attempt: events to apply and optional error reason."""
events: Sequence[Event]
error: str | None
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
"""Reconstruct ring order from shard device_rank."""
node_ranks: list[tuple[NodeId, int]] = []
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
shard = instance.shard_assignments.runner_to_shard[runner_id]
node_ranks.append((node_id, shard.device_rank))
node_ranks.sort(key=lambda x: x[1])
return [node_id for node_id, _ in node_ranks]
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
"""Check that the specific IPs used by a ring instance still exist in the topology."""
ring = _get_ring_order(instance)
n = len(ring)
for node in ring:
hosts = instance.hosts_by_node[node]
for idx in range(n):
host = hosts[idx]
if host.ip in ("0.0.0.0", "198.51.100.1"):
continue # self or placeholder
# Real connection: node → ring[idx]. Check specific IP.
connections = topology.get_all_connections_between(node, ring[idx])
if not any(
isinstance(c, SocketConnection)
and c.sink_multiaddr.ip_address == host.ip
for c in connections
):
return False
return True
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
ring = _get_ring_order(instance)
n = len(ring)
for i in range(n):
for j in range(n):
iface = instance.jaccl_devices[i][j]
if iface is None:
continue
connections = topology.get_all_connections_between(ring[i], ring[j])
if not any(
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
for c in connections
):
return False
return True
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
"""Check that an instance's nodes and specific connections are still in the topology."""
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if not all(topology.contains_node(n) for n in instance_nodes):
return False
if len(instance_nodes) <= 1:
return True
match instance:
case MlxRingInstance():
return _ring_connections_healthy(instance, topology)
case MlxJacclInstance():
return _jaccl_connections_healthy(instance, topology)
def instance_runners_failed(
instance: Instance,
runners: Mapping[RunnerId, RunnerStatus],
node_identities: Mapping[NodeId, NodeIdentity],
) -> tuple[bool, str | None]:
"""Check if an instance's runners have all reached terminal failure states.
Returns ``(True, error_message)`` when ALL runners are terminal
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
Returns ``(False, None)`` when runners are still active, haven't reported
yet, or all gracefully shut down (no ``RunnerFailed``).
"""
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
if not instance_runner_ids:
return False, None
# Build reverse mapping: runner_id -> node_id
runner_to_node: dict[RunnerId, NodeId] = {
runner_id: node_id
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
}
has_any_failed = False
error_messages: list[str] = []
for runner_id in instance_runner_ids:
status = runners.get(runner_id)
if status is None:
# Runner hasn't reported yet — instance is still starting
return False, None
if isinstance(status, RunnerFailed):
has_any_failed = True
if status.error_message:
node_id = runner_to_node.get(runner_id)
name = (
node_identities[node_id].friendly_name
if node_id and node_id in node_identities
else node_id or "unknown"
)
error_messages.append(f"{name}: {status.error_message}")
elif isinstance(status, RunnerShutdown):
pass # Terminal but not a failure indicator on its own
else:
# Runner is still active (connecting, loading, running, etc.)
return False, None
if has_any_failed:
return True, "; ".join(error_messages) if error_messages else "Runner failed"
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
return False, None
def instance_satisfies_meta_instance(
meta_instance: MetaInstance,
instance: Instance,
) -> bool:
"""Check if a single instance satisfies a meta-instance's constraints.
This is a pure constraint check (model, min_nodes, node_ids).
Use ``instance_connections_healthy`` separately for topology health.
"""
if instance.shard_assignments.model_id != meta_instance.model_id:
return False
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if len(instance_nodes) < meta_instance.min_nodes:
return False
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
instance_nodes
)
def find_unsatisfied_meta_instances(
meta_instances: Mapping[MetaInstanceId, MetaInstance],
instances: Mapping[InstanceId, Instance],
topology: Topology,
) -> Sequence[MetaInstance]:
"""Return meta-instances that have no healthy backing instance."""
unsatisfied: list[MetaInstance] = []
for meta_id, meta_instance in meta_instances.items():
has_healthy_backing = any(
instance.meta_instance_id == meta_id
and instance_connections_healthy(instance, topology)
for instance in instances.values()
)
if not has_healthy_backing:
unsatisfied.append(meta_instance)
return unsatisfied
def try_place_for_meta_instance(
meta_instance: MetaInstance,
model_card: ModelCard,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
tasks: Mapping[TaskId, Task],
) -> PlacementResult:
"""Try to place an instance satisfying the meta-instance constraints.
Returns a :class:`PlacementResult` with events on success, or an error
reason on failure.
"""
command = PlaceInstance(
model_card=model_card,
sharding=meta_instance.sharding,
instance_meta=meta_instance.instance_meta,
min_nodes=meta_instance.min_nodes,
)
try:
target_instances = place_instance(
command,
topology,
current_instances,
node_memory,
node_network,
required_nodes=(
set(meta_instance.node_ids) if meta_instance.node_ids else None
),
)
# Tag the new instance with meta_instance_id
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
if new_instance_ids:
new_id = next(iter(new_instance_ids))
target_instances[new_id] = target_instances[new_id].model_copy(
update={"meta_instance_id": meta_instance.meta_instance_id}
)
return PlacementResult(
events=list(
get_transition_events(current_instances, target_instances, tasks)
),
error=None,
)
except ValueError as e:
logger.debug(
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
)
return PlacementResult(events=[], error=str(e))

View File

@@ -261,7 +261,7 @@ class TestGenerateClaudeStreamToolUse:
parsed = _parse_sse_events(events)
# Two tool block starts (at indices 1 and 2)
# Two tool block starts (at indices 0 and 1 — no text block when only tools)
tool_starts = [
e
for e in parsed
@@ -270,12 +270,11 @@ class TestGenerateClaudeStreamToolUse:
== "tool_use"
]
assert len(tool_starts) == 2
assert tool_starts[0]["index"] == 1
assert tool_starts[1]["index"] == 2
assert tool_starts[0]["index"] == 0
assert tool_starts[1]["index"] == 1
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
# Two tool block stops (at indices 0 and 1)
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
stop_indices = [e["index"] for e in block_stops]
assert 0 in stop_indices
assert 1 in stop_indices
assert 2 in stop_indices

Some files were not shown because too many files have changed in this diff Show More