Compare commits

...

34 Commits

Author SHA1 Message Date
Alex Cheema
f9ffdaef5f Preserve last_failure_error across instance recreation, fix RDMA banner wording
- apply_instance_created no longer clears last_failure_error so the
  error context persists while the new instance starts up
- Dashboard retryError shows the error without (N/3) prefix when
  consecutiveFailures is 0 (instance was recreated)
- Jaccl warning tooltip now says "experimental RDMA driver in macOS"

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:48:30 -08:00
Alex Cheema
8c2416c9ea chore: remove temporary screenshot files
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:40:16 -08:00
Alex Cheema
e5007f619a temp: add jaccl warning screenshots for PR comment
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:38:53 -08:00
Alex Cheema
a627f67253 dashboard: show warning banner for [jaccl] RDMA driver errors
Detect errors containing "[jaccl]" in MetaInstance failure errors and
display a red dismissible alert banner. The tooltip explains this is a
macOS RDMA driver issue and that the affected machine needs to be
restarted. Alert re-appears if a new error arrives after dismissal.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:38:42 -08:00
Alex Cheema
f189222bfc Merge remote-tracking branch 'origin/main' into alexcheema/meta-instance
# Conflicts:
#	dashboard/src/lib/stores/app.svelte.ts
#	dashboard/src/routes/+page.svelte
2026-02-11 15:59:50 -08:00
Alex Cheema
62e8110e97 fix: prevent DownloadModel TaskCreated event flood (#1452)
## Motivation

When a model download fails repeatedly (e.g. `ContentLengthError` on a
large model like `zai-org/GLM-5`), the download coordinator accumulates
duplicate progress callbacks — one per retry cycle. Each callback
independently throttles at 1 event/sec, so after N retries, every
download progress tick generates N events instead of 1. After an hour of
failures (~60 retry cycles), this produces ~60 `NodeDownloadProgress`
events/sec, overwhelming the master, delaying heartbeats, and causing
the node to time itself out.

### The callback accumulation cycle
1. `_start_download_task()` calls
`shard_downloader.on_progress(callback)` which **appends** to a list
2. Download fails → `DownloadFailed` status set, but old callback stays
in the list
3. 60s later: `_emit_existing_download_progress()` scans disk → resets
status to `DownloadPending`
4. Worker sends new `StartDownload` → coordinator accepts (guard didn't
check `DownloadFailed`)
5. `_start_download_task()` appends **another** callback
6. Each callback has its own throttle → N callbacks = N events per
progress tick

## Changes

### Commit 1: `src/exo/worker/main.py`
Move the `DownloadModel` backoff check **before** `TaskCreated` emission
in `plan_step()`. Previously `TaskCreated` was emitted unconditionally
every 0.1s even when backoff blocked the download command.

### Commit 2: `src/exo/download/coordinator.py`
1. **Register progress callback once** in `__post_init__` instead of
per-download in `_start_download_task()`. Uses a per-model throttle dict
instead of per-callback closure variables.
2. **Add `DownloadFailed` to the `_start_download()` guard** so
redundant `_start_download_task()` calls don't happen. Retries still
work because `_emit_existing_download_progress` resets `DownloadFailed`
→ `DownloadPending` by scanning disk every 60s.

## Why It Works

The root cause was callbacks accumulating in
`ResumableShardDownloader.on_progress_callbacks` (a list that only
appends, never clears). By registering one callback per coordinator
lifetime and guarding against re-entry on `DownloadFailed`, we ensure
exactly one progress event per model per progress tick regardless of how
many retry cycles have occurred.

## Test Plan

### Manual Testing
- Verified the download retry flow: failed download → 60s scan resets
status → new `StartDownload` accepted → download retries with single
callback

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `uv run pytest` — 188 passed

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:50:43 +00:00
Alex Cheema
98773437f3 Make info gatherer monitors resilient with retry loops and timeouts (#1448)
## Motivation

Info gatherer monitors could silently stop posting events, causing stale
node state after rejoins. The macmon monitor was especially fragile — it
had no retry loop, so a crash or parse error would kill it permanently.
Worse, the unhandled exception would propagate to the TaskGroup and take
down *all* sibling monitors. Additionally, none of the monitors had
timeouts on their subprocess calls, so a hung `system_profiler` or
`networksetup` could stall a monitor indefinitely.

## Changes

- Wrap `_monitor_macmon` in a `while True` retry loop with `except
Exception`, matching the pattern used by all other monitors
- Add `fail_after` timeouts to all monitor loop bodies:
- 10s for lightweight commands (`_monitor_misc`, `_watch_system_info`,
`_gather_iface_map` init)
- 30s for heavier commands (`_monitor_system_profiler_thunderbolt_data`,
`_monitor_thunderbolt_bridge_status`)
- Remove unused `CalledProcessError` and `cast` imports

## Why It Works

All monitors now follow the same resilient pattern: `while True` → `try`
with `fail_after` → `except Exception` (logs warning) → `sleep`. If a
subprocess hangs, the timeout fires and `TimeoutError` is caught by the
existing `except Exception` handler. If macmon crashes, it restarts
after the interval instead of dying permanently. No single monitor
failure can cascade to kill the others.

## Test Plan

### Manual Testing
<!-- Hardware: macOS with macmon installed -->
<!-- What you did: -->
- Run exo, kill macmon process (`kill $(pgrep macmon)`), verify it
restarts and metrics resume
- Verify all monitors continue posting events after simulated hangs

### Automated Testing
- All 188 existing tests pass
- basedpyright: 0 errors
- ruff: all checks passed

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:07:36 +00:00
Alex Cheema
ad6d35d68a Retry runners within the same Instance instead of recreating
When runners fail for a MetaInstance-backed Instance, retry up to 3
times by restarting runners within the same Instance rather than
deleting and recreating it each time. After 3 failures, delete the
Instance so MetaInstanceReconciler can create a fresh one.

- Add InstanceRetrying event that removes runners from state (signaling
  workers to restart) and increments consecutive_failures on MetaInstance
- InstanceHealthReconciler emits InstanceRetrying when under retry limit,
  InstanceDeleted when exhausted or no MetaInstance
- Worker _kill_runner detects retry signal (runner deleted from state +
  terminal supervisor) and cleans up for _create_runner to recreate
- Worker _create_runner guards against oscillation by blocking creation
  while any peer runner has explicit terminal status
- InstanceCreated resets consecutive_failures for fresh starts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 14:21:11 -08:00
Jake Hillion
a8acb3cafb dashboard: show available disk space on downloads page
The downloads page previously only showed the approximate space used by
downloaded models (summed from completed download sizes), but did not
show how much disk space was actually available. This made it difficult
to know if a download would succeed before pressing the button.

Added disk space tracking to the InfoGatherer that polls the models
directory partition every 30 seconds. The DiskUsage type captures total
and available space, which flows through the event system to State and
is exposed via the /state API. The dashboard now displays "X on disk /
Y available" for each node in the downloads view.

Test plan:
- CI
2026-02-11 21:57:28 +00:00
Alex Cheema
a0721dbe57 feat: warn when cluster nodes have mismatched macOS versions (#1436)
## Motivation

When nodes in an exo cluster run different macOS versions, inference can
produce incompatible results or fail silently. Users currently have no
way to know this from the dashboard.

## Changes

- Added `get_os_version()` to `system_info.py` that returns the macOS
version (e.g. `"15.3"`) or platform name for non-Mac nodes
- Added `os_version` field to `NodeIdentity` and
`StaticNodeInformation`, gathered once at startup
- Propagated `os_version` through the event sourcing pipeline
(`apply.py`)
- Exposed `nodeIdentities` from the dashboard store with `osVersion`
- Added a derived `macosVersionMismatch` check in `+page.svelte` that
triggers when 2+ macOS nodes report different versions
- Rendered a yellow "INCOMPATIBLE macOS VERSIONS" warning badge
(matching the existing Thunderbolt Bridge cycle warning style) with a
hover tooltip listing each node's name and version, in all three
topology view sizes (large, medium, compact)

## Why It Works

The OS version is a static property gathered once at node startup via
`platform.mac_ver()`. It flows through the existing
`StaticNodeInformation` → `NodeGatheredInfo` event → `NodeIdentity`
state pipeline, so no new event types or state fields beyond
`os_version` on `NodeIdentity` are needed. The dashboard derives the
mismatch by comparing `osVersion` across all nodes whose version looks
like a macOS version string (starts with a digit).

## Test Plan

### Manual Testing
Hardware: 4x Mac Studio M2 Ultra 512GB (s18, s17 (2), james, mike),
connected via Thunderbolt
- s18 and s17 (2) on macOS 26.2, james and mike on macOS 26.3
- Verified the "INCOMPATIBLE macOS VERSIONS" warning badge appears in
the topology view
- Verified the hover tooltip lists all four nodes with their respective
versions
- Screenshots attached in comment below

### Automated Testing
- basedpyright: 0 errors
- ruff check: all checks passed
- nix fmt: no formatting changes needed
- Dashboard builds successfully

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 21:18:59 +00:00
Alex Cheema
c236d62caf Remove timestamp-based retry cooldown
Remove last_failure_at field and RETRY_COOLDOWN_SECONDS logic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:59:39 -08:00
Alex Cheema
a8069e8a30 Consolidate failure state onto MetaInstance, add 5s retry cooldown
Move placement_error, consecutive_failures, last_failure_error, and
last_failure_at directly onto the MetaInstance model instead of keeping
them as separate State mappings (meta_instance_errors, InstanceFailureInfo,
meta_instance_failure_info). Adds a 5-second cooldown between retry attempts
to prevent rapid instance churn when runners fail instantly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:55:47 -08:00
Alex Cheema
84ce555d55 Show retry attempt count with error message, e.g. (2/3)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:43:20 -08:00
Alex Cheema
50e2bcf93e fix: RDMA debug labels, TB5 info box, and rdma_ctl status detection (#1437)
## Motivation

Several RDMA/Thunderbolt UX issues in the dashboard and macOS app:

1. **Debug mode showed "? ?" for RDMA connections** — the topology view
only extracted IPs from socket connections, not RDMA interface names
2. **No way to detect if RDMA is actually enabled** — the system only
knew about TB5 hardware and RDMA topology edges, not whether `rdma_ctl`
was enabled on each node
3. **False "RDMA AVAILABLE" info box** — showed on Mac Minis with idle
TB5 ports even when RDMA was already enabled, and on single nodes with
TB5
4. **macOS app only showed local RDMA status** — ran `rdma_ctl` locally
with no visibility into other nodes in the cluster

## Changes

### Dashboard: Fix RDMA debug labels (`0abc90c4`)
- Added `sourceRdmaIface` and `sinkRdmaIface` to `TopologyEdge`
interface
- Updated `TopologyGraph.svelte` and `ModelCard.svelte` to show `RDMA
en2 → en3` instead of `? ?`

### Dashboard: TB5 RDMA info box (`a3795552`, `8ce8e173`)
- Added dismissible info box when 2+ nodes have TB5 hardware but RDMA is
disabled
- Includes setup instructions (Recovery mode → `rdma_ctl enable` →
reboot, TB5 cables, macOS version match)
- Requires 2+ exo nodes with TB5 to avoid false positives from
single-node setups

### Backend: `rdma_ctl status` detection (`ae07239b`)
- Added `RdmaCtlStatus` event to `info_gatherer.py` — runs `rdma_ctl
status` with 5s timeout, `shutil.which` guard, and `OSError` handling
(polls every 10s on macOS)
- Added `NodeRdmaCtlStatus` model to `profiling.py` and `node_rdma_ctl`
field to `State`
- Handle in `apply.py` (event apply + node timeout cleanup)
- Exposed `nodeRdmaCtl` in dashboard store (`app.svelte.ts`)
- Info box detection now uses actual RDMA status instead of TB5 link
speeds

### Dashboard: Per-node RDMA debug labels (`ae07239b`)
- Debug mode shows `RDMA:ON` (green) or `RDMA:OFF` (dim) per node in
topology view, below the TB bridge label

### macOS app: Cluster-wide RDMA status from `/state` (`a1455b61`,
`d0d77b63`)
- Added `NodeRdmaCtlStatus` to `ClusterState.swift` — decoded from
`/state` endpoint
- Replaced local-only `rdma_ctl status` check with cluster-wide
`nodeRdmaCtl` from state
- Debug section shows per-node RDMA enabled/disabled for all nodes in
the cluster
- Still shows local `ibv_devices` and `ibv_devinfo` details (device
names, active ports) for richer local debugging

## Files changed

| Area | File | Change |
|------|------|--------|
| Backend | `src/exo/utils/info_gatherer/info_gatherer.py` |
`RdmaCtlStatus` event, monitor task |
| Backend | `src/exo/shared/types/profiling.py` | `NodeRdmaCtlStatus`
model |
| Backend | `src/exo/shared/types/state.py` | `node_rdma_ctl` field |
| Backend | `src/exo/shared/apply.py` | Event handler + timeout cleanup
|
| Dashboard | `dashboard/src/lib/stores/app.svelte.ts` | `nodeRdmaCtl` +
`nodeThunderbolt` in store |
| Dashboard | `dashboard/src/routes/+page.svelte` | Info box with RDMA
detection + instructions |
| Dashboard | `dashboard/src/lib/components/TopologyGraph.svelte` | RDMA
debug labels per node + fix "? ?" |
| Dashboard | `dashboard/src/lib/components/ModelCard.svelte` | RDMA
interface display fix |
| App | `app/EXO/EXO/Models/ClusterState.swift` | `NodeRdmaCtlStatus`
struct + decode |
| App | `app/EXO/EXO/ContentView.swift` | Cluster-wide RDMA view + local
device details |
| App | `app/EXO/EXO/Services/NetworkStatusService.swift` | Remove local
`rdma_ctl`, keep `ibv_*` |

## Test Plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — pass
- [x] `nix fmt` — clean
- [x] `cd dashboard && npm run build` — success
- [x] `uv run pytest` — 188 passed
- [x] Xcode build — compiles (only pre-existing `dist/exo` resource
error)
- [x] Deployed to Mac Minis — `nodeRdmaCtl` shows `enabled: true`, no
false info box
- [x] Deployed to James cluster — RDMA debug labels show correctly

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 20:43:11 +00:00
Alex Cheema
b78ea438bc Include node friendly names in runner error messages
Each error in the combined message is now prefixed with the node's friendly
name (e.g. "MacBook Pro: OOM; Mac Studio: connection reset") so the root
cause node is easily identifiable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:41:10 -08:00
Alex Cheema
1960b16f9f Remove permanent retry blocking, allow continuous retry batches
The dashboard % 3 logic already handles displaying retry progress in batches
(RETRYING 1/3, 2/3, 3/3, then PLACING with error, repeat). No need to
permanently block placement after 3 failures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:35:03 -08:00
Alex Cheema
7bed91c9c2 feat: add Recent tab to model picker (#1440)
## Motivation

When frequently switching between models, it's tedious to search through
the full model list to find ones you've used before. A "Recent" tab
provides quick access to previously launched models.

## Changes

- **New store** (`dashboard/src/lib/stores/recents.svelte.ts`):
`RecentsStore` class persisting recently launched model IDs with
timestamps to localStorage (key: `exo-recent-models`). Caps at 20
entries, deduplicates on re-launch (moves to top).
- **FamilySidebar**: Added "Recent" tab between Favorites and Hub,
conditionally shown when there are recent models.
- **FamilyLogos**: Added clock/history icon for the recents tab.
- **ModelPickerModal**: Added `recentModelIds`/`hasRecents` props.
Derives single-variant `ModelGroup[]` from recent IDs and renders them
using the same `ModelPickerGroup` component as all other tabs —
consistent styling, memory grey-out, favorites, info button, download
indicators.
- **+page.svelte**: Calls `recordRecentLaunch(modelId)` after successful
instance launch. Passes reactive recent state to the modal.

## Why It Works

Follows the exact same pattern as the existing Favorites feature
(localStorage persistence, conditional tab display, reactive Svelte 5
`$state`/`$derived`). Recent models are wrapped as single-variant
`ModelGroup` objects so they reuse `ModelPickerGroup` for identical row
rendering across all tabs.

## Test Plan

### Manual Testing
<!-- Hardware: MacBook Pro -->
- Launch a model instance → reopen model picker → "Recent" tab appears
with the launched model
- Launch a second model → it appears at top of the Recent list
- Re-launch the first model → it moves back to top
- Search within the Recent tab filters the list
- Models that don't fit in memory are greyed out (same as All tab)
- Close/reopen browser → recents persist from localStorage

### Automated Testing
- Dashboard builds successfully (`npm run build`)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-11 12:34:08 -08:00
Alex Cheema
c6838c8fd8 Show retry count in exceeded retry limit message (3/3)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:28:17 -08:00
Alex Cheema
420d9b9e76 Collect all runner error messages instead of just the last one
When multiple runners fail, concatenate all error messages with "; " so the
real error isn't hidden by generic side-effect failures from other runners.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:27:49 -08:00
Alex Cheema
13f1e9c489 Stop infinite retries after 3 failures, show errors persistently in dashboard
MetaInstanceReconciler now checks failure count before placement — after 3
consecutive failures it emits MetaInstancePlacementFailed instead of retrying
forever. Dashboard shows "Retrying after error: <msg>" in orange throughout
the retry cycle, not just during the brief window with no backing instance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:21:11 -08:00
Alex Cheema
451a06b3d8 Add instance retry logic with max 3 retries and failure tracking
- Extend InstanceDeleted with failure_error field for runner crash info
- Add InstanceFailureInfo model tracking consecutive failures per MetaInstance
- InstanceHealthReconciler now detects runner failures (all terminal with
  at least one RunnerFailed) in addition to connection failures
- apply_instance_deleted increments failure counter for meta-bound instances
- Dashboard shows RETRYING (N/3) status with error messages, and
  "Instance re-created due to failure" after 3 consecutive failures
- Extract and display RunnerFailed error messages in instance status

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:09:42 -08:00
Alex Cheema
94b55d66f4 Fix MetaInstance.node_ids frozenset failing JSON deserialization
frozenset serializes to a JSON array but cannot be deserialized back
in strict mode through the TaggedModel wrap validator (list → frozenset
coercion is rejected). Changed to list[NodeId] since the model is
already frozen/immutable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:54:56 -08:00
Alex Cheema
2b68b931c5 Send node_ids from placement preview when launching instances
The dashboard now extracts node IDs from the selected preview's
memory_delta_by_node, ensuring the backend places on exactly the
nodes the user was shown. Also reverts incorrect RDMA min_nodes >= 2
enforcement since single-node RDMA is valid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:49:37 -08:00
Alex Cheema
4aecaa7748 Enforce min_nodes >= 2 for RDMA (MlxJaccl) instances
RDMA requires at least 2 nodes — a single-node RDMA instance is
nonsensical. Enforce this in both the dashboard (when building the
launch request) and the backend placement (when filtering cycles).
Previously, selecting RDMA would still place on 1 node because
min_nodes defaulted to 1 and the placement silently switched to Ring.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:42:21 -08:00
Alex Cheema
25e2891c30 Ensure min_nodes >= node filter size when launching
When user selects specific nodes via the filter, min_nodes should be at
least the number of filtered nodes to prevent placement from picking a
smaller cycle.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:36:51 -08:00
Alex Cheema
16345e0ffa Send node_ids from dashboard, error on RDMA when unavailable
Dashboard was not including the user's node filter in the POST to
/meta_instance, so placement ignored which nodes the user selected.
Also, placement silently fell back to Ring when RDMA was requested but
no RDMA-connected cycles were available — now raises an error that
surfaces via MetaInstancePlacementFailed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:26:29 -08:00
Alex Cheema
3a845f90b0 Fix use_default validator silently ignoring sharding/instance_meta
The mode="plain" validator bypassed Pydantic's string-to-enum coercion,
so JSON strings like "Tensor" and "MlxJaccl" from the dashboard failed
the isinstance check and silently fell back to Pipeline/MlxRing defaults.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:05:00 -08:00
Alex Cheema
dccf2440ba Add placement error feedback and per-node loading status
Show why MetaInstance placement fails instead of stuck "PLACING", and
show per-node runner status during loading for multi-node instances.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:01:07 -08:00
Alex Cheema
f96f3f2c0f Show MetaInstance sharding/type while PLACING, fix MlxIbv references
When a MetaInstance has no backing instance yet, derive the strategy
display from the MetaInstance's own sharding and instanceMeta fields
rather than showing "Unknown (Unknown)".

Also clean up all stale MlxIbv references across the dashboard —
the backend enum is MlxJaccl.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:23:44 -08:00
Alex Cheema
7d54e468d5 Extract reconciler into ProcessManager protocol, fix RDMA instance type
- Replace inline _plan() with ProcessManager loop (_reconcile), tick
  every 1s instead of 10s — safe because all PMs are idempotent
- Fix dashboard sending "MlxIbv" instead of "MlxJaccl" for RDMA
  instance type, which silently fell back to MlxRing default
- Remove all stale MlxIbv references from dashboard

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:19:13 -08:00
Alex Cheema
124d504f95 Extract reconciler into ProcessManager protocol
Replace inline _plan() steps with a list of ProcessManagers, each
implementing async reconcile(State) -> Sequence[Event]. Tick every
1s instead of 10s — safe because all PMs are idempotent against state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:06:41 -08:00
Alex Cheema
9ab4a40989 Simplify MetaInstance binding: put meta_instance_id on Instance
The separate MetaInstanceBound event + meta_instance_backing map
introduced two bugs: stale exclusion sets in the reconciler loop and
a delete ordering race. Embedding meta_instance_id directly on
BaseInstance eliminates the binding mechanism entirely — when an
instance is created for a MetaInstance it carries the ID, when
deleted the binding is gone. No separate map, no cleanup, no races.

Also fixes delete_meta_instance to cascade-delete backing instances.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 16:15:29 -08:00
Alex Cheema
f4329c72c2 Add explicit MetaInstance binding, slim MetaInstance to use ModelId
- Add MetaInstanceBound event and meta_instance_backing State field
  for explicit MetaInstance → Instance binding (prevents ambiguous
  linking when two MetaInstances have identical constraints)
- Replace model_card: ModelCard with model_id: ModelId on MetaInstance
  (load ModelCard on-demand at placement time)
- Add MetaInstance API endpoints (POST /meta_instance, DELETE)
- Update dashboard to use MetaInstances as primary primitive with
  unified display items merging MetaInstances and orphan instances
- Dashboard launches via MetaInstance instead of direct Instance creation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 15:53:07 -08:00
Alex Cheema
ceb76b8f6c Add MetaInstance declarative layer with connection health checking
Introduces MetaInstance as a declarative constraint ensuring an instance
matching given parameters (model, sharding, min_nodes) always exists.
The master's reconciliation loop continuously checks for unsatisfied
meta-instances and attempts placement. Connection health checking
verifies that specific IPs (MlxRing) and RDMA interfaces (MlxJaccl)
stored on instances still exist as topology edges, enabling automatic
recovery when cables are swapped or interfaces change.

Also eliminates the master's loopback event path, unifying all event
emission through _apply_and_broadcast for simpler control flow.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 13:43:50 -08:00
39 changed files with 3286 additions and 545 deletions

View File

@@ -119,3 +119,78 @@ From .cursorrules:
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
## Dashboard UI Testing & Screenshots
### Building and Running the Dashboard
```bash
# Build the dashboard (must be done before running exo)
cd dashboard && npm install && npm run build && cd ..
# Start exo (serves the dashboard at http://localhost:52415)
uv run exo &
sleep 8 # Wait for server to start
```
### Taking Headless Screenshots with Playwright
Use Playwright with headless Chromium for programmatic screenshots — no manual browser interaction needed.
**Setup (one-time):**
```bash
npx --yes playwright install chromium
cd /tmp && npm init -y && npm install playwright
```
**Taking screenshots:**
```javascript
// Run from /tmp where playwright is installed: cd /tmp && node -e "..."
const { chromium } = require('playwright');
(async () => {
const browser = await chromium.launch({ headless: true });
const page = await browser.newPage({ viewport: { width: 1280, height: 800 } });
await page.goto('http://localhost:52415', { waitUntil: 'networkidle' });
await page.waitForTimeout(2000);
// Inject test data into localStorage if needed (e.g., recent models)
await page.evaluate(() => {
localStorage.setItem('exo-recent-models', JSON.stringify([
{ modelId: 'mlx-community/Qwen3-30B-A3B-4bit', launchedAt: Date.now() },
]));
});
await page.reload({ waitUntil: 'networkidle' });
await page.waitForTimeout(2000);
// Interact with UI elements
await page.locator('text=SELECT MODEL').click();
await page.waitForTimeout(1000);
// Take screenshot
await page.screenshot({ path: '/tmp/screenshot.png', fullPage: false });
await browser.close();
})();
```
### Uploading Images to GitHub PRs
GitHub's API doesn't support direct image upload for PR comments. Workaround:
1. **Commit images to the branch** (temporarily):
```bash
cp /tmp/screenshot.png .
git add screenshot.png
git commit -m "temp: add screenshots for PR"
git push origin <branch>
COMMIT_SHA=$(git rev-parse HEAD)
```
2. **Post PR comment** referencing the raw image URL (uses permanent commit SHA so images survive deletion):
```bash
gh pr comment <PR_NUMBER> --body "![Screenshot](https://raw.githubusercontent.com/exo-explore/exo/${COMMIT_SHA}/screenshot.png)"
```
3. **Remove the images** from the branch:
```bash
git rm screenshot.png
git commit -m "chore: remove temporary screenshot files"
git push origin <branch>
```
The images still render in the PR comment because they reference the permanent commit SHA.

View File

@@ -563,21 +563,45 @@ struct ContentView: View {
}
private var rdmaStatusView: some View {
let rdma = networkStatusService.status.rdmaStatus
let rdmaStatuses = stateService.latestSnapshot?.nodeRdmaCtl ?? [:]
let localNodeId = stateService.localNodeId
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
let localDevices = networkStatusService.status.localRdmaDevices
let localPorts = networkStatusService.status.localRdmaActivePorts
return VStack(alignment: .leading, spacing: 1) {
Text("RDMA: \(rdmaStatusText(rdma))")
.font(.caption2)
.foregroundColor(rdmaStatusColor(rdma))
if !rdma.devices.isEmpty {
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
if rdmaStatuses.isEmpty {
Text("Cluster RDMA: No data")
.font(.caption2)
.foregroundColor(.secondary)
} else {
Text("Cluster RDMA Status:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(Array(rdmaStatuses.keys.sorted()), id: \.self) { nodeId in
if let status = rdmaStatuses[nodeId] {
let nodeName =
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
let isLocal = nodeId == localNodeId
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
let statusText = status.enabled ? "Enabled" : "Disabled"
let color: Color = status.enabled ? .green : .orange
Text("\(prefix) \(statusText)")
.font(.caption2)
.foregroundColor(color)
}
}
}
if !localDevices.isEmpty {
Text(" Local Devices: \(localDevices.joined(separator: ", "))")
.font(.caption2)
.foregroundColor(.secondary)
}
if !rdma.activePorts.isEmpty {
Text(" Active Ports:")
if !localPorts.isEmpty {
Text(" Local Active Ports:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(rdma.activePorts, id: \.device) { port in
ForEach(localPorts, id: \.device) { port in
Text(" \(port.device) port \(port.port): \(port.state)")
.font(.caption2)
.foregroundColor(.green)
@@ -586,28 +610,6 @@ struct ContentView: View {
}
}
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
switch rdma.rdmaCtlEnabled {
case .some(true):
return "Enabled"
case .some(false):
return "Disabled"
case nil:
return rdma.devices.isEmpty ? "Not Available" : "Available"
}
}
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
switch rdma.rdmaCtlEnabled {
case .some(true):
return .green
case .some(false):
return .orange
case nil:
return rdma.devices.isEmpty ? .secondary : .green
}
}
private var sendBugReportButton: some View {
VStack(alignment: .leading, spacing: 4) {
Button {

View File

@@ -15,6 +15,7 @@ struct ClusterState: Decodable {
let nodeMemory: [String: MemoryInfo]
let nodeSystem: [String: SystemInfo]
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
let nodeRdmaCtl: [String: NodeRdmaCtlStatus]
/// Computed property for backwards compatibility - merges granular state into NodeProfile
var nodeProfiles: [String: NodeProfile] {
@@ -65,6 +66,10 @@ struct ClusterState: Decodable {
try container.decodeIfPresent(
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
) ?? [:]
self.nodeRdmaCtl =
try container.decodeIfPresent(
[String: NodeRdmaCtlStatus].self, forKey: .nodeRdmaCtl
) ?? [:]
}
private enum CodingKeys: String, CodingKey {
@@ -78,6 +83,7 @@ struct ClusterState: Decodable {
case nodeMemory
case nodeSystem
case nodeThunderboltBridge
case nodeRdmaCtl
}
}
@@ -159,6 +165,10 @@ struct ThunderboltBridgeStatus: Decodable {
let serviceName: String?
}
struct NodeRdmaCtlStatus: Decodable {
let enabled: Bool
}
struct MemoryInfo: Decodable {
let ramTotal: MemoryValue?
let ramAvailable: MemoryValue?

View File

@@ -35,28 +35,18 @@ struct NetworkStatus: Equatable {
let thunderboltBridgeState: ThunderboltState?
let bridgeInactive: Bool?
let interfaceStatuses: [InterfaceIpStatus]
let rdmaStatus: RDMAStatus
let localRdmaDevices: [String]
let localRdmaActivePorts: [RDMAPort]
static let empty = NetworkStatus(
thunderboltBridgeState: nil,
bridgeInactive: nil,
interfaceStatuses: [],
rdmaStatus: .empty
localRdmaDevices: [],
localRdmaActivePorts: []
)
}
struct RDMAStatus: Equatable {
let rdmaCtlEnabled: Bool?
let devices: [String]
let activePorts: [RDMAPort]
var isAvailable: Bool {
rdmaCtlEnabled == true || !devices.isEmpty
}
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
}
struct RDMAPort: Equatable {
let device: String
let port: String
@@ -80,31 +70,11 @@ private struct NetworkStatusFetcher {
thunderboltBridgeState: readThunderboltBridgeState(),
bridgeInactive: readBridgeInactive(),
interfaceStatuses: readInterfaceStatuses(),
rdmaStatus: readRDMAStatus()
localRdmaDevices: readRDMADevices(),
localRdmaActivePorts: readRDMAActivePorts()
)
}
private func readRDMAStatus() -> RDMAStatus {
let rdmaCtlEnabled = readRDMACtlEnabled()
let devices = readRDMADevices()
let activePorts = readRDMAActivePorts()
return RDMAStatus(
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
}
private func readRDMACtlEnabled() -> Bool? {
let result = runCommand(["rdma_ctl", "status"])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
if output.contains("enabled") {
return true
}
if output.contains("disabled") {
return false
}
return nil
}
private func readRDMADevices() -> [String] {
let result = runCommand(["ibv_devices"])
guard result.exitCode == 0 else { return [] }

View File

@@ -185,11 +185,7 @@
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {

View File

@@ -13,6 +13,12 @@
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else if family === "recents"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M13 3a9 9 0 0 0-9 9H1l3.89 3.89.07.14L9 12H6c0-3.87 3.13-7 7-7s7 3.13 7 7-3.13 7-7 7c-1.93 0-3.68-.79-4.94-2.06l-1.42 1.42A8.954 8.954 0 0 0 13 21a9 9 0 0 0 0-18zm-1 5v5l4.28 2.54.72-1.21-3.5-2.08V8H12z"
/>
</svg>
{:else if family === "llama" || family === "meta"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path

View File

@@ -5,15 +5,22 @@
families: string[];
selectedFamily: string | null;
hasFavorites: boolean;
hasRecents: boolean;
onSelect: (family: string | null) => void;
};
let { families, selectedFamily, hasFavorites, onSelect }: FamilySidebarProps =
$props();
let {
families,
selectedFamily,
hasFavorites,
hasRecents,
onSelect,
}: FamilySidebarProps = $props();
// Family display names
const familyNames: Record<string, string> = {
favorites: "Favorites",
recents: "Recent",
huggingface: "Hub",
llama: "Meta",
qwen: "Qwen",
@@ -89,6 +96,31 @@
</button>
{/if}
<!-- Recent (only show if has recent models) -->
{#if hasRecents}
<button
type="button"
onclick={() => onSelect("recents")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'recents'
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Recently launched models"
>
<FamilyLogos
family="recents"
class={selectedFamily === "recents"
? "text-exo-yellow"
: "text-white/50 group-hover:text-white/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'recents'
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}">Recent</span
>
</button>
{/if}
<!-- HuggingFace Hub -->
<button
type="button"

View File

@@ -21,7 +21,7 @@
} | null;
nodes?: Record<string, NodeInfo>;
sharding?: "Pipeline" | "Tensor";
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
runtime?: "MlxRing" | "MlxJaccl";
onLaunch?: () => void;
tags?: string[];
apiPreview?: PlacementPreview | null;
@@ -348,7 +348,7 @@
// Debug mode state
const isDebugMode = $derived(debugMode());
const topology = $derived(topologyData());
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
const isRdma = $derived(runtime === "MlxJaccl");
// Get interface name for an IP from node data
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
@@ -422,9 +422,16 @@
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
for (const edge of topology.edges) {
const ip = edge.sendBackIp || "?";
const iface =
edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
let ip: string;
let iface: string | null;
if (edge.sourceRdmaIface || edge.sinkRdmaIface) {
ip = "RDMA";
iface = `${edge.sourceRdmaIface || "?"} \u2192 ${edge.sinkRdmaIface || "?"}`;
} else {
ip = edge.sendBackIp || "?";
iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
}
if (edge.source === nodeId1 && edge.target === nodeId2) {
aToBCandidates.push({ ip, iface });
@@ -568,7 +575,7 @@
>
{runtime === "MlxRing"
? "MLX Ring"
: runtime === "MlxIbv" || runtime === "MlxJaccl"
: runtime === "MlxJaccl"
? "MLX RDMA"
: runtime}
</span>

View File

@@ -40,6 +40,7 @@
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
downloadStatusMap?: Map<string, DownloadAvailability>;
launchedAt?: number;
};
let {
@@ -54,6 +55,7 @@
onToggleFavorite,
onShowInfo,
downloadStatusMap,
launchedAt,
}: ModelPickerGroupProps = $props();
// Group-level download status: show if any variant is downloaded
@@ -75,6 +77,17 @@
return `${mb}MB`;
}
function timeAgo(ts: number): string {
const seconds = Math.floor((Date.now() - ts) / 1000);
if (seconds < 60) return "just now";
const minutes = Math.floor(seconds / 60);
if (minutes < 60) return `${minutes}m ago`;
const hours = Math.floor(minutes / 60);
if (hours < 24) return `${hours}h ago`;
const days = Math.floor(hours / 24);
return `${days}d ago`;
}
// Check if any variant can fit
const anyVariantFits = $derived(
group.variants.some((v) => canModelFit(v.id)),
@@ -300,6 +313,13 @@
</span>
{/if}
<!-- Time ago (for recent models) -->
{#if launchedAt}
<span class="text-xs font-mono text-white/20 flex-shrink-0">
{timeAgo(launchedAt)}
</span>
{/if}
<!-- Download availability indicator -->
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
<span

View File

@@ -6,6 +6,7 @@
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
import { getRecentEntries } from "$lib/stores/recents.svelte";
interface ModelInfo {
id: string;
@@ -53,6 +54,8 @@
models: ModelInfo[];
selectedModelId: string | null;
favorites: Set<string>;
recentModelIds?: string[];
hasRecents?: boolean;
existingModelIds: Set<string>;
canModelFit: (modelId: string) => boolean;
getModelFitStatus: (modelId: string) => ModelFitStatus;
@@ -79,6 +82,8 @@
models,
selectedModelId,
favorites,
recentModelIds = [],
hasRecents: hasRecentsTab = false,
existingModelIds,
canModelFit,
getModelFitStatus,
@@ -387,7 +392,11 @@
// Filter by family
if (selectedFamily === "favorites") {
result = result.filter((g) => favorites.has(g.id));
} else if (selectedFamily && selectedFamily !== "huggingface") {
} else if (
selectedFamily &&
selectedFamily !== "huggingface" &&
selectedFamily !== "recents"
) {
result = result.filter((g) => g.family === selectedFamily);
}
@@ -461,6 +470,48 @@
// Check if any favorites exist
const hasFavorites = $derived(favorites.size > 0);
// Timestamp lookup for recent models
const recentTimestamps = $derived(
new Map(getRecentEntries().map((e) => [e.modelId, e.launchedAt])),
);
// Recent models: single-variant ModelGroups in launch order
const recentGroups = $derived.by((): ModelGroup[] => {
if (!recentModelIds || recentModelIds.length === 0) return [];
const result: ModelGroup[] = [];
for (const id of recentModelIds) {
const model = models.find((m) => m.id === id);
if (model) {
result.push({
id: model.base_model || model.id,
name: model.name || model.id,
capabilities: model.capabilities || ["text"],
family: model.family || "",
variants: [model],
smallestVariant: model,
hasMultipleVariants: false,
});
}
}
return result;
});
// Filtered recent groups (apply search query)
const filteredRecentGroups = $derived.by((): ModelGroup[] => {
if (!searchQuery.trim()) return recentGroups;
const query = searchQuery.toLowerCase().trim();
return recentGroups.filter(
(g) =>
g.name.toLowerCase().includes(query) ||
g.variants.some(
(v) =>
v.id.toLowerCase().includes(query) ||
(v.name || "").toLowerCase().includes(query) ||
(v.quantization || "").toLowerCase().includes(query),
),
);
});
function toggleGroupExpanded(groupId: string) {
const next = new Set(expandedGroups);
if (next.has(groupId)) {
@@ -618,6 +669,7 @@
families={uniqueFamilies}
{selectedFamily}
{hasFavorites}
hasRecents={hasRecentsTab}
onSelect={(family) => (selectedFamily = family)}
/>
@@ -725,6 +777,44 @@
</div>
</div>
</div>
{:else if selectedFamily === "recents"}
<!-- Recent models view -->
{#if filteredRecentGroups.length === 0}
<div
class="flex flex-col items-center justify-center h-full text-white/40 p-8"
>
<svg
class="w-12 h-12 mb-3"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M13 3a9 9 0 0 0-9 9H1l3.89 3.89.07.14L9 12H6c0-3.87 3.13-7 7-7s7 3.13 7 7-3.13 7-7 7c-1.93 0-3.68-.79-4.94-2.06l-1.42 1.42A8.954 8.954 0 0 0 13 21a9 9 0 0 0 0-18zm-1 5v5l4.28 2.54.72-1.21-3.5-2.08V8H12z"
/>
</svg>
<p class="font-mono text-sm">
{searchQuery
? "No matching recent models"
: "No recently launched models"}
</p>
</div>
{:else}
{#each filteredRecentGroups as group}
<ModelPickerGroup
{group}
isExpanded={expandedGroups.has(group.id)}
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
downloadStatusMap={getVariantDownloadMap(group)}
launchedAt={recentTimestamps.get(group.variants[0]?.id ?? "")}
/>
{/each}
{/if}
{:else if filteredGroups.length === 0}
<div
class="flex flex-col items-center justify-center h-full text-white/40 p-8"

View File

@@ -6,6 +6,7 @@
isTopologyMinimized,
debugMode,
nodeThunderboltBridge,
nodeRdmaCtl,
type NodeInfo,
} from "$lib/stores/app.svelte";
@@ -31,6 +32,7 @@
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
const tbBridgeData = $derived(nodeThunderboltBridge());
const rdmaCtlData = $derived(nodeRdmaCtl());
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -333,14 +335,27 @@
if (edge.source === a) entry.aToB = true;
else entry.bToA = true;
const ip = edge.sendBackIp || "?";
const ifaceInfo = getInterfaceLabel(edge.source, ip);
let ip: string;
let ifaceLabel: string;
let missingIface: boolean;
if (edge.sourceRdmaIface || edge.sinkRdmaIface) {
ip = "RDMA";
ifaceLabel = `${edge.sourceRdmaIface || "?"} \u2192 ${edge.sinkRdmaIface || "?"}`;
missingIface = false;
} else {
ip = edge.sendBackIp || "?";
const ifaceInfo = getInterfaceLabel(edge.source, ip);
ifaceLabel = ifaceInfo.label;
missingIface = ifaceInfo.missing;
}
entry.connections.push({
from: edge.source,
to: edge.target,
ip,
ifaceLabel: ifaceInfo.label,
missingIface: ifaceInfo.missing,
ifaceLabel,
missingIface,
});
pairMap.set(key, entry);
});
@@ -1120,15 +1135,17 @@
.text(` (${ramUsagePercent.toFixed(0)}%)`);
}
// Debug mode: Show TB bridge status
// Debug mode: Show TB bridge and RDMA status
if (debugEnabled) {
let debugLabelY =
nodeInfo.y +
iconBaseHeight / 2 +
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
const debugFontSize = showFullLabels ? 9 : 7;
const debugLineHeight = showFullLabels ? 11 : 9;
const tbStatus = tbBridgeData[nodeInfo.id];
if (tbStatus) {
const tbY =
nodeInfo.y +
iconBaseHeight / 2 +
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
const tbFontSize = showFullLabels ? 9 : 7;
const tbColor = tbStatus.enabled
? "rgba(234,179,8,0.9)"
: "rgba(100,100,100,0.7)";
@@ -1136,12 +1153,30 @@
nodeG
.append("text")
.attr("x", nodeInfo.x)
.attr("y", tbY)
.attr("y", debugLabelY)
.attr("text-anchor", "middle")
.attr("fill", tbColor)
.attr("font-size", tbFontSize)
.attr("font-size", debugFontSize)
.attr("font-family", "SF Mono, Monaco, monospace")
.text(tbText);
debugLabelY += debugLineHeight;
}
const rdmaStatus = rdmaCtlData[nodeInfo.id];
if (rdmaStatus !== undefined) {
const rdmaColor = rdmaStatus.enabled
? "rgba(74,222,128,0.9)"
: "rgba(100,100,100,0.7)";
const rdmaText = rdmaStatus.enabled ? "RDMA:ON" : "RDMA:OFF";
nodeG
.append("text")
.attr("x", nodeInfo.x)
.attr("y", debugLabelY)
.attr("text-anchor", "middle")
.attr("fill", rdmaColor)
.attr("font-size", debugFontSize)
.attr("font-family", "SF Mono, Monaco, monospace")
.text(rdmaText);
}
}
});

View File

@@ -49,6 +49,7 @@ export interface NodeInfo {
};
last_macmon_update: number;
friendly_name?: string;
os_version?: string;
}
export interface TopologyEdge {
@@ -56,6 +57,8 @@ export interface TopologyEdge {
target: string;
sendBackIp?: string;
sendBackInterface?: string;
sourceRdmaIface?: string;
sinkRdmaIface?: string;
}
export interface TopologyData {
@@ -76,6 +79,8 @@ interface RawNodeIdentity {
modelId?: string;
chipId?: string;
friendlyName?: string;
osVersion?: string;
osBuildVersion?: string;
}
interface RawMemoryUsage {
@@ -163,7 +168,7 @@ export interface ModelDownloadStatus {
export interface PlacementPreview {
model_id: string;
sharding: "Pipeline" | "Tensor";
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
instance_meta: "MlxRing" | "MlxJaccl";
instance: unknown | null;
memory_delta_by_node: Record<string, number> | null;
error: string | null;
@@ -214,7 +219,6 @@ interface RawStateResponse {
string,
{
MlxRingInstance?: Instance;
MlxIbvInstance?: Instance;
MlxJacclInstance?: Instance;
}
>;
@@ -225,6 +229,19 @@ interface RawStateResponse {
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
// Thunderbolt identifiers per node
nodeThunderbolt?: Record<
string,
{
interfaces: Array<{
rdmaInterface: string;
domainUuid: string;
linkSpeed: string;
}>;
}
>;
// RDMA ctl status per node
nodeRdmaCtl?: Record<string, { enabled: boolean }>;
// Thunderbolt bridge status per node
nodeThunderboltBridge?: Record<
string,
@@ -232,6 +249,20 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// MetaInstances (declarative instance constraints)
metaInstances?: Record<string, MetaInstanceData>;
}
export interface MetaInstanceData {
metaInstanceId: string;
modelId: string;
sharding: string;
instanceMeta: string;
minNodes: number;
nodeIds: string[] | null;
placementError: string | null;
consecutiveFailures: number;
lastFailureError: string | null;
}
export interface MessageAttachment {
@@ -425,6 +456,7 @@ function transformTopology(
},
last_macmon_update: Date.now() / 1000,
friendly_name: identity?.friendlyName,
os_version: identity?.osVersion,
};
}
@@ -437,6 +469,8 @@ function transformTopology(
if (!Array.isArray(edgeList)) continue;
for (const edge of edgeList) {
let sendBackIp: string | undefined;
let sourceRdmaIface: string | undefined;
let sinkRdmaIface: string | undefined;
if (edge && typeof edge === "object" && "sinkMultiaddr" in edge) {
const multiaddr = edge.sinkMultiaddr;
if (multiaddr) {
@@ -444,10 +478,23 @@ function transformTopology(
multiaddr.ip_address ||
extractIpFromMultiaddr(multiaddr.address);
}
} else if (
edge &&
typeof edge === "object" &&
"sourceRdmaIface" in edge
) {
sourceRdmaIface = edge.sourceRdmaIface;
sinkRdmaIface = edge.sinkRdmaIface;
}
if (nodes[source] && nodes[sink] && source !== sink) {
edges.push({ source, target: sink, sendBackIp });
edges.push({
source,
target: sink,
sendBackIp,
sourceRdmaIface,
sinkRdmaIface,
});
}
}
}
@@ -490,12 +537,33 @@ class AppStore {
instances = $state<Record<string, unknown>>({});
runners = $state<Record<string, unknown>>({});
downloads = $state<Record<string, unknown[]>>({});
nodeDisk = $state<
Record<
string,
{ total: { inBytes: number }; available: { inBytes: number } }
>
>({});
placementPreviews = $state<PlacementPreview[]>([]);
selectedPreviewModelId = $state<string | null>(null);
isLoadingPreviews = $state(false);
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
metaInstances = $state<Record<string, MetaInstanceData>>({});
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderbolt = $state<
Record<
string,
{
interfaces: Array<{
rdmaInterface: string;
domainUuid: string;
linkSpeed: string;
}>;
}
>
>({});
nodeRdmaCtl = $state<Record<string, { enabled: boolean }>>({});
nodeThunderboltBridge = $state<
Record<
string,
@@ -837,11 +905,7 @@ class AppStore {
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {
@@ -1206,6 +1270,17 @@ class AppStore {
if (data.downloads) {
this.downloads = data.downloads;
}
// MetaInstances
this.metaInstances = data.metaInstances ?? {};
if (data.nodeDisk) {
this.nodeDisk = data.nodeDisk;
}
// Node identities (for OS version mismatch detection)
this.nodeIdentities = data.nodeIdentities ?? {};
// Thunderbolt identifiers per node
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
// RDMA ctl status per node
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
// Thunderbolt bridge cycles
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
// Thunderbolt bridge status per node
@@ -2956,8 +3031,10 @@ export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const metaInstances = () => appStore.metaInstances;
export const runners = () => appStore.runners;
export const downloads = () => appStore.downloads;
export const nodeDisk = () => appStore.nodeDisk;
export const placementPreviews = () => appStore.placementPreviews;
export const selectedPreviewModelId = () => appStore.selectedPreviewModelId;
export const isLoadingPreviews = () => appStore.isLoadingPreviews;
@@ -3038,7 +3115,12 @@ export const setChatSidebarVisible = (visible: boolean) =>
appStore.setChatSidebarVisible(visible);
export const refreshState = () => appStore.fetchState();
// Thunderbolt bridge status
// Node identities (for OS version mismatch detection)
export const nodeIdentities = () => appStore.nodeIdentities;
// Thunderbolt & RDMA status
export const nodeThunderbolt = () => appStore.nodeThunderbolt;
export const nodeRdmaCtl = () => appStore.nodeRdmaCtl;
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;

View File

@@ -0,0 +1,75 @@
/**
* RecentsStore - Manages recently launched models with localStorage persistence
*/
import { browser } from "$app/environment";
const RECENTS_KEY = "exo-recent-models";
const MAX_RECENT_MODELS = 20;
interface RecentEntry {
modelId: string;
launchedAt: number;
}
class RecentsStore {
recents = $state<RecentEntry[]>([]);
constructor() {
if (browser) {
this.loadFromStorage();
}
}
private loadFromStorage() {
try {
const stored = localStorage.getItem(RECENTS_KEY);
if (stored) {
const parsed = JSON.parse(stored) as RecentEntry[];
this.recents = parsed;
}
} catch (error) {
console.error("Failed to load recent models:", error);
}
}
private saveToStorage() {
try {
localStorage.setItem(RECENTS_KEY, JSON.stringify(this.recents));
} catch (error) {
console.error("Failed to save recent models:", error);
}
}
recordLaunch(modelId: string) {
// Remove existing entry for this model (if any) to move it to top
const filtered = this.recents.filter((r) => r.modelId !== modelId);
// Prepend new entry
const next = [{ modelId, launchedAt: Date.now() }, ...filtered];
// Cap at max
this.recents = next.slice(0, MAX_RECENT_MODELS);
this.saveToStorage();
}
getRecentModelIds(): string[] {
return this.recents.map((r) => r.modelId);
}
hasAny(): boolean {
return this.recents.length > 0;
}
clearAll() {
this.recents = [];
this.saveToStorage();
}
}
export const recentsStore = new RecentsStore();
export const hasRecents = () => recentsStore.hasAny();
export const getRecentModelIds = () => recentsStore.getRecentModelIds();
export const getRecentEntries = () => recentsStore.recents;
export const recordRecentLaunch = (modelId: string) =>
recentsStore.recordLaunch(modelId);
export const clearRecents = () => recentsStore.clearAll();

View File

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@
import {
topologyData,
downloads,
nodeDisk,
type DownloadProgress,
refreshState,
lastUpdate as lastUpdateStore,
@@ -37,10 +38,13 @@
nodeId: string;
nodeName: string;
models: ModelEntry[];
diskAvailable?: number;
diskTotal?: number;
};
const data = $derived(topologyData());
const downloadsData = $derived(downloads());
const nodeDiskData = $derived(nodeDisk());
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -327,10 +331,17 @@
];
}
// Get disk info for this node
const diskInfo = nodeDiskData?.[nodeId];
const diskAvailable = diskInfo?.available?.inBytes;
const diskTotal = diskInfo?.total?.inBytes;
built.push({
nodeId,
nodeName: getNodeLabel(nodeId),
models,
diskAvailable,
diskTotal,
});
}
@@ -417,6 +428,14 @@
<div class="text-xs text-exo-light-gray font-mono truncate">
{node.nodeId}
</div>
<div class="text-xs text-exo-light-gray font-mono mt-1">
{formatBytes(
node.models
.filter((m) => m.status === "completed")
.reduce((sum, m) => sum + m.totalBytes, 0),
)} models{#if node.diskAvailable != null}
- {formatBytes(node.diskAvailable)} free{/if}
</div>
</div>
<div
class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right"
@@ -429,13 +448,6 @@
/ {node.models.length} models</span
>
</div>
<div class="text-exo-light-gray normal-case tracking-normal">
{formatBytes(
node.models
.filter((m) => m.status === "completed")
.reduce((sum, m) => sum + m.totalBytes, 0),
)} on disk
</div>
</div>
</div>

View File

@@ -56,8 +56,49 @@ class DownloadCoordinator:
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
# Per-model throttle for download progress events
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self.shard_downloader.on_progress(self._download_progress_callback)
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
model_id = callback_shard.model_card.model_id
throttle_interval_secs = 1.0
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
if model_id in self.active_downloads:
del self.active_downloads[model_id]
self._last_progress_time.pop(model_id, None)
elif (
progress.status == "in_progress"
and current_time() - self._last_progress_time.get(model_id, 0.0)
> throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
self._last_progress_time[model_id] = current_time()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
@@ -119,12 +160,12 @@ class DownloadCoordinator:
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
# Check if already downloading or complete
# Check if already downloading, complete, or recently failed
if model_id in self.download_status:
status = self.download_status[model_id]
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
if isinstance(status, (DownloadOngoing, DownloadCompleted, DownloadFailed)):
logger.debug(
f"Download for {model_id} already in progress or complete, skipping"
f"Download for {model_id} already in progress, complete, or failed, skipping"
)
return
@@ -169,46 +210,6 @@ class DownloadCoordinator:
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[callback_shard.model_card.model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
# Clean up active download tracking
if callback_shard.model_card.model_id in self.active_downloads:
del self.active_downloads[callback_shard.model_card.model_id]
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[callback_shard.model_card.model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
@@ -283,6 +284,12 @@ class DownloadCoordinator:
_,
progress,
) in self.shard_downloader.get_shard_download_status():
model_id = progress.shard.model_card.model_id
# Active downloads emit progress via the callback — don't overwrite
if model_id in self.active_downloads:
continue
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,

View File

@@ -71,8 +71,11 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
CreateMetaInstanceParams,
CreateMetaInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
DeleteMetaInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
@@ -115,8 +118,10 @@ from exo.shared.types.claude_api import (
from exo.shared.types.commands import (
Command,
CreateInstance,
CreateMetaInstance,
DeleteDownload,
DeleteInstance,
DeleteMetaInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
@@ -128,7 +133,7 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -137,6 +142,7 @@ from exo.shared.types.events import (
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
@@ -275,6 +281,8 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.post("/meta_instance")(self.create_meta_instance)
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
@@ -521,6 +529,46 @@ class API:
instance_id=instance_id,
)
async def create_meta_instance(
self, payload: CreateMetaInstanceParams
) -> CreateMetaInstanceResponse:
meta_instance = MetaInstance(
model_id=payload.model_id,
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
node_ids=payload.node_ids,
)
command = CreateMetaInstance(meta_instance=meta_instance)
await self._send(command)
return CreateMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance.meta_instance_id,
)
async def delete_meta_instance(
self, meta_instance_id: MetaInstanceId
) -> DeleteMetaInstanceResponse:
meta = self.state.meta_instances.get(meta_instance_id)
if not meta:
raise HTTPException(status_code=404, detail="MetaInstance not found")
# Delete MetaInstance first to prevent reconciler from re-placing
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
await self._send(command)
# Then cascade-delete any backing instances
for instance_id, instance in self.state.instances.items():
if instance.meta_instance_id == meta_instance_id:
await self._send(DeleteInstance(instance_id=instance_id))
return DeleteMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance_id,
)
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:

View File

@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone
from collections.abc import Sequence
from datetime import datetime, timezone
import anyio
from anyio.abc import TaskGroup
@@ -12,11 +13,19 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.process_managers import ProcessManager
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
from exo.master.reconcile import try_place_for_meta_instance
from exo.shared.apply import apply
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.commands import (
CreateInstance,
CreateMetaInstance,
DeleteInstance,
DeleteMetaInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
@@ -34,9 +43,9 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
MetaInstanceCreated,
MetaInstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
TraceEventData,
@@ -58,7 +67,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import MultiSourceBuffer
@@ -82,16 +91,15 @@ class Master:
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
self._process_managers: Sequence[ProcessManager] = [
InstanceHealthReconciler(),
NodeTimeoutReconciler(),
MetaInstanceReconciler(),
]
async def run(self):
logger.info("Starting Master")
@@ -100,15 +108,12 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
tg.start_soon(self._reconcile)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -290,6 +295,29 @@ class Master:
)
)
generated_events.extend(transition_events)
case CreateMetaInstance():
generated_events.append(
MetaInstanceCreated(meta_instance=command.meta_instance)
)
# Immediate placement attempt for responsiveness
model_card = await ModelCard.load(
command.meta_instance.model_id
)
result = try_place_for_meta_instance(
command.meta_instance,
model_card,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
generated_events.extend(result.events)
case DeleteMetaInstance():
generated_events.append(
MetaInstanceDeleted(
meta_instance_id=command.meta_instance_id
)
)
case PlaceInstance():
placement = place_instance(
command,
@@ -341,31 +369,32 @@ class Master:
):
await self._send_event(IndexedEvent(idx=i, event=event))
for event in generated_events:
await self.event_sender.send(event)
await self._apply_and_broadcast(event)
except ValueError as e:
logger.opt(exception=e).warning("Error in command processor")
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
async def _plan(self) -> None:
async def _apply_and_broadcast(self, event: Event) -> None:
"""Apply event to state, persist to disk, and broadcast to workers.
State is updated synchronously (before any await), so callers can
rely on ``self.state`` reflecting this event immediately after the
call. Python's cooperative scheduling guarantees no interleaving
between the state read and write.
"""
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
self._event_log.append(event)
await self._send_event(indexed)
async def _reconcile(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
await self.event_sender.send(
InstanceDeleted(instance_id=instance_id)
)
break
# time out dead nodes
for node_id, time in self.state.last_seen.items():
now = datetime.now(tz=timezone.utc)
if now - time > timedelta(seconds=30):
logger.info(f"Manually removing node {node_id} due to inactivity")
await self.event_sender.send(NodeTimedOut(node_id=node_id))
await anyio.sleep(10)
for pm in self._process_managers:
events = await pm.reconcile(self.state)
for event in events:
await self._apply_and_broadcast(event)
await anyio.sleep(1)
async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
@@ -383,32 +412,10 @@ class Master:
await self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
await self._apply_and_broadcast(event)
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
@@ -440,7 +447,7 @@ class Master:
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self.event_sender.send(
await self._apply_and_broadcast(
TracesMerged(task_id=task_id, traces=all_trace_data)
)

View File

@@ -63,7 +63,9 @@ def place_instance(
required_nodes: set[NodeId] | None = None,
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles)
)
# Filter to cycles containing all required nodes (subset matching)
if required_nodes:
@@ -106,7 +108,11 @@ def place_instance(
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
if command.instance_meta == InstanceMeta.MlxJaccl:
if not smallest_rdma_cycles:
raise ValueError(
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
)
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [

View File

@@ -0,0 +1,12 @@
from collections.abc import Sequence
from typing import Protocol, runtime_checkable
from exo.shared.types.events import Event
from exo.shared.types.state import State
@runtime_checkable
class ProcessManager(Protocol):
"""A reconciliation step that examines state and returns corrective events."""
async def reconcile(self, state: State) -> Sequence[Event]: ...

View File

@@ -0,0 +1,49 @@
from collections.abc import Sequence
from typing import final
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
from exo.shared.types.state import State
MAX_INSTANCE_RETRIES = 3
@final
class InstanceHealthReconciler:
"""Delete instances whose network connections are broken or whose runners have all failed."""
async def reconcile(self, state: State) -> Sequence[Event]:
events: list[Event] = []
for instance_id, instance in state.instances.items():
if not instance_connections_healthy(instance, state.topology):
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error="Network connection lost",
)
)
continue
is_failed, error_message = instance_runners_failed(
instance, state.runners, state.node_identities
)
if is_failed:
# Retry within the same instance if backed by a MetaInstance
mid = instance.meta_instance_id
mi = state.meta_instances.get(mid) if mid else None
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
events.append(
InstanceRetrying(
instance_id=instance_id,
meta_instance_id=mid,
failure_error=error_message or "Runner failed",
)
)
else:
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error=error_message,
)
)
return events

View File

@@ -0,0 +1,53 @@
from collections.abc import Sequence
from typing import final
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
from exo.shared.types.state import State
from exo.shared.types.worker.instances import Instance, InstanceId
@final
class MetaInstanceReconciler:
"""Place instances for unsatisfied MetaInstances."""
async def reconcile(self, state: State) -> Sequence[Event]:
all_events: list[Event] = []
# Local copy for intermediate tracking — so placement of B
# sees A's instance and doesn't double-place on same resources.
current_instances: dict[InstanceId, Instance] = dict(state.instances)
unsatisfied = find_unsatisfied_meta_instances(
state.meta_instances,
current_instances,
state.topology,
)
for meta_instance in unsatisfied:
model_card = await ModelCard.load(meta_instance.model_id)
result = try_place_for_meta_instance(
meta_instance,
model_card,
state.topology,
current_instances,
state.node_memory,
state.node_network,
)
# Update local instance map so next placement sees this one
for event in result.events:
if isinstance(event, InstanceCreated):
current_instances[event.instance.instance_id] = event.instance
all_events.extend(result.events)
# Emit placement failure if error differs from what's already in state
if result.error is not None and meta_instance.placement_error != result.error:
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=result.error,
)
)
return all_events

View File

@@ -0,0 +1,27 @@
from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from typing import final
from loguru import logger
from exo.shared.types.events import Event, NodeTimedOut
from exo.shared.types.state import State
_DEFAULT_TIMEOUT = timedelta(seconds=30)
@final
class NodeTimeoutReconciler:
"""Time out nodes that haven't been seen recently."""
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
self.timeout = timeout
async def reconcile(self, state: State) -> Sequence[Event]:
now = datetime.now(tz=timezone.utc)
events: list[Event] = []
for node_id, last_seen in state.last_seen.items():
if now - last_seen > self.timeout:
logger.info(f"Removing node {node_id} due to inactivity")
events.append(NodeTimedOut(node_id=node_id))
return events

236
src/exo/master/reconcile.py Normal file
View File

@@ -0,0 +1,236 @@
from collections.abc import Mapping, Sequence
from typing import NamedTuple
from loguru import logger
from exo.master.placement import get_transition_events, place_instance
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import Event
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.instances import (
BaseInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
RunnerStatus,
)
class PlacementResult(NamedTuple):
"""Result of a placement attempt: events to apply and optional error reason."""
events: Sequence[Event]
error: str | None
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
"""Reconstruct ring order from shard device_rank."""
node_ranks: list[tuple[NodeId, int]] = []
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
shard = instance.shard_assignments.runner_to_shard[runner_id]
node_ranks.append((node_id, shard.device_rank))
node_ranks.sort(key=lambda x: x[1])
return [node_id for node_id, _ in node_ranks]
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
"""Check that the specific IPs used by a ring instance still exist in the topology."""
ring = _get_ring_order(instance)
n = len(ring)
for node in ring:
hosts = instance.hosts_by_node[node]
for idx in range(n):
host = hosts[idx]
if host.ip in ("0.0.0.0", "198.51.100.1"):
continue # self or placeholder
# Real connection: node → ring[idx]. Check specific IP.
connections = topology.get_all_connections_between(node, ring[idx])
if not any(
isinstance(c, SocketConnection)
and c.sink_multiaddr.ip_address == host.ip
for c in connections
):
return False
return True
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
ring = _get_ring_order(instance)
n = len(ring)
for i in range(n):
for j in range(n):
iface = instance.jaccl_devices[i][j]
if iface is None:
continue
connections = topology.get_all_connections_between(ring[i], ring[j])
if not any(
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
for c in connections
):
return False
return True
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
"""Check that an instance's nodes and specific connections are still in the topology."""
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if not all(topology.contains_node(n) for n in instance_nodes):
return False
if len(instance_nodes) <= 1:
return True
match instance:
case MlxRingInstance():
return _ring_connections_healthy(instance, topology)
case MlxJacclInstance():
return _jaccl_connections_healthy(instance, topology)
def instance_runners_failed(
instance: Instance,
runners: Mapping[RunnerId, RunnerStatus],
node_identities: Mapping[NodeId, NodeIdentity],
) -> tuple[bool, str | None]:
"""Check if an instance's runners have all reached terminal failure states.
Returns ``(True, error_message)`` when ALL runners are terminal
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
Returns ``(False, None)`` when runners are still active, haven't reported
yet, or all gracefully shut down (no ``RunnerFailed``).
"""
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
if not instance_runner_ids:
return False, None
# Build reverse mapping: runner_id -> node_id
runner_to_node: dict[RunnerId, NodeId] = {
runner_id: node_id
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
}
has_any_failed = False
error_messages: list[str] = []
for runner_id in instance_runner_ids:
status = runners.get(runner_id)
if status is None:
# Runner hasn't reported yet — instance is still starting
return False, None
if isinstance(status, RunnerFailed):
has_any_failed = True
if status.error_message:
node_id = runner_to_node.get(runner_id)
name = node_identities[node_id].friendly_name if node_id and node_id in node_identities else node_id or "unknown"
error_messages.append(f"{name}: {status.error_message}")
elif isinstance(status, RunnerShutdown):
pass # Terminal but not a failure indicator on its own
else:
# Runner is still active (connecting, loading, running, etc.)
return False, None
if has_any_failed:
return True, "; ".join(error_messages) if error_messages else "Runner failed"
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
return False, None
def instance_satisfies_meta_instance(
meta_instance: MetaInstance,
instance: Instance,
) -> bool:
"""Check if a single instance satisfies a meta-instance's constraints.
This is a pure constraint check (model, min_nodes, node_ids).
Use ``instance_connections_healthy`` separately for topology health.
"""
if instance.shard_assignments.model_id != meta_instance.model_id:
return False
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if len(instance_nodes) < meta_instance.min_nodes:
return False
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
instance_nodes
)
def find_unsatisfied_meta_instances(
meta_instances: Mapping[MetaInstanceId, MetaInstance],
instances: Mapping[InstanceId, Instance],
topology: Topology,
) -> Sequence[MetaInstance]:
"""Return meta-instances that have no healthy backing instance."""
unsatisfied: list[MetaInstance] = []
for meta_id, meta_instance in meta_instances.items():
has_healthy_backing = any(
instance.meta_instance_id == meta_id
and instance_connections_healthy(instance, topology)
for instance in instances.values()
)
if not has_healthy_backing:
unsatisfied.append(meta_instance)
return unsatisfied
def try_place_for_meta_instance(
meta_instance: MetaInstance,
model_card: ModelCard,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> PlacementResult:
"""Try to place an instance satisfying the meta-instance constraints.
Returns a :class:`PlacementResult` with events on success, or an error
reason on failure.
"""
command = PlaceInstance(
model_card=model_card,
sharding=meta_instance.sharding,
instance_meta=meta_instance.instance_meta,
min_nodes=meta_instance.min_nodes,
)
try:
target_instances = place_instance(
command,
topology,
current_instances,
node_memory,
node_network,
required_nodes=(
set(meta_instance.node_ids) if meta_instance.node_ids else None
),
)
# Tag the new instance with meta_instance_id
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
if new_instance_ids:
new_id = next(iter(new_instance_ids))
target_instances[new_id] = target_instances[new_id].model_copy(
update={"meta_instance_id": meta_instance.meta_instance_id}
)
return PlacementResult(
events=list(get_transition_events(current_instances, target_instances)),
error=None,
)
except ValueError as e:
logger.debug(
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
)
return PlacementResult(events=[], error=str(e))

View File

@@ -0,0 +1,750 @@
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
instance_connections_healthy,
instance_runners_failed,
instance_satisfies_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, MetaInstanceId, NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerLoading,
RunnerReady,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_kb(1000),
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
def _topology(*node_ids: str, connect: bool = True) -> Topology:
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
between consecutive nodes (including wrap-around).
"""
t = Topology()
nodes = [NodeId(n) for n in node_ids]
for n in nodes:
t.add_node(n)
if connect and len(nodes) > 1:
for i in range(len(nodes)):
j = (i + 1) % len(nodes)
t.add_connection(
Connection(
source=nodes[i],
sink=nodes[j],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
)
),
)
)
t.add_connection(
Connection(
source=nodes[j],
sink=nodes[i],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
)
),
)
)
return t
def _meta_instance(
model_id: str = "test-org/test-model",
*,
min_nodes: int = 1,
node_ids: list[NodeId] | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> MetaInstance:
return MetaInstance(
meta_instance_id=meta_instance_id or MetaInstanceId(),
model_id=ModelId(model_id),
min_nodes=min_nodes,
node_ids=node_ids,
)
def _instance(
model_id: str = "test-org/test-model",
node_ids: list[str] | None = None,
instance_id: InstanceId | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> tuple[InstanceId, MlxRingInstance]:
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
iid = instance_id or InstanceId()
nodes = node_ids or ["node-a"]
n = len(nodes)
mc = _model_card(model_id)
ephemeral_port = 50000
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
runner_to_shard = {
runner_id: PipelineShardMetadata(
model_card=mc,
device_rank=i,
world_size=n,
start_layer=0,
end_layer=mc.n_layers,
n_layers=mc.n_layers,
)
for i, runner_id in enumerate(node_to_runner.values())
}
# Build hosts_by_node with IPs matching _topology() convention:
# node at index idx has IP 10.0.0.{idx+1}
hosts_by_node: dict[NodeId, list[Host]] = {}
for r, node_str in enumerate(nodes):
hosts: list[Host] = []
for idx in range(n):
if idx == r:
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
else:
hosts.append(Host(ip="198.51.100.1", port=0))
hosts_by_node[NodeId(node_str)] = hosts
return iid, MlxRingInstance(
instance_id=iid,
shard_assignments=ShardAssignments(
model_id=ModelId(model_id),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
),
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
meta_instance_id=meta_instance_id,
)
# --- instance_satisfies_meta_instance (pure constraint matching) ---
def test_satisfies_matching_model():
meta = _meta_instance()
_, inst = _instance(node_ids=["node-a"])
assert instance_satisfies_meta_instance(meta, inst) is True
def test_not_satisfies_wrong_model():
meta = _meta_instance("test-org/model-a")
_, inst = _instance("test-org/model-b")
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_missing_required_node():
meta = _meta_instance(node_ids=[NodeId("node-c")])
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_fewer_than_min_nodes():
meta = _meta_instance(min_nodes=3)
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_satisfies_with_node_ids_specified():
meta = _meta_instance(
node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2
)
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
assert instance_satisfies_meta_instance(meta, inst) is True
# --- instance_connections_healthy ---
def test_healthy_single_node_present():
_, inst = _instance(node_ids=["node-a"])
topology = _topology("node-a")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_single_node_missing():
_, inst = _instance(node_ids=["node-a"])
topology = Topology() # empty
assert instance_connections_healthy(inst, topology) is False
def test_healthy_two_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_two_node_edge_removed():
"""Nodes present but edge removed — ring broken."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", connect=False)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_two_node_ip_changed():
"""Edge exists but with a different IP than instance was configured with."""
_, inst = _instance(node_ids=["node-a", "node-b"])
# Build topology with different IPs than _instance() expects
topology = Topology()
topology.add_node(NodeId("node-a"))
topology.add_node(NodeId("node-b"))
topology.add_connection(
Connection(
source=NodeId("node-a"),
sink=NodeId("node-b"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=NodeId("node-b"),
sink=NodeId("node-a"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
),
)
)
assert instance_connections_healthy(inst, topology) is False
def test_healthy_three_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_three_node_one_edge_removed():
"""Remove one edge from a three-node ring — instance unhealthy."""
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
# Build topology with one direction of one edge missing
topology = Topology()
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
for n in nodes:
topology.add_node(n)
# Add all edges except node-a → node-b
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[1],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[0],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
# Missing: node-a → node-b (ip 10.0.0.2)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_node_missing_from_topology():
"""Instance has a node that's not in the topology at all."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a") # node-b not present
assert instance_connections_healthy(inst, topology) is False
def test_healthy_extra_nodes_in_topology():
"""Extra nodes in topology don't affect instance health."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
# --- find_unsatisfied_meta_instances ---
def test_unsatisfied_no_meta_instances():
result = find_unsatisfied_meta_instances({}, {}, Topology())
assert list(result) == []
def test_unsatisfied_one_satisfied():
meta = _meta_instance()
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == []
def test_unsatisfied_one_not_satisfied():
meta = _meta_instance("test-org/model-x")
id_a, inst_a = _instance("test-org/model-y")
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
)
assert list(result) == [meta]
def test_unsatisfied_mix():
meta_satisfied = _meta_instance("test-org/model-a")
meta_unsatisfied = _meta_instance("test-org/model-b")
id_a, inst_a = _instance(
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_satisfied.meta_instance_id: meta_satisfied,
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
},
{id_a: inst_a},
topology,
)
assert list(result) == [meta_unsatisfied]
def test_unsatisfied_node_disconnect():
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a") # node-b disconnected
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_edge_break():
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_idempotent():
meta = _meta_instance("test-org/model-x")
topology = _topology("node-a")
meta_instances = {meta.meta_instance_id: meta}
instances: dict[InstanceId, MlxRingInstance] = {}
result_1 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
result_2 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
assert result_1 == result_2
def test_unsatisfied_exclusive_binding():
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
meta_a = _meta_instance("test-org/model-x")
meta_b = _meta_instance("test-org/model-x")
id_inst, inst = _instance(
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
{id_inst: inst},
topology,
)
assert list(result) == [meta_b]
# --- apply handlers ---
def test_apply_meta_instance_created():
state = State()
meta = _meta_instance()
event = MetaInstanceCreated(meta_instance=meta)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id in new_state.meta_instances
assert new_state.meta_instances[meta.meta_instance_id] == meta
def test_apply_meta_instance_deleted():
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
def test_apply_meta_instance_deleted_clears_failure_info():
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
)
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
# --- instance_runners_failed ---
def test_runners_failed_all_failed():
"""All runners in RunnerFailed -> instance is failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runners = {
rid: RunnerFailed(error_message="OOM")
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "OOM" in error
def test_runners_failed_mixed_failed_shutdown():
"""One Failed + one Shutdown = failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="crash"),
runner_ids[1]: RunnerShutdown(),
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "crash" in error
def test_runners_not_failed_all_shutdown():
"""All Shutdown (graceful) = not a failure."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerShutdown()
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_still_active():
"""Some runners still active = not failed yet."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerLoading(),
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_no_status():
"""Runner not yet reported = not failed."""
_, inst = _instance(node_ids=["node-a"])
is_failed, _ = instance_runners_failed(inst, {}, {})
assert is_failed is False
def test_runners_not_failed_healthy():
"""Runners in Ready state = not failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerReady()
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
# --- failure tracking in apply_instance_deleted ---
def test_apply_instance_deleted_tracks_failure():
"""InstanceDeleted with failure_error increments meta instance failure count."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "Runner OOM"
def test_apply_instance_deleted_increments_failure():
"""Subsequent failures increment the counter."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
)
iid, inst = _instance(
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="new error")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 3
assert mi.last_failure_error == "new error"
def test_apply_instance_deleted_no_failure_no_tracking():
"""InstanceDeleted without failure_error does not track."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
def test_apply_instance_deleted_orphan_no_tracking():
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
iid, inst = _instance(node_ids=["node-a"])
state = State(instances={iid: inst})
event = InstanceDeleted(instance_id=iid, failure_error="crash")
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
# --- InstanceRetrying ---
def test_apply_instance_retrying_removes_runners():
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerShutdown(),
}
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners=runners,
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="OOM",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
# Instance still exists
assert iid in new_state.instances
# Runners removed
assert runner_ids[0] not in new_state.runners
assert runner_ids[1] not in new_state.runners
def test_apply_instance_retrying_increments_failure():
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "crash"
def test_apply_instance_retrying_skips_missing_runners():
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
# No runners in state at all
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
# Should not raise
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert iid in new_state.instances
def test_apply_instance_created_resets_failure_counter():
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 3, "last_failure_error": "old error"}
)
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
event = InstanceCreated(instance=inst)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
assert mi.last_failure_error == "old error"
assert mi.placement_error is None
# --- InstanceHealthReconciler retry-vs-delete ---
async def test_health_reconciler_retries_when_under_limit():
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
assert events[0].instance_id == iid
assert events[0].meta_instance_id == meta.meta_instance_id
async def test_health_reconciler_deletes_when_limit_reached():
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_deletes_without_meta_instance():
"""Instances without a MetaInstance are deleted immediately on runner failure."""
iid, inst = _instance(node_ids=["node-a"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_network_failure_always_deletes():
"""Network failure always triggers InstanceDeleted regardless of retry count."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
topology=_topology("node-a"), # node-b missing
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
assert events[0].failure_error == "Network connection lost"

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from loguru import logger
from exo.shared.types.common import NodeId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -12,6 +12,10 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -28,9 +32,11 @@ from exo.shared.types.events import (
TracesCollected,
TracesMerged,
)
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeRdmaCtlStatus,
NodeThunderboltInfo,
ThunderboltBridgeStatus,
)
@@ -47,7 +53,9 @@ from exo.utils.info_gatherer.info_gatherer import (
MemoryUsage,
MiscData,
NodeConfig,
NodeDiskUsage,
NodeNetworkInterfaces,
RdmaCtlStatus,
StaticNodeInformation,
ThunderboltBridgeInfo,
)
@@ -69,6 +77,14 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceRetrying():
return apply_instance_retrying(event, state)
case MetaInstanceCreated():
return apply_meta_instance_created(event, state)
case MetaInstanceDeleted():
return apply_meta_instance_deleted(event, state)
case MetaInstancePlacementFailed():
return apply_meta_instance_placement_failed(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -171,20 +187,119 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
return state.model_copy(update={"tasks": new_tasks})
def _update_meta_instance(
state: State, mid: MetaInstanceId, **fields: object
) -> Mapping[MetaInstanceId, MetaInstance]:
mi = state.meta_instances[mid]
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
def apply_instance_created(event: InstanceCreated, state: State) -> State:
instance = event.instance
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
instance.instance_id: instance,
}
return state.model_copy(update={"instances": new_instances})
update: dict[str, object] = {"instances": new_instances}
# Reset failure tracking when a new instance is created for a meta-instance
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
mi = state.meta_instances[instance.meta_instance_id]
if mi.placement_error is not None or mi.consecutive_failures > 0:
update["meta_instances"] = _update_meta_instance(
state,
instance.meta_instance_id,
placement_error=None,
consecutive_failures=0,
)
return state.model_copy(update=update)
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
deleted_instance = state.instances.get(event.instance_id)
new_instances: Mapping[InstanceId, Instance] = {
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
}
return state.model_copy(update={"instances": new_instances})
update: dict[str, object] = {"instances": new_instances}
# Track failure on the MetaInstance itself
if (
event.failure_error
and deleted_instance
and deleted_instance.meta_instance_id
and deleted_instance.meta_instance_id in state.meta_instances
):
mid = deleted_instance.meta_instance_id
mi = state.meta_instances[mid]
update["meta_instances"] = {
**state.meta_instances,
mid: mi.model_copy(
update={
"consecutive_failures": mi.consecutive_failures + 1,
"last_failure_error": event.failure_error,
}
),
}
return state.model_copy(update=update)
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
"""Runners failed but retry limit not reached — remove runners, keep instance."""
instance = state.instances.get(event.instance_id)
if instance is None:
return state
# Remove all runners belonging to this instance from state
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
new_runners: Mapping[RunnerId, RunnerStatus] = {
rid: rs
for rid, rs in state.runners.items()
if rid not in runner_ids_to_remove
}
update: dict[str, object] = {"runners": new_runners}
# Increment failure count on the MetaInstance
if event.meta_instance_id in state.meta_instances:
update["meta_instances"] = _update_meta_instance(
state,
event.meta_instance_id,
consecutive_failures=state.meta_instances[event.meta_instance_id].consecutive_failures + 1,
last_failure_error=event.failure_error,
)
return state.model_copy(update=update)
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
**state.meta_instances,
event.meta_instance.meta_instance_id: event.meta_instance,
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
mid: mi
for mid, mi in state.meta_instances.items()
if mid != event.meta_instance_id
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_placement_failed(
event: MetaInstancePlacementFailed, state: State
) -> State:
if event.meta_instance_id not in state.meta_instances:
return state
return state.model_copy(
update={
"meta_instances": _update_meta_instance(
state, event.meta_instance_id, placement_error=event.reason
)
}
)
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
@@ -223,6 +338,9 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
node_disk = {
key: value for key, value in state.node_disk.items() if key != event.node_id
}
node_system = {
key: value for key, value in state.node_system.items() if key != event.node_id
}
@@ -239,6 +357,9 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
for key, value in state.node_thunderbolt_bridge.items()
if key != event.node_id
}
node_rdma_ctl = {
key: value for key, value in state.node_rdma_ctl.items() if key != event.node_id
}
# Only recompute cycles if the leaving node had TB bridge enabled
leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)
leaving_node_had_tb_enabled = (
@@ -256,10 +377,12 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
"node_thunderbolt_bridge": node_thunderbolt_bridge,
"node_rdma_ctl": node_rdma_ctl,
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
}
)
@@ -288,6 +411,8 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
update["node_memory"] = {**state.node_memory, event.node_id: info.memory}
case MemoryUsage():
update["node_memory"] = {**state.node_memory, event.node_id: info}
case NodeDiskUsage():
update["node_disk"] = {**state.node_disk, event.node_id: info.disk_usage}
case NodeConfig():
pass
case MiscData():
@@ -302,7 +427,12 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
case StaticNodeInformation():
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"model_id": info.model, "chip_id": info.chip}
update={
"model_id": info.model,
"chip_id": info.chip,
"os_version": info.os_version,
"os_build_version": info.os_build_version,
}
)
update["node_identities"] = {
**state.node_identities,
@@ -354,6 +484,11 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
new_tb_bridge, state.node_network
)
)
case RdmaCtlStatus():
update["node_rdma_ctl"] = {
**state.node_rdma_ctl,
event.node_id: NodeRdmaCtlStatus(enabled=info.enabled),
}
return state.model_copy(update=update)

View File

@@ -3,11 +3,10 @@ from collections.abc import Generator
from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -227,13 +226,6 @@ class PlaceInstanceParams(BaseModel):
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod
def use_default(cls, v: object):
if not v or not isinstance(v, (Sharding, InstanceMeta)):
raise PydanticUseDefault()
return v
class CreateInstanceParams(BaseModel):
instance: Instance
@@ -269,6 +261,26 @@ class DeleteInstanceResponse(BaseModel):
instance_id: InstanceId
class CreateMetaInstanceParams(BaseModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
class CreateMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class DeleteMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None

View File

@@ -6,7 +6,8 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -48,6 +49,14 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class CreateMetaInstance(BaseCommand):
meta_instance: MetaInstance
class DeleteMetaInstance(BaseCommand):
meta_instance_id: MetaInstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -89,6 +98,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| CreateMetaInstance
| DeleteMetaInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -42,6 +42,10 @@ class CommandId(Id):
pass
class MetaInstanceId(Id):
"""Identifier for a MetaInstance."""
class Host(CamelCaseModel):
ip: str
port: int

View File

@@ -5,7 +5,8 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -66,6 +67,30 @@ class InstanceCreated(BaseEvent):
class InstanceDeleted(BaseEvent):
instance_id: InstanceId
failure_error: str | None = None
class MetaInstanceCreated(BaseEvent):
meta_instance: MetaInstance
class MetaInstanceDeleted(BaseEvent):
meta_instance_id: MetaInstanceId
@final
class MetaInstancePlacementFailed(BaseEvent):
meta_instance_id: MetaInstanceId
reason: str
@final
class InstanceRetrying(BaseEvent):
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
instance_id: InstanceId
meta_instance_id: MetaInstanceId
failure_error: str
class RunnerStatusUpdated(BaseEvent):
@@ -141,6 +166,10 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceRetrying
| MetaInstanceCreated
| MetaInstanceDeleted
| MetaInstancePlacementFailed
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut

View File

@@ -0,0 +1,25 @@
from typing import final
from pydantic import Field
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.worker.instances import InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import FrozenModel
@final
class MetaInstance(FrozenModel):
"""Declarative constraint: ensure an instance matching these parameters always exists."""
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
# Failure tracking
placement_error: str | None = None
consecutive_failures: int = 0
last_failure_error: str | None = None

View File

@@ -1,4 +1,6 @@
import shutil
from collections.abc import Sequence
from pathlib import Path
from typing import Literal, Self
import psutil
@@ -38,6 +40,22 @@ class MemoryUsage(CamelCaseModel):
)
class DiskUsage(CamelCaseModel):
"""Disk space usage for the models directory."""
total: Memory
available: Memory
@classmethod
def from_path(cls, path: Path) -> Self:
"""Get disk usage stats for the partition containing path."""
total, _used, free = shutil.disk_usage(path)
return cls(
total=Memory.from_bytes(total),
available=Memory.from_bytes(free),
)
class SystemPerformanceProfile(CamelCaseModel):
# TODO: flops_fp16: float
@@ -63,6 +81,8 @@ class NodeIdentity(CamelCaseModel):
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
os_version: str = "Unknown"
os_build_version: str = "Unknown"
class NodeNetworkInfo(CamelCaseModel):
@@ -77,6 +97,12 @@ class NodeThunderboltInfo(CamelCaseModel):
interfaces: Sequence[ThunderboltIdentifier] = []
class NodeRdmaCtlStatus(CamelCaseModel):
"""Whether RDMA is enabled on this node (via rdma_ctl)."""
enabled: bool
class ThunderboltBridgeStatus(CamelCaseModel):
"""Whether the Thunderbolt Bridge network service is enabled on this node."""

View File

@@ -6,11 +6,14 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import (
DiskUsage,
MemoryUsage,
NodeIdentity,
NodeNetworkInfo,
NodeRdmaCtlStatus,
NodeThunderboltInfo,
SystemPerformanceProfile,
ThunderboltBridgeStatus,
@@ -39,6 +42,7 @@ class State(CamelCaseModel):
arbitrary_types_allowed=True,
)
instances: Mapping[InstanceId, Instance] = {}
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}
@@ -49,10 +53,12 @@ class State(CamelCaseModel):
# Granular node state mappings (update independently at different frequencies)
node_identities: Mapping[NodeId, NodeIdentity] = {}
node_memory: Mapping[NodeId, MemoryUsage] = {}
node_disk: Mapping[NodeId, DiskUsage] = {}
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] = {}
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []

View File

@@ -12,6 +12,7 @@ class ThunderboltConnection(CamelCaseModel):
class ThunderboltIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
link_speed: str = ""
## Intentionally minimal, only collecting data we care about - there's a lot more
@@ -19,6 +20,7 @@ class ThunderboltIdentifier(CamelCaseModel):
class _ReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str | None = None
current_speed_key: str | None = None
class _ConnectivityItem(BaseModel, extra="ignore"):
@@ -42,7 +44,9 @@ class ThunderboltConnectivityData(BaseModel, extra="ignore"):
# if tag not in ifaces: return None
iface = f"rdma_{ifaces[tag]}"
return ThunderboltIdentifier(
rdma_interface=iface, domain_uuid=self.domain_uuid_key
rdma_interface=iface,
domain_uuid=self.domain_uuid_key,
link_speed=self.receptacle_1_tag.current_speed_key or "",
)
def conn(self) -> ThunderboltConnection | None:

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +19,7 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
meta_instance_id: MetaInstanceId | None = None
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -8,16 +8,17 @@ from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio import create_task_group, fail_after, open_process, to_thread
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from pydantic import ValidationError
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.constants import EXO_CONFIG_FILE, EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
DiskUsage,
MemoryUsage,
NetworkInterfaceInfo,
ThunderboltBridgeStatus,
@@ -31,7 +32,13 @@ from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
get_os_build_version,
get_os_version,
)
IS_DARWIN = sys.platform == "darwin"
@@ -177,11 +184,18 @@ class StaticNodeInformation(TaggedModel):
model: str
chip: str
os_version: str
os_build_version: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
return cls(
model=model,
chip=chip,
os_version=get_os_version(),
os_build_version=await get_os_build_version(),
)
class NodeNetworkInterfaces(TaggedModel):
@@ -196,6 +210,28 @@ class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class RdmaCtlStatus(TaggedModel):
enabled: bool
@classmethod
async def gather(cls) -> Self | None:
if not IS_DARWIN or shutil.which("rdma_ctl") is None:
return None
try:
with anyio.fail_after(5):
proc = await anyio.run_process(["rdma_ctl", "status"], check=False)
except (TimeoutError, OSError):
return None
if proc.returncode != 0:
return None
output = proc.stdout.decode("utf-8").lower().strip()
if "enabled" in output:
return cls(enabled=True)
if "disabled" in output:
return cls(enabled=False)
return None
class ThunderboltBridgeInfo(TaggedModel):
status: ThunderboltBridgeStatus
@@ -284,6 +320,20 @@ class MiscData(TaggedModel):
return cls(friendly_name=await get_friendly_name())
class NodeDiskUsage(TaggedModel):
"""Disk space information for the models directory."""
disk_usage: DiskUsage
@classmethod
async def gather(cls) -> Self:
return cls(
disk_usage=await to_thread.run_sync(
lambda: DiskUsage.from_path(EXO_MODELS_DIR)
)
)
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
@@ -310,10 +360,12 @@ GatheredInfo = (
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| RdmaCtlStatus
| ThunderboltBridgeInfo
| NodeConfig
| MiscData
| StaticNodeInformation
| NodeDiskUsage
)
@@ -326,6 +378,9 @@ class InfoGatherer:
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
static_info_poll_interval: float | None = 60
rdma_ctl_poll_interval: float | None = 10 if IS_DARWIN else None
disk_poll_interval: float | None = 30
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
@@ -335,25 +390,38 @@ class InfoGatherer:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._monitor_thunderbolt_bridge_status)
tg.start_soon(self._monitor_rdma_ctl_status)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
tg.start_soon(self._monitor_static_info)
tg.start_soon(self._monitor_disk_usage)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_static_info(self):
if self.static_info_poll_interval is None:
return
while True:
try:
with fail_after(30):
await self.info_sender.send(await StaticNodeInformation.gather())
except Exception as e:
logger.warning(f"Error gathering static node info: {e}")
await anyio.sleep(self.static_info_poll_interval)
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
while True:
try:
await self.info_sender.send(await MiscData.gather())
with fail_after(10):
await self.info_sender.send(await MiscData.gather())
except Exception as e:
logger.warning(f"Error gathering misc data: {e}")
await anyio.sleep(self.misc_poll_interval)
@@ -361,20 +429,26 @@ class InfoGatherer:
async def _monitor_system_profiler_thunderbolt_data(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
while True:
try:
data = await ThunderboltConnectivity.gather()
assert data is not None
with fail_after(30):
iface_map = await _gather_iface_map()
if iface_map is None:
raise ValueError("Failed to gather interface map")
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
data = await ThunderboltConnectivity.gather()
assert data is not None
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
idents = [
it for i in data if (it := i.ident(iface_map)) is not None
]
await self.info_sender.send(
MacThunderboltIdentifiers(idents=idents)
)
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
except Exception as e:
logger.warning(f"Error gathering Thunderbolt data: {e}")
await anyio.sleep(self.system_profiler_interval)
@@ -402,8 +476,9 @@ class InfoGatherer:
return
while True:
try:
nics = await get_network_interfaces()
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
with fail_after(10):
nics = await get_network_interfaces()
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
except Exception as e:
logger.warning(f"Error gathering network interfaces: {e}")
await anyio.sleep(self.interface_watcher_interval)
@@ -413,37 +488,70 @@ class InfoGatherer:
return
while True:
try:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None:
await self.info_sender.send(curr)
with fail_after(30):
curr = await ThunderboltBridgeInfo.gather()
if curr is not None:
await self.info_sender.send(curr)
except Exception as e:
logger.warning(f"Error gathering Thunderbolt Bridge status: {e}")
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
async def _monitor_rdma_ctl_status(self):
if self.rdma_ctl_poll_interval is None:
return
while True:
try:
curr = await RdmaCtlStatus.gather()
if curr is not None:
await self.info_sender.send(curr)
except Exception as e:
logger.warning(f"Error gathering RDMA ctl status: {e}")
await anyio.sleep(self.rdma_ctl_poll_interval)
async def _monitor_disk_usage(self):
if self.disk_poll_interval is None:
return
while True:
try:
with fail_after(5):
await self.info_sender.send(await NodeDiskUsage.gather())
except Exception as e:
logger.warning(f"Error gathering disk usage: {e}")
await anyio.sleep(self.disk_poll_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
while True:
try:
async with await open_process(
[
macmon_path,
"pipe",
"--interval",
str(self.macmon_interval * 1000),
]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)
except Exception as e:
logger.warning(f"Error in macmon monitor: {e}")
await anyio.sleep(self.macmon_interval)

View File

@@ -1,3 +1,4 @@
import platform
import socket
import sys
from subprocess import CalledProcessError
@@ -8,6 +9,34 @@ from anyio import run_process
from exo.shared.types.profiling import InterfaceType, NetworkInterfaceInfo
def get_os_version() -> str:
"""Return the OS version string for this node.
On macOS this is the macOS version (e.g. ``"15.3"``).
On other platforms it falls back to the platform name (e.g. ``"Linux"``).
"""
if sys.platform == "darwin":
version = platform.mac_ver()[0]
return version if version else "Unknown"
return platform.system() or "Unknown"
async def get_os_build_version() -> str:
"""Return the macOS build version string (e.g. ``"24D5055b"``).
On non-macOS platforms, returns ``"Unknown"``.
"""
if sys.platform != "darwin":
return "Unknown"
try:
process = await run_process(["sw_vers", "-buildVersion"])
except CalledProcessError:
return "Unknown"
return process.stdout.decode("utf-8", errors="replace").strip() or "Unknown"
async def get_friendly_name() -> str:
"""
Asynchronously gets the 'Computer Name' (friendly name) of a Mac.

View File

@@ -184,6 +184,14 @@ class Worker:
)
if task is None:
continue
# Gate DownloadModel on backoff BEFORE emitting TaskCreated
# to prevent flooding the event log with useless events
if isinstance(task, DownloadModel):
model_id = task.shard_metadata.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
continue
logger.info(f"Worker plan: {task.__class__.__name__}")
assert task.task_status
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
@@ -199,9 +207,6 @@ class Worker:
)
case DownloadModel(shard_metadata=shard):
model_id = shard.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
continue
self._download_backoff.record_attempt(model_id)
await self.download_command_sender.send(

View File

@@ -34,6 +34,7 @@ from exo.shared.types.worker.runners import (
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerStatus,
RunnerWarmingUp,
)
@@ -54,7 +55,7 @@ def plan(
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _create_runner(node_id, runners, instances, all_runners)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
@@ -73,6 +74,12 @@ def _kill_runner(
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
return Shutdown(instance_id=instance_id, runner_id=runner_id)
# Master removed our runner from state (retry signal) and process is dead
if runner_id not in all_runners and isinstance(
runner.status, (RunnerFailed, RunnerShutdown)
):
return Shutdown(instance_id=instance_id, runner_id=runner_id)
for (
global_runner_id
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
@@ -90,6 +97,7 @@ def _create_runner(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus],
) -> CreateRunner | None:
for instance in instances.values():
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
@@ -99,6 +107,16 @@ def _create_runner(
if runner_id in runners:
continue
# Don't create while any peer runner is in a terminal state — wait for
# the master to emit InstanceRetrying which removes them from state.
has_terminal_peer = any(
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
for peer_rid in instance.shard_assignments.node_to_runner.values()
if peer_rid != runner_id
)
if has_terminal_peer:
continue
shard = instance.shard(runner_id)
assert shard is not None