Compare commits

..

45 Commits

Author SHA1 Message Date
Alex Cheema
9cacef39f0 feat: add --bootstrap-peer flag to bypass mDNS for peer discovery
macOS TCC blocks mDNS multicast from SSH sessions, preventing
SSH-spawned exo processes from discovering peers. This adds a
--bootstrap-peer IP:PORT flag that directly dials a known peer,
bypassing mDNS. The flag is additive — mDNS still runs alongside.

Changes across Rust networking, PyO3 bindings, and Python CLI:
- discovery.rs: store bootstrap addresses, dial on startup and retry
- swarm.rs: thread bootstrap_peers through to Behaviour
- networking.rs: accept optional bootstrap_peers in PyO3 constructor
- router.py: pass bootstrap_peers to NetworkingHandle
- main.py: parse --bootstrap-peer CLI arg into multiaddr format

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:24:00 -08:00
Alex Cheema
3df896dec1 fix: cancel active tasks on meta-instance cascade delete
Previously, DeleteMetaInstance cascade-deleted backing instances without
cancelling their active tasks, leaving orphaned task references. Now emits
TaskStatusUpdated(Cancelled) for Pending/Running tasks before InstanceDeleted.

Also adds lifecycle logging for meta-instance operations, a GET /meta_instances
endpoint, and 2 regression tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 09:51:55 -08:00
Alex Cheema
29c3489f3e Merge remote-tracking branch 'origin/main' into alexcheema/mlx-distributed-transfer
Resolve merge conflicts between meta-instance feature and task cancellation:
- commands.py: Include both CreateMetaInstance/DeleteMetaInstance and TaskCancelled
- plan.py: Add _cancel_tasks from main, keep all_runners param in _create_runner
- runner_supervisor.py: Merge cancel_sender/cancelled fields with pipe relay fields
- reconcile.py: Pass empty tasks dict to get_transition_events (new param from main)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:38:47 -08:00
Alex Cheema
19f606a379 fix: dashboard derived reactivity bug and tautological check in meta-instance UI
1. DERIVED REACTIVITY BUG: `unifiedDisplayItems` used `$derived(fn)` which
   made the derived value the function itself instead of its result. Svelte
   never tracked reactive dependencies in the function body, so the instance
   list didn't update when metaInstances or instances changed. Fixed by using
   `$derived.by(fn)` and removing the `()` call-sites in the template.

2. TAUTOLOGICAL CHECK: In `getMetaInstancePlacingStatus`, the `lastError ? ...
   : null` guard inside the `failures > 0` branch was always true because
   `lastFailureError` and `consecutiveFailures` are always set together in
   `apply_instance_retrying` and `apply_instance_deleted`. Removed the dead
   `: null` branch.

Also fixes pyright errors in test file by using proper pytest.MonkeyPatch type.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:24:49 -08:00
Alex Cheema
c74b46dca0 fix: add timeout and error handling for ModelCard.load in MetaInstanceReconciler
ModelCard.load() does async I/O inside the 1-second reconcile loop. A slow
or failing load blocked all reconciliation (health checks, node timeouts,
other meta-instances). Adds a 10-second timeout, per-meta-instance error
handling with MetaInstancePlacementFailed events, and documents the
intentional early return in apply_instance_retrying.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:11:29 -08:00
Alex Cheema
7028f99d7e fix: validate placement before sending command to prevent silent failures (BUG-001c)
The place_instance API endpoint used fire-and-forget: it sent the command
and returned HTTP 200 immediately. On a fresh cluster start, the master's
state often lacks topology/memory data, so placement raises ValueError
which was silently caught and logged. The caller never learned it failed.

Two fixes:
- API: validate placement locally before sending, return HTTP 400 on
  failure instead of silently accepting an unprocessable command
- Master: emit MetaInstancePlacementFailed on immediate placement error
  in CreateMetaInstance handler so the error surfaces in state right away

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 16:13:38 -08:00
Alex Cheema
68d9debcb6 fix: place_instance now uses all available nodes to prevent OOM (BUG-001d)
The placement algorithm previously selected the smallest viable cycle,
causing large models to be distributed across too few nodes and running
out of memory. Changed get_smallest_cycles to get_largest_cycles so that
all healthy nodes are utilized, spreading layers more evenly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:13:03 -08:00
Alex Cheema
9bf5f86b9e Revert "feat: add task cancellation for client disconnect handling (BUG-001)"
This reverts commit 2a75672a4a.
2026-02-15 13:48:01 -08:00
Alex Cheema
2a75672a4a feat: add task cancellation for client disconnect handling (BUG-001)
- Add TaskCancelled command and Cancelled task status
- Detect API client disconnects in master/api.py
- Handle TaskCancelled in master state machine
- Add _cancel_tasks to worker for graceful task cleanup
- Add cancel_receiver to runner for inference abort
- Add mx_any helper in MLX utils for cancellable operations
- Guard instance lookup in worker to prevent KeyError
- Update tests for cancellation flow

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 13:43:15 -08:00
Alex Cheema
8a6344fba3 Fix model discovery reporting DownloadPending for fully-downloaded models
On startup, _emit_existing_download_progress() used
downloaded_bytes_this_session to decide between DownloadPending and
DownloadOngoing. Since downloaded_bytes_this_session is always 0 on
startup (it tracks the current session only), fully-downloaded models
were incorrectly reported as DownloadPending.

Now checks actual disk state: if downloaded_bytes >= total_bytes, emit
DownloadCompleted regardless of session bytes. This fixes the UI showing
models as pending when they're already available.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 09:39:06 -08:00
Alex Cheema
316a7a344e fix: use force=True for multiprocessing set_start_method
Prevents RuntimeError when the context has already been set,
e.g. when Terminal.app reuses a tab or the process restarts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 07:02:32 -08:00
Alex Cheema
16b58d6946 fix: eliminate command/reconciler interleaving race in meta-instance
Two race conditions existed in the meta-instance lifecycle:

1. CreateMetaInstance buffered MetaInstanceCreated but didn't apply it
   before awaiting ModelCard.load(). The reconciler could interleave
   during the await, leading to duplicate placements.

   Fix: apply MetaInstanceCreated eagerly via _apply_and_broadcast,
   then re-check satisfaction after the await so placement uses fresh
   state and skips if the reconciler already handled it.

2. delete_meta_instance (API handler) sent DeleteMetaInstance then
   read self.state.instances for cascade deletion. State was stale,
   so backing instances created between the send and the read were
   missed — permanently orphaning them.

   Fix: move cascade delete into the command processor's
   DeleteMetaInstance handler, where InstanceDeleted events are
   generated atomically with MetaInstanceDeleted.

Reproduced on 4-node Mac Mini cluster: 28K anomalies in stress test
including 21 permanently orphaned instances. After fix, the cascade
delete and placement are race-free.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 12:50:57 -08:00
Alex Cheema
39741907c7 test: add 25 edge-case tests for MetaInstance lifecycle
Cover retry logic, error handling, backward compatibility,
concurrent scenarios, placement error tracking, and serialization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 12:17:41 -08:00
Alex Cheema
e5c73c564c Merge remote-tracking branch 'origin/main' into alexcheema/meta-instance 2026-02-13 10:10:38 -08:00
Alex Cheema
77799a170a Fix JACCL SideChannel bytes serialization for JSON round-trip
TaggedModel's wrap validator converts JSON→Python validation context,
which breaks strict-mode bytes deserialization from JSON strings.
Use Base64Bytes type to encode/decode bytes as base64 strings in JSON.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 14:23:44 -08:00
Alex Cheema
5dddd9cd2b Use named pipes (FIFOs) for JACCL SideChannel relay
Anonymous pipes from os.pipe() don't survive multiprocessing.Process
spawn on macOS (default since Python 3.8). The FD numbers are passed
but the actual file descriptors don't exist in the child process,
causing EBADF errors.

Switch to named pipes (FIFOs) which the child opens by path in the
spawned process, getting valid FDs for the C++ SideChannel.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 12:06:37 -08:00
Alex Cheema
0e08c2bfd3 Add pipe-based JACCL SideChannel relay via exo control plane
Replace fragile TCP SideChannel with anonymous pipes relayed through
exo's event-sourced control plane. RunnerSupervisor creates pipe pairs
for MlxJaccl instances, relays all_gather rounds via JacclSideChannelData/
JacclSideChannelGathered events through the master, eliminating errno=57
crashes from Thunderbolt RDMA driver instability.

Also includes dashboard RDMA warning improvements and instance retry fixes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 08:55:40 -08:00
Alex Cheema
f9ffdaef5f Preserve last_failure_error across instance recreation, fix RDMA banner wording
- apply_instance_created no longer clears last_failure_error so the
  error context persists while the new instance starts up
- Dashboard retryError shows the error without (N/3) prefix when
  consecutiveFailures is 0 (instance was recreated)
- Jaccl warning tooltip now says "experimental RDMA driver in macOS"

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:48:30 -08:00
Alex Cheema
8c2416c9ea chore: remove temporary screenshot files
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:40:16 -08:00
Alex Cheema
e5007f619a temp: add jaccl warning screenshots for PR comment
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:38:53 -08:00
Alex Cheema
a627f67253 dashboard: show warning banner for [jaccl] RDMA driver errors
Detect errors containing "[jaccl]" in MetaInstance failure errors and
display a red dismissible alert banner. The tooltip explains this is a
macOS RDMA driver issue and that the affected machine needs to be
restarted. Alert re-appears if a new error arrives after dismissal.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:38:42 -08:00
Alex Cheema
f189222bfc Merge remote-tracking branch 'origin/main' into alexcheema/meta-instance
# Conflicts:
#	dashboard/src/lib/stores/app.svelte.ts
#	dashboard/src/routes/+page.svelte
2026-02-11 15:59:50 -08:00
Alex Cheema
ad6d35d68a Retry runners within the same Instance instead of recreating
When runners fail for a MetaInstance-backed Instance, retry up to 3
times by restarting runners within the same Instance rather than
deleting and recreating it each time. After 3 failures, delete the
Instance so MetaInstanceReconciler can create a fresh one.

- Add InstanceRetrying event that removes runners from state (signaling
  workers to restart) and increments consecutive_failures on MetaInstance
- InstanceHealthReconciler emits InstanceRetrying when under retry limit,
  InstanceDeleted when exhausted or no MetaInstance
- Worker _kill_runner detects retry signal (runner deleted from state +
  terminal supervisor) and cleans up for _create_runner to recreate
- Worker _create_runner guards against oscillation by blocking creation
  while any peer runner has explicit terminal status
- InstanceCreated resets consecutive_failures for fresh starts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 14:21:11 -08:00
Alex Cheema
c236d62caf Remove timestamp-based retry cooldown
Remove last_failure_at field and RETRY_COOLDOWN_SECONDS logic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:59:39 -08:00
Alex Cheema
a8069e8a30 Consolidate failure state onto MetaInstance, add 5s retry cooldown
Move placement_error, consecutive_failures, last_failure_error, and
last_failure_at directly onto the MetaInstance model instead of keeping
them as separate State mappings (meta_instance_errors, InstanceFailureInfo,
meta_instance_failure_info). Adds a 5-second cooldown between retry attempts
to prevent rapid instance churn when runners fail instantly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:55:47 -08:00
Alex Cheema
84ce555d55 Show retry attempt count with error message, e.g. (2/3)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:43:20 -08:00
Alex Cheema
b78ea438bc Include node friendly names in runner error messages
Each error in the combined message is now prefixed with the node's friendly
name (e.g. "MacBook Pro: OOM; Mac Studio: connection reset") so the root
cause node is easily identifiable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:41:10 -08:00
Alex Cheema
1960b16f9f Remove permanent retry blocking, allow continuous retry batches
The dashboard % 3 logic already handles displaying retry progress in batches
(RETRYING 1/3, 2/3, 3/3, then PLACING with error, repeat). No need to
permanently block placement after 3 failures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:35:03 -08:00
Alex Cheema
c6838c8fd8 Show retry count in exceeded retry limit message (3/3)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:28:17 -08:00
Alex Cheema
420d9b9e76 Collect all runner error messages instead of just the last one
When multiple runners fail, concatenate all error messages with "; " so the
real error isn't hidden by generic side-effect failures from other runners.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:27:49 -08:00
Alex Cheema
13f1e9c489 Stop infinite retries after 3 failures, show errors persistently in dashboard
MetaInstanceReconciler now checks failure count before placement — after 3
consecutive failures it emits MetaInstancePlacementFailed instead of retrying
forever. Dashboard shows "Retrying after error: <msg>" in orange throughout
the retry cycle, not just during the brief window with no backing instance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:21:11 -08:00
Alex Cheema
451a06b3d8 Add instance retry logic with max 3 retries and failure tracking
- Extend InstanceDeleted with failure_error field for runner crash info
- Add InstanceFailureInfo model tracking consecutive failures per MetaInstance
- InstanceHealthReconciler now detects runner failures (all terminal with
  at least one RunnerFailed) in addition to connection failures
- apply_instance_deleted increments failure counter for meta-bound instances
- Dashboard shows RETRYING (N/3) status with error messages, and
  "Instance re-created due to failure" after 3 consecutive failures
- Extract and display RunnerFailed error messages in instance status

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:09:42 -08:00
Alex Cheema
94b55d66f4 Fix MetaInstance.node_ids frozenset failing JSON deserialization
frozenset serializes to a JSON array but cannot be deserialized back
in strict mode through the TaggedModel wrap validator (list → frozenset
coercion is rejected). Changed to list[NodeId] since the model is
already frozen/immutable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:54:56 -08:00
Alex Cheema
2b68b931c5 Send node_ids from placement preview when launching instances
The dashboard now extracts node IDs from the selected preview's
memory_delta_by_node, ensuring the backend places on exactly the
nodes the user was shown. Also reverts incorrect RDMA min_nodes >= 2
enforcement since single-node RDMA is valid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:49:37 -08:00
Alex Cheema
4aecaa7748 Enforce min_nodes >= 2 for RDMA (MlxJaccl) instances
RDMA requires at least 2 nodes — a single-node RDMA instance is
nonsensical. Enforce this in both the dashboard (when building the
launch request) and the backend placement (when filtering cycles).
Previously, selecting RDMA would still place on 1 node because
min_nodes defaulted to 1 and the placement silently switched to Ring.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:42:21 -08:00
Alex Cheema
25e2891c30 Ensure min_nodes >= node filter size when launching
When user selects specific nodes via the filter, min_nodes should be at
least the number of filtered nodes to prevent placement from picking a
smaller cycle.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:36:51 -08:00
Alex Cheema
16345e0ffa Send node_ids from dashboard, error on RDMA when unavailable
Dashboard was not including the user's node filter in the POST to
/meta_instance, so placement ignored which nodes the user selected.
Also, placement silently fell back to Ring when RDMA was requested but
no RDMA-connected cycles were available — now raises an error that
surfaces via MetaInstancePlacementFailed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:26:29 -08:00
Alex Cheema
3a845f90b0 Fix use_default validator silently ignoring sharding/instance_meta
The mode="plain" validator bypassed Pydantic's string-to-enum coercion,
so JSON strings like "Tensor" and "MlxJaccl" from the dashboard failed
the isinstance check and silently fell back to Pipeline/MlxRing defaults.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:05:00 -08:00
Alex Cheema
dccf2440ba Add placement error feedback and per-node loading status
Show why MetaInstance placement fails instead of stuck "PLACING", and
show per-node runner status during loading for multi-node instances.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:01:07 -08:00
Alex Cheema
f96f3f2c0f Show MetaInstance sharding/type while PLACING, fix MlxIbv references
When a MetaInstance has no backing instance yet, derive the strategy
display from the MetaInstance's own sharding and instanceMeta fields
rather than showing "Unknown (Unknown)".

Also clean up all stale MlxIbv references across the dashboard —
the backend enum is MlxJaccl.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:23:44 -08:00
Alex Cheema
7d54e468d5 Extract reconciler into ProcessManager protocol, fix RDMA instance type
- Replace inline _plan() with ProcessManager loop (_reconcile), tick
  every 1s instead of 10s — safe because all PMs are idempotent
- Fix dashboard sending "MlxIbv" instead of "MlxJaccl" for RDMA
  instance type, which silently fell back to MlxRing default
- Remove all stale MlxIbv references from dashboard

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:19:13 -08:00
Alex Cheema
124d504f95 Extract reconciler into ProcessManager protocol
Replace inline _plan() steps with a list of ProcessManagers, each
implementing async reconcile(State) -> Sequence[Event]. Tick every
1s instead of 10s — safe because all PMs are idempotent against state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:06:41 -08:00
Alex Cheema
9ab4a40989 Simplify MetaInstance binding: put meta_instance_id on Instance
The separate MetaInstanceBound event + meta_instance_backing map
introduced two bugs: stale exclusion sets in the reconciler loop and
a delete ordering race. Embedding meta_instance_id directly on
BaseInstance eliminates the binding mechanism entirely — when an
instance is created for a MetaInstance it carries the ID, when
deleted the binding is gone. No separate map, no cleanup, no races.

Also fixes delete_meta_instance to cascade-delete backing instances.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 16:15:29 -08:00
Alex Cheema
f4329c72c2 Add explicit MetaInstance binding, slim MetaInstance to use ModelId
- Add MetaInstanceBound event and meta_instance_backing State field
  for explicit MetaInstance → Instance binding (prevents ambiguous
  linking when two MetaInstances have identical constraints)
- Replace model_card: ModelCard with model_id: ModelId on MetaInstance
  (load ModelCard on-demand at placement time)
- Add MetaInstance API endpoints (POST /meta_instance, DELETE)
- Update dashboard to use MetaInstances as primary primitive with
  unified display items merging MetaInstances and orphan instances
- Dashboard launches via MetaInstance instead of direct Instance creation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 15:53:07 -08:00
Alex Cheema
ceb76b8f6c Add MetaInstance declarative layer with connection health checking
Introduces MetaInstance as a declarative constraint ensuring an instance
matching given parameters (model, sharding, min_nodes) always exists.
The master's reconciliation loop continuously checks for unsatisfied
meta-instances and attempts placement. Connection health checking
verifies that specific IPs (MlxRing) and RDMA interfaces (MlxJaccl)
stored on instances still exist as topology edges, enabling automatic
recovery when cables are swapped or interfaces change.

Also eliminates the master's loopback event path, unifying all event
emission through _apply_and_broadcast for simpler control flow.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 13:43:50 -08:00
75 changed files with 6108 additions and 1856 deletions

275
Cargo.lock generated
View File

@@ -125,9 +125,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.101"
version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "arc-swap"
@@ -141,6 +141,12 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.7.1"
@@ -165,7 +171,7 @@ checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
"synstructure",
]
@@ -177,7 +183,7 @@ checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -224,7 +230,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -298,6 +304,19 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
[[package]]
name = "bigdecimal"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
dependencies = [
"autocfg",
"libm",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.3"
@@ -421,9 +440,9 @@ dependencies = [
[[package]]
name = "chrono"
version = "0.4.43"
version = "0.4.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118"
checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2"
dependencies = [
"iana-time-zone",
"js-sys",
@@ -497,6 +516,15 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
[[package]]
name = "convert_case"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -644,7 +672,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -670,7 +698,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d162beedaa69905488a8da94f5ac3edb4dd4788b732fadb7bd120b2625c1976"
dependencies = [
"data-encoding",
"syn 1.0.109",
"syn 2.0.111",
]
[[package]]
@@ -681,7 +709,7 @@ checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -718,6 +746,29 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_more"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.111",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -738,7 +789,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -820,7 +871,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -888,17 +939,22 @@ name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"env_logger",
"extend",
"futures-lite",
"futures",
"impl-trait-for-tuples",
"libp2p",
"log",
"networking",
"once_cell",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"thiserror 2.0.17",
"thread_local",
"tokio",
"util",
]
@@ -911,15 +967,9 @@ checksum = "311a6d2f1f9d60bff73d2c78a0af97ed27f79672f15c238192a5bbb64db56d00"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
version = "0.13.1"
@@ -1028,10 +1078,7 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
@@ -1043,7 +1090,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -1593,6 +1640,17 @@ dependencies = [
"xmltree",
]
[[package]]
name = "impl-trait-for-tuples"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "indexmap"
version = "2.12.1"
@@ -1657,7 +1715,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -1711,7 +1769,7 @@ checksum = "980af8b43c3ad5d8d349ace167ec8170839f753a42d233ba19e08afe1850fa69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -1771,6 +1829,12 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libp2p"
version = "0.56.0"
@@ -2300,7 +2364,7 @@ checksum = "dd297cf53f0cb3dee4d2620bb319ae47ef27c702684309f682bdb7e55a18ae9c"
dependencies = [
"heck",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -2760,13 +2824,16 @@ name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures-lite",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"libp2p",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
@@ -2838,9 +2905,9 @@ dependencies = [
[[package]]
name = "num-conv"
version = "0.2.0"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
@@ -2851,6 +2918,17 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -3053,7 +3131,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3165,9 +3243,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.106"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
@@ -3192,7 +3270,7 @@ checksum = "440f724eba9f6996b75d63681b0a92b06947f1457076d503a4d2e2c8f56442b8"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3201,14 +3279,28 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
dependencies = [
"bigdecimal",
"either",
"hashbrown 0.16.1",
"indexmap",
"indoc",
"inventory",
"libc",
"lock_api",
"memoffset",
"num-bigint",
"num-complex",
"num-rational",
"num-traits",
"once_cell",
"ordered-float",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"rust_decimal",
"smallvec",
"unindent",
]
@@ -3236,7 +3328,7 @@ checksum = "bcd7d70ee0ca1661c40407e6f84e4463ef2658c90a9e2fbbd4515b2bcdfcaeca"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3278,7 +3370,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3291,14 +3383,14 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
name = "pyo3-stub-gen"
version = "0.19.0"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b159f7704044f57d058f528a6f1f22a0a0a327dcb595c5fb38beae658e0338d6"
checksum = "398b833826a83ca72c1e26d1b2c7c71f9ca7c3bfc74eacc663901895c362ae33"
dependencies = [
"anyhow",
"chrono",
@@ -3313,25 +3405,22 @@ dependencies = [
"ordered-float",
"pyo3",
"pyo3-stub-gen-derive",
"rustpython-parser",
"serde",
"serde_json",
"time",
"toml",
]
[[package]]
name = "pyo3-stub-gen-derive"
version = "0.19.0"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8c79e7c5b1fcec7c39ab186594658a971c59911eb6fbab5a5932cf2318534be"
checksum = "2426ba759d848787239d80f9fdb1f223786976f87fb6c3da8188ca7c17744b28"
dependencies = [
"heck",
"indexmap",
"proc-macro2",
"quote",
"rustpython-parser",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3414,9 +3503,9 @@ dependencies = [
[[package]]
name = "quote"
version = "1.0.44"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
@@ -3652,6 +3741,16 @@ dependencies = [
"tokio",
]
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -3887,7 +3986,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3905,9 +4004,9 @@ dependencies = [
[[package]]
name = "serde_spanned"
version = "1.0.4"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776"
checksum = "e24345aa0fe688594e73770a5f6d1b216508b4f93484c0026d521acd30134392"
dependencies = [
"serde_core",
]
@@ -4095,9 +4194,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.116"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3df424c70518695237746f84cede799c9c58fcb37450d7b23716568cc8bc69cb"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
@@ -4112,7 +4211,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4188,7 +4287,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4199,7 +4298,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4213,30 +4312,30 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.47"
version = "0.3.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c"
checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d"
dependencies = [
"deranged",
"itoa",
"num-conv",
"powerfmt",
"serde_core",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.8"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca"
checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b"
[[package]]
name = "time-macros"
version = "0.2.27"
version = "0.2.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215"
checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3"
dependencies = [
"num-conv",
"time-core",
@@ -4312,7 +4411,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4330,9 +4429,9 @@ dependencies = [
[[package]]
name = "toml"
version = "1.0.2+spec-1.1.0"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1dfefef6a142e93f346b64c160934eb13b5594b84ab378133ac6815cb2bd57f"
checksum = "f0dc8b1fb61449e27716ec0e1bdf0f6b8f3e8f6b05391e8497b8b6d7804ea6d8"
dependencies = [
"indexmap",
"serde_core",
@@ -4345,27 +4444,27 @@ dependencies = [
[[package]]
name = "toml_datetime"
version = "1.0.0+spec-1.1.0"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e"
checksum = "f2cdb639ebbc97961c51720f858597f7f24c4fc295327923af55b74c3c724533"
dependencies = [
"serde_core",
]
[[package]]
name = "toml_parser"
version = "1.0.9+spec-1.1.0"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4"
checksum = "c0cbe268d35bdb4bb5a56a2de88d0ad0eb70af5384a99d648cd4b3d04039800e"
dependencies = [
"winnow",
]
[[package]]
name = "toml_writer"
version = "1.0.6+spec-1.1.0"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607"
checksum = "df8b2b54733674ad286d16267dcfc7a71ed5c776e4ac7aa3c3e2561f7c637bf2"
[[package]]
name = "tower-service"
@@ -4392,7 +4491,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4516,12 +4615,24 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"
@@ -4704,7 +4815,7 @@ dependencies = [
"bumpalo",
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
"wasm-bindgen-shared",
]
@@ -4846,7 +4957,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4857,7 +4968,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4868,7 +4979,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4879,7 +4990,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -5268,7 +5379,7 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
"synstructure",
]
@@ -5289,7 +5400,7 @@ checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -5309,7 +5420,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
"synstructure",
]
@@ -5330,7 +5441,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -5363,5 +5474,5 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]

View File

@@ -26,22 +26,49 @@ opt-level = 3
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-lite = "2.6.1"
futures-util = "0.3"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"

View File

@@ -72,23 +72,16 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
```bash
nix run .#exo
```
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```

View File

@@ -126,37 +126,11 @@ final class ExoProcessController: ObservableObject {
return
}
process.terminationHandler = nil
status = .stopped
guard process.isRunning else {
self.process = nil
return
if process.isRunning {
process.terminate()
}
let proc = process
self.process = nil
Task.detached {
proc.interrupt()
for _ in 0..<50 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
proc.terminate()
}
for _ in 0..<30 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
kill(proc.processIdentifier, SIGKILL)
}
}
status = .stopped
}
func restart() {

View File

@@ -1,7 +0,0 @@
# Canary benchmark manifest
#
# Lists the suite files to include. Each file defines benchmarks
# with shared constraints, topology, and default args.
include = [
"single-m3-ultra.toml",
]

View File

@@ -1,189 +0,0 @@
# Single-node M3 Ultra benchmarks
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=1)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
]
[topology]
type = "none"
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]

View File

@@ -185,11 +185,7 @@
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {

View File

@@ -21,7 +21,7 @@
} | null;
nodes?: Record<string, NodeInfo>;
sharding?: "Pipeline" | "Tensor";
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
runtime?: "MlxRing" | "MlxJaccl";
onLaunch?: () => void;
tags?: string[];
apiPreview?: PlacementPreview | null;
@@ -348,7 +348,7 @@
// Debug mode state
const isDebugMode = $derived(debugMode());
const topology = $derived(topologyData());
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
const isRdma = $derived(runtime === "MlxJaccl");
// Get interface name for an IP from node data
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
@@ -575,7 +575,7 @@
>
{runtime === "MlxRing"
? "MLX Ring"
: runtime === "MlxIbv" || runtime === "MlxJaccl"
: runtime === "MlxJaccl"
? "MLX RDMA"
: runtime}
</span>

View File

@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
export interface PlacementPreview {
model_id: string;
sharding: "Pipeline" | "Tensor";
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
instance_meta: "MlxRing" | "MlxJaccl";
instance: unknown | null;
memory_delta_by_node: Record<string, number> | null;
error: string | null;
@@ -219,7 +219,6 @@ interface RawStateResponse {
string,
{
MlxRingInstance?: Instance;
MlxIbvInstance?: Instance;
MlxJacclInstance?: Instance;
}
>;
@@ -250,6 +249,20 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// MetaInstances (declarative instance constraints)
metaInstances?: Record<string, MetaInstanceData>;
}
export interface MetaInstanceData {
metaInstanceId: string;
modelId: string;
sharding: string;
instanceMeta: string;
minNodes: number;
nodeIds: string[] | null;
placementError: string | null;
consecutiveFailures: number;
lastFailureError: string | null;
}
export interface MessageAttachment {
@@ -536,6 +549,7 @@ class AppStore {
isLoadingPreviews = $state(false);
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
metaInstances = $state<Record<string, MetaInstanceData>>({});
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderbolt = $state<
@@ -895,11 +909,7 @@ class AppStore {
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {
@@ -1264,6 +1274,8 @@ class AppStore {
if (data.downloads) {
this.downloads = data.downloads;
}
// MetaInstances
this.metaInstances = data.metaInstances ?? {};
if (data.nodeDisk) {
this.nodeDisk = data.nodeDisk;
}
@@ -3044,6 +3056,7 @@ export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const metaInstances = () => appStore.metaInstances;
export const runners = () => appStore.runners;
export const downloads = () => appStore.downloads;
export const nodeDisk = () => appStore.nodeDisk;

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -115,7 +115,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260217+50487b41"; in
version = let v = "0.30.6"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
};
patches = [

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.7",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -64,7 +64,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }

View File

@@ -58,21 +58,6 @@
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
@@ -89,25 +74,14 @@
linuxOverlay
]
);
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;

View File

@@ -25,38 +25,53 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
# "nightly", # enables better-supported GIL integration
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
# "ordered-float", "rust_decimal", "smallvec",
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.19.0" }
pyo3-stub-gen = { version = "0.17.2" }
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures-lite = { workspace = true }
futures = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,85 +1,155 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401, F403, F405
# ruff: noqa: E501, F401
import builtins
import enum
import typing
__all__ = [
"AllQueuesFullError",
"Keypair",
"NoPeersSubscribedToTopicError",
"PyMessage",
"PySwarm",
]
@typing.final
class AllQueuesFullError(builtins.Exception):
def __new__(cls, *_a: typing.Any) -> AllQueuesFullError: ...
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
@typing.final
class Keypair:
r"""
Identity keypair of a node.
"""
@staticmethod
def generate() -> Keypair:
def generate_ed25519() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def deserialize(bytes: bytes) -> Keypair:
def generate_ecdsa() -> Keypair:
r"""
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
def serialize(self) -> bytes:
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_string(self) -> builtins.str:
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
def __new__(cls, *_a: typing.Any) -> NoPeersSubscribedToTopicError: ...
def __str__(self) -> builtins.str: ...
class PyMessage:
@typing.final
class Connection(PyMessage):
__match_args__ = ("node_id", "connected",)
@property
def node_id(self) -> builtins.str: ...
@property
def connected(self) -> builtins.bool: ...
def __new__(cls, node_id: builtins.str, connected: builtins.bool) -> PyMessage.Connection: ...
@typing.final
class Gossip(PyMessage):
__match_args__ = ("node_id", "topic", "data",)
@property
def node_id(self) -> builtins.str: ...
@property
def topic(self) -> builtins.str: ...
@property
def data(self) -> bytes: ...
def __new__(cls, node_id: builtins.str, topic: builtins.str, data: bytes) -> PyMessage.Gossip: ...
...
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
"""
def to_string(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
"""
@typing.final
class PySwarm:
def __new__(cls, identity: Keypair) -> PySwarm: ...
async def recv(self) -> PyMessage:
class NetworkingHandle:
def __new__(cls, identity: Keypair, bootstrap_peers: builtins.list[builtins.str] = ...) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next message from networking.
Receives the next `ConnectionUpdate` from networking.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> None:
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
"""
async def gossipsub_unsubscribe(self, topic: builtins.str) -> None:
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Unsubscribes from a `GossipSub` topic.
@@ -87,6 +157,65 @@ class PySwarm:
"""
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
r"""
Publishes a message to the network on a specific topic.
Publishes a message with multiple topics to the `GossipSub` network.
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...

View File

@@ -1,4 +1,8 @@
//! See: <https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await>
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
@@ -6,17 +10,31 @@ use std::{
task::{Context, Poll},
};
pub struct AllowThreads<F>(pub(crate) F);
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
#[pin_project]
#[repr(transparent)]
pub(crate) struct AllowThreads<F>(#[pin] F);
impl<F> AllowThreads<F>
where
Self: Future,
{
pub fn new(f: F) -> Self {
Self(f)
}
}
impl<F> Future for AllowThreads<F>
where
F: Future + Unpin + Send,
F::Output: Send,
F: Future + Ungil,
F::Output: Ungil,
{
type Output = F::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
}
}

View File

@@ -0,0 +1,240 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -1,47 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn deserialize(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn serialize<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_string(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
Ok(())
}

View File

@@ -4,13 +4,28 @@
//!
//!
mod allow_threading;
mod ident;
mod networking;
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
use crate::ident::ident_submodule;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -19,18 +34,35 @@ pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyResult, Python};
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::attach(|py| PyBytes::new(py, self).unbind())
Python::with_gil(|py| PyBytes::new(py, self).unbind())
}
}
@@ -45,16 +77,120 @@ pub(crate) mod ext {
}
pub trait FutureExt: Future + Sized {
/// SEE: https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
fn allow_threads_py(self) -> AllowThreads<Self>
where
AllowThreads<Self>: Future,
{
AllowThreads(self)
AllowThreads::new(self)
}
}
impl<T: Future> FutureExt for T {}
#[ext(pub, name = PyErrExt)]
impl PyErr {
fn receiver_channel_closed() -> Self {
PyConnectionError::new_err("Receiver channel closed unexpectedly")
}
}
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
match self {
Ok(v) => Some(v),
Err(e) => {
// write error back to python
e.write_unraisable(py, None);
None
}
}
}
}
#[ext(pub, name = TokioRuntimeExt)]
impl Runtime {
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
}
}
#[ext(pub, name = TokioMpscSenderExt)]
impl<T> mpsc::Sender<T> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
/// channel has not hung up already. An unsuccessful send would be one where
/// the corresponding receiver has already been closed.
async fn send_py(&self, value: T) -> PyResult<()> {
self.send(value)
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
}
#[ext(pub, name = TokioMpscReceiverExt)]
impl<T> mpsc::Receiver<T> {
/// Receives the next value for this receiver.
async fn recv_py(&mut self) -> PyResult<T> {
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
}
/// Receives at most `limit` values for this receiver and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
// get updates from receiver channel
let mut updates = Vec::with_capacity(limit);
let received = self.recv_many(&mut updates, limit).await;
// if we received zero items, then the channel was unexpectedly closed
if limit != 0 && received == 0 {
return Err(PyErr::receiver_channel_closed());
}
Ok(updates)
}
/// Tries to receive the next value for this receiver.
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
match self.try_recv() {
Ok(v) => Ok(Some(v)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
}
}
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
@@ -65,9 +201,16 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
// top-level constructs
// TODO: ...
Ok(())
}

View File

@@ -1,24 +1,31 @@
#![allow(
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::ResultExt as _;
use crate::ext::{ByteArrayExt as _, FutureExt as _};
use crate::ident::PyKeypair;
use crate::networking::exception::{PyAllQueuesFullError, PyNoPeersSubscribedToTopicError};
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::pyclass;
use futures_lite::FutureExt as _;
use networking::swarm::{FromSwarm, Swarm, ToSwarm};
use pyo3::coroutine::CancelHandle;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::prelude::*;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3_async_runtimes::tokio::get_runtime;
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods};
use std::pin::pin;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
use pyo3::types::PyTuple;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
@@ -42,11 +49,16 @@ mod exception {
#[pymethods]
impl PyNoPeersSubscribedToTopicError {
#[new]
#[pyo3(signature = (*_a))]
pub(crate) fn new(_a: &Bound<'_, PyTuple>) -> Self {
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
@@ -72,179 +84,502 @@ mod exception {
#[pymethods]
impl PyAllQueuesFullError {
#[new]
#[pyo3(signature = (*_a))]
pub(crate) fn new(_a: &Bound<'_, PyTuple>) -> Self {
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
}
#[gen_stub_pyclass]
#[pyclass]
struct PySwarm {
swarm: Arc<Mutex<Swarm>>,
from_swarm: Mutex<mpsc::Receiver<FromSwarm>>,
to_swarm: Mutex<mpsc::Sender<ToSwarm>>,
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
#[gen_stub_pyclass_complex_enum]
#[pyclass]
pub enum PyMessage {
Connection {
node_id: String,
connected: bool,
},
Gossip {
node_id: String,
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
data: Py<PyBytes>,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
impl TryFrom<FromSwarm> for PyMessage {
type Error = PyErr;
fn try_from(value: FromSwarm) -> Result<Self, Self::Error> {
match value {
FromSwarm::Discovered(nid) => Ok(PyMessage::Connection {
node_id: nid.to_base58(),
connected: true,
}),
FromSwarm::Expired(nid) => Ok(PyMessage::Connection {
node_id: nid.to_base58(),
connected: false,
}),
FromSwarm::Message(nid, topic, data) => Ok(PyMessage::Gossip {
node_id: nid.to_base58(),
topic,
data: data.pybytes(),
}),
FromSwarm::PublishError(e) => match e {
libp2p::gossipsub::PublishError::NoPeersSubscribedToTopic => {
Err(PyNoPeersSubscribedToTopicError::new_err())
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyAllQueuesFullError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
libp2p::gossipsub::PublishError::AllQueuesFull(_) => {
Err(PyAllQueuesFullError::new_err())
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
e => Err(PyRuntimeError::new_err(e.to_string())),
},
}
}
}
log::info!("RUST: networking task stopped");
}
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle {
// channels
to_task_tx: Option<mpsc::Sender<ToTask>>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
}
impl Drop for PyNetworkingHandle {
fn drop(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
}
#[allow(clippy::expect_used)]
impl PyNetworkingHandle {
fn new(
to_task_tx: mpsc::Sender<ToTask>,
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
) -> Self {
Self {
to_task_tx: Some(to_task_tx),
connection_update_rx: Mutex::new(connection_update_rx),
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
}
}
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
self.to_task_tx
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PySwarm {
impl PyNetworkingHandle {
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
// immediately beforehand to release the interpreter.
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
// ---- Lifecycle management methods ----
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
#[pyo3(signature = (identity, bootstrap_peers=vec![]))]
fn py_new(identity: Bound<'_, PyKeypair>, bootstrap_peers: Vec<String>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
use std::str::FromStr;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (to_client, from_swarm) = mpsc::channel(MPSC_CHANNEL_SIZE);
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { Swarm::new(identity, from_client, to_client) })
// parse bootstrap peer multiaddrs
let parsed_peers: Vec<libp2p::Multiaddr> = bootstrap_peers
.into_iter()
.map(|s| libp2p::Multiaddr::from_str(&s))
.collect::<Result<_, _>>()
.pyerr()?;
Ok(Self {
swarm: Arc::new(Mutex::new(swarm)),
from_swarm: Mutex::new(from_swarm),
to_swarm: Mutex::new(to_swarm),
})
if !parsed_peers.is_empty() {
log::info!("RUST: bootstrap peers: {:?}", parsed_peers);
}
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { create_swarm(identity, parsed_peers) })
.pyerr()?;
// spawn tokio task running the networking logic
get_runtime().spawn(async move {
networking_task(
swarm,
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
}
#[gen_stub(skip)]
async fn run(&self, #[pyo3(cancel_handle)] mut cancel: CancelHandle) -> PyResult<()> {
let copy = Arc::clone(&self.swarm);
let jh = get_runtime().spawn(async move {
copy.try_lock()
.expect("tried to run swarm twice")
.run()
.await
});
jh.or(async {
cancel.cancelled().await;
Ok(())
})
.await
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
#[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
// ---- Connection update receiver methods ----
/// Receives the next message from networking.
async fn recv(&self) -> PyResult<PyMessage> {
let msg = pin!(
self.from_swarm
.try_lock()
.expect("called recv concurrently")
.recv()
)
.allow_threads_py()
.await;
match msg {
None => Err(PyConnectionError::new_err("swarm closed")),
Some(msg) => msg.try_into(),
}
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
///
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
/// will sleep until a `ConnectionUpdate`s is sent.
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next `ConnectionUpdate` from networking.
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic.
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<()> {
///
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
pin!(
self.to_swarm
.try_lock()
.expect("called send concurrently")
.send(ToSwarm::Subscribe(topic))
)
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyConnectionError::new_err("swarm closed"))
self.to_task_tx()
.send_py(ToTask::GossipsubSubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())?
}
/// Unsubscribes from a `GossipSub` topic.
///
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<()> {
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
pin!(
self.to_swarm
.try_lock()
.expect("called send concurrently")
.send(ToSwarm::Unsubscribe(topic))
)
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyConnectionError::new_err("swarm closed"))
self.to_task_tx()
.send_py(ToTask::GossipsubUnsubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & convert any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
/// Publishes a message to the network on a specific topic.
/// Publishes a message with multiple topics to the `GossipSub` network.
///
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
pin!(
self.to_swarm
.try_lock()
.expect("called send concurrently")
.send(ToSwarm::Message(topic, data))
)
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyConnectionError::new_err("swarm closed"))
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,
data,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors => ignore messageID for now!!!
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())??;
Ok(())
}
// ---- Gossipsub message receiver methods ----
/// Receives the next message from the `GossipSub` network.
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
self.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
}
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.into_iter()
.map(|(t, d)| (t, d.pybytes()))
.collect())
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
m.add_class::<exception::PyAllQueuesFullError>()?;
m.add_class::<PySwarm>()?;
m.add_class::<PyMessage>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?;
Ok(())
}

View File

@@ -0,0 +1,159 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -0,0 +1,8 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -0,0 +1,81 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -19,14 +19,21 @@ either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures-lite = { workspace = true }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
@@ -34,4 +41,4 @@ keccak-const = { workspace = true }
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,6 +1,6 @@
use libp2p::identity;
use networking::swarm::{FromSwarm, Swarm, ToSwarm};
use tokio::sync::mpsc;
use futures::stream::StreamExt as _;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
@@ -11,50 +11,60 @@ async fn main() {
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
let (to_swarm, from_client) = mpsc::channel(20);
let (to_client, mut from_swarm) = mpsc::channel(20);
// Configure swarm
let mut swarm = Swarm::new(
identity::Keypair::generate_ed25519(),
from_client,
to_client,
)
.expect("Swarm creation failed");
let mut swarm =
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
// Create a Gossipsub topic & subscribe
_ = to_swarm
.send(ToSwarm::Subscribe("test-net".to_owned()))
.await;
let topic = gossipsub::IdentTopic::new("test-net");
swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("Subscribing to topic failed");
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
tokio::task::spawn(async move { swarm.run().await });
// Kick it off
loop {
select! {
// on gossipsub outgoing
Ok(Some(line)) = stdin.next_line() => {
_= to_swarm.send(ToSwarm::Message("test-net".to_owned(), line.into_bytes())).await;
}
event = from_swarm.recv() => match event {
// on gossipsub incoming
Some(FromSwarm::Message(pid, topic, content)) => {
assert_eq!(topic, "test-net");
let fmt = String::from_utf8_lossy(&content);
println!("{pid}: {fmt}");
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
Some(FromSwarm::Discovered(pid)) => {
eprintln!("\n\nConnected to: {pid}\n\n");
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
Some(FromSwarm::Expired(pid)) => {
eprintln!("\n\nDisconnected from: {pid}\n\n");
}
None => break,
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }

View File

@@ -0,0 +1,127 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -1,10 +1,11 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::futures::FutureExt;
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{
@@ -105,6 +106,9 @@ pub struct Behaviour {
managed: managed::Behaviour,
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
/// Addresses provided via --bootstrap-peer, re-dialed on the retry loop.
bootstrap_addrs: Vec<Multiaddr>,
retry_delay: Delay, // retry interval
// pending events to emmit => waker-backed Deque to control polling
@@ -112,13 +116,21 @@ pub struct Behaviour {
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
pub fn new(keypair: &identity::Keypair, bootstrap_peers: Vec<Multiaddr>) -> io::Result<Self> {
let mut behaviour = Self {
managed: managed::Behaviour::new(keypair)?,
mdns_discovered: HashMap::new(),
bootstrap_addrs: bootstrap_peers,
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
})
};
// Immediately dial all bootstrap peers
for addr in &behaviour.bootstrap_addrs.clone() {
behaviour.dial_addr(addr.clone());
}
Ok(behaviour)
}
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
@@ -127,6 +139,14 @@ impl Behaviour {
})
}
/// Dial by address only — PeerId is resolved via Noise handshake on connect.
/// Used for bootstrap peers where only IP:port is known.
fn dial_addr(&mut self, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::unknown_peer_id().address(addr).build(),
})
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
// push front to make this IMMEDIATE
self.pending_events.push_front(ToSwarm::CloseConnection {
@@ -361,13 +381,16 @@ impl NetworkBehaviour for Behaviour {
Poll::Pending => {}
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
// retry connecting to all mDNS + bootstrap peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
}
}
for addr in self.bootstrap_addrs.clone() {
self.dial_addr(addr)
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
}

View File

@@ -0,0 +1,44 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -4,7 +4,18 @@
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
@@ -43,3 +54,11 @@ pub(crate) mod ext {
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -1,30 +1,9 @@
use crate::alias;
use crate::discovery;
use crate::swarm::transport::tcp_transport;
use behaviour::{Behaviour, BehaviourEvent};
use futures_lite::StreamExt;
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
use tokio::sync::mpsc;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{Multiaddr, SwarmBuilder, identity};
pub struct Swarm {
swarm: libp2p::Swarm<Behaviour>,
from_client: mpsc::Receiver<ToSwarm>,
to_client: mpsc::Sender<FromSwarm>,
}
#[derive(Debug)]
pub enum FromSwarm {
PublishError(gossipsub::PublishError),
Discovered(PeerId),
Expired(PeerId),
Message(PeerId, String, Vec<u8>),
}
#[derive(Debug)]
pub enum ToSwarm {
Message(String, Vec<u8>),
Subscribe(String),
Unsubscribe(String),
}
pub type Swarm = libp2p::Swarm<Behaviour>;
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
@@ -36,142 +15,26 @@ pub enum ToSwarm {
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
impl Swarm {
/// Create and configure a swarm which listens to all ports on OS
pub fn new(
keypair: identity::Keypair,
from_client: mpsc::Receiver<ToSwarm>,
to_client: mpsc::Sender<FromSwarm>,
) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.build();
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(
keypair: identity::Keypair,
bootstrap_peers: Vec<Multiaddr>,
) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(|kp| Behaviour::new(kp, bootstrap_peers))?
.build();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(Self {
swarm,
from_client,
to_client,
})
}
pub async fn run(&mut self) {
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = self.from_client.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
ToSwarm::Subscribe(topic) => {
// try to subscribe
match self.swarm.behaviour_mut().gossipsub.subscribe(&gossipsub::IdentTopic::new(topic.clone())) {
Err(e) => {
let gossipsub::SubscriptionError::PublishError(e) = e else {
unreachable!("topic filter used")
};
let Ok(()) = self.to_client.send(FromSwarm::PublishError(e)).await else {
log::warn!("RUST: client connection closed");
break
};
},
Ok(false) => log::warn!("RUST: tried to subscribe to topic twice"),
Ok(true) => {},
}
}
ToSwarm::Unsubscribe(topic) => {
// try to subscribe
if !self.swarm.behaviour_mut().gossipsub.unsubscribe(&gossipsub::IdentTopic::new(topic)) {
log::warn!("RUST: tried to unsubscribe from topic twice");
}
}
ToSwarm::Message( topic, data ) => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
match self.swarm.behaviour_mut().gossipsub.publish(
gossipsub::IdentTopic::new(topic), data
) {
Ok(_) => {},
Err(e) => {
let Ok(()) = self.to_client.send(FromSwarm::PublishError(e)).await else {
log::warn!("RUST: client connection closed");
break
};
},
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = self.swarm.next() => {
let Some(swarm_event) = swarm_event else {
log::warn!("RUST: swarm closed communication");
break
};
let SwarmEvent::Behaviour(behaviour_event) = swarm_event else {
continue
};
match behaviour_event {
BehaviourEvent::Gossipsub(gossipsub::Event::Message {
message: gossipsub::Message {
source,
topic,
data,
..
},
..
}) => {
let Some(peer_id) = source else {
log::warn!("RUST: ignoring message with unknown source on {topic}");
continue;
};
// send incoming message to channel (or exit if connection closed)
if let Err(e) = self.to_client.send(FromSwarm::Message(peer_id, topic.into_string(), data)).await {
log::warn!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
break
};
},
BehaviourEvent::Discovery(discovery::Event::ConnectionEstablished { peer_id, .. }) => {
// send connection event to channel (or exit if connection closed)
if let Err(_) = self.to_client.send(FromSwarm::Discovered(peer_id)).await {
log::warn!("RUST: swarm closed communication");
};
},
BehaviourEvent::Discovery(discovery::Event::ConnectionClosed { peer_id, .. }) => {
// send connection event to channel (or exit if connection closed)
if let Err(_) = self.to_client.send(FromSwarm::Expired(peer_id)).await {
log::warn!("RUST: swarm closed communication");
};
},
e => {
log::debug!("RUST: other event {e:?}");
}
}
}
}
}
log::info!("RUST: networking task stopped");
}
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm)
}
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures_lite::{AsyncRead, AsyncWrite};
use futures::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;
@@ -245,7 +108,7 @@ mod transport {
mod behaviour {
use crate::{alias, discovery};
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
use libp2p::{Multiaddr, gossipsub, identity};
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
@@ -256,9 +119,12 @@ mod behaviour {
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
pub fn new(
keypair: &identity::Keypair,
bootstrap_peers: Vec<Multiaddr>,
) -> alias::AnyResult<Self> {
Ok(Self {
discovery: discovery::Behaviour::new(keypair)?,
discovery: discovery::Behaviour::new(keypair, bootstrap_peers)?,
gossipsub: gossipsub_behaviour(keypair),
})
}

View File

@@ -14,7 +14,6 @@ from exo.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
@@ -64,9 +63,6 @@ class DownloadCoordinator:
self.event_sender, self.event_receiver = channel[Event]()
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
return str(EXO_MODELS_DIR / model_id.normalize())
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
@@ -78,7 +74,6 @@ class DownloadCoordinator:
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -98,7 +93,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
@@ -176,11 +170,7 @@ class DownloadCoordinator:
return
# Emit pending status
progress = DownloadPending(
shard_metadata=shard,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
@@ -194,7 +184,6 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -217,7 +206,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
@@ -231,7 +219,6 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(
@@ -266,7 +253,6 @@ class DownloadCoordinator:
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
@@ -309,18 +295,21 @@ class DownloadCoordinator:
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
if (
progress.downloaded_bytes.in_bytes
>= progress.total_bytes.in_bytes
> 0
):
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
@@ -329,9 +318,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
continue

View File

@@ -44,9 +44,14 @@ class Node:
@classmethod
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_string())
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
bootstrap_peers: list[str] = []
if args.bootstrap_peer:
ip, port = args.bootstrap_peer.rsplit(":", 1)
bootstrap_peers.append(f"/ip4/{ip}/tcp/{port}")
logger.info(f"Bootstrap peer: {args.bootstrap_peer}")
router = Router.create(keypair, bootstrap_peers=bootstrap_peers)
await router.register_topic(topics.GLOBAL_EVENTS)
await router.register_topic(topics.LOCAL_EVENTS)
await router.register_topic(topics.COMMANDS)
@@ -136,8 +141,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -149,6 +152,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit
@@ -254,7 +259,7 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn")
mp.set_start_method("spawn", force=True)
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
@@ -283,6 +288,7 @@ class Args(CamelCaseModel):
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
bootstrap_peer: str | None = None
@classmethod
def parse(cls) -> Self:
@@ -343,6 +349,13 @@ class Args(CamelCaseModel):
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
parser.add_argument(
"--bootstrap-peer",
type=str,
dest="bootstrap_peer",
default=None,
help="IP:PORT of an existing node to connect to (bypasses mDNS)",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -144,8 +144,8 @@ async def collect_responses_response(
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=tool.id,
call_id=tool.id,
id=f"fc_{tool.id}",
call_id=f"call_{tool.id}",
name=tool.name,
arguments=tool.arguments,
)

View File

@@ -71,8 +71,11 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
CreateMetaInstanceParams,
CreateMetaInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
DeleteMetaInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
@@ -115,8 +118,10 @@ from exo.shared.types.claude_api import (
from exo.shared.types.commands import (
Command,
CreateInstance,
CreateMetaInstance,
DeleteDownload,
DeleteInstance,
DeleteMetaInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
@@ -129,7 +134,7 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -138,6 +143,7 @@ from exo.shared.types.events import (
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
@@ -276,6 +282,9 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/meta_instances")(self.list_meta_instances)
self.app.post("/meta_instance")(self.create_meta_instance)
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
@@ -305,12 +314,27 @@ class API:
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
async def place_instance(self, payload: PlaceInstanceParams):
model_card = await ModelCard.load(payload.model_id)
command = PlaceInstance(
model_card=await ModelCard.load(payload.model_id),
model_card=model_card,
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
)
# Validate placement before sending — fail fast with a clear error
# instead of silently dropping the command in the master.
try:
get_instance_placements(
command,
topology=self.state.topology,
current_instances=self.state.instances,
node_memory=self.state.node_memory,
node_network=self.state.node_network,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
await self._send(command)
return CreateInstanceResponse(
@@ -522,6 +546,44 @@ class API:
instance_id=instance_id,
)
def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]:
return dict(self.state.meta_instances)
async def create_meta_instance(
self, payload: CreateMetaInstanceParams
) -> CreateMetaInstanceResponse:
meta_instance = MetaInstance(
model_id=payload.model_id,
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
node_ids=payload.node_ids,
)
command = CreateMetaInstance(meta_instance=meta_instance)
await self._send(command)
return CreateMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance.meta_instance_id,
)
async def delete_meta_instance(
self, meta_instance_id: MetaInstanceId
) -> DeleteMetaInstanceResponse:
meta = self.state.meta_instances.get(meta_instance_id)
if not meta:
raise HTTPException(status_code=404, detail="MetaInstance not found")
# Command processor handles cascade-deleting backing instances
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
await self._send(command)
return DeleteMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance_id,
)
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:

View File

@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone
from collections.abc import Sequence
from datetime import datetime, timezone
import anyio
from anyio.abc import TaskGroup
@@ -12,11 +13,22 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.process_managers import ProcessManager
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.commands import (
CreateInstance,
CreateMetaInstance,
DeleteInstance,
DeleteMetaInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
@@ -36,8 +48,12 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
JacclSideChannelData,
JacclSideChannelGathered,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
@@ -60,7 +76,8 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import MultiSourceBuffer
@@ -84,16 +101,16 @@ class Master:
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
self._process_managers: Sequence[ProcessManager] = [
InstanceHealthReconciler(),
NodeTimeoutReconciler(),
MetaInstanceReconciler(),
]
async def run(self):
logger.info("Starting Master")
@@ -102,15 +119,12 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
tg.start_soon(self._reconcile)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -292,6 +306,85 @@ class Master:
)
)
generated_events.extend(transition_events)
case CreateMetaInstance():
logger.info(
f"Creating MetaInstance for {command.meta_instance.model_id}"
f" (min_nodes={command.meta_instance.min_nodes},"
f" sharding={command.meta_instance.sharding})"
)
# Apply immediately so self.state is fresh across
# the await below and the reconciler won't race.
await self._apply_and_broadcast(
MetaInstanceCreated(meta_instance=command.meta_instance)
)
# Immediate placement attempt for responsiveness
model_card = await ModelCard.load(
command.meta_instance.model_id
)
# Re-check: reconciler may have satisfied it during the await
meta_id = command.meta_instance.meta_instance_id
still_unsatisfied = any(
m.meta_instance_id == meta_id
for m in find_unsatisfied_meta_instances(
self.state.meta_instances,
self.state.instances,
self.state.topology,
)
)
if still_unsatisfied:
result = try_place_for_meta_instance(
command.meta_instance,
model_card,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
generated_events.extend(result.events)
if result.error is not None:
generated_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_id,
reason=result.error,
)
)
case DeleteMetaInstance():
backing_count = sum(
1
for inst in self.state.instances.values()
if inst.meta_instance_id == command.meta_instance_id
)
logger.info(
f"Deleting MetaInstance {command.meta_instance_id}"
f" (cascade-deleting {backing_count} backing instance(s))"
)
generated_events.append(
MetaInstanceDeleted(
meta_instance_id=command.meta_instance_id
)
)
# Cascade-delete backing instances atomically,
# cancelling any active tasks first.
for iid, inst in self.state.instances.items():
if inst.meta_instance_id == command.meta_instance_id:
for task in self.state.tasks.values():
if (
task.instance_id == iid
and task.task_status
in (
TaskStatus.Pending,
TaskStatus.Running,
)
):
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
generated_events.append(
InstanceDeleted(instance_id=iid)
)
case PlaceInstance():
placement = place_instance(
command,
@@ -354,31 +447,32 @@ class Master:
):
await self._send_event(IndexedEvent(idx=i, event=event))
for event in generated_events:
await self.event_sender.send(event)
await self._apply_and_broadcast(event)
except ValueError as e:
logger.opt(exception=e).warning("Error in command processor")
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
async def _plan(self) -> None:
async def _apply_and_broadcast(self, event: Event) -> None:
"""Apply event to state, persist to disk, and broadcast to workers.
State is updated synchronously (before any await), so callers can
rely on ``self.state`` reflecting this event immediately after the
call. Python's cooperative scheduling guarantees no interleaving
between the state read and write.
"""
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
self._event_log.append(event)
await self._send_event(indexed)
async def _reconcile(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
await self.event_sender.send(
InstanceDeleted(instance_id=instance_id)
)
break
# time out dead nodes
for node_id, time in self.state.last_seen.items():
now = datetime.now(tz=timezone.utc)
if now - time > timedelta(seconds=30):
logger.info(f"Manually removing node {node_id} due to inactivity")
await self.event_sender.send(NodeTimedOut(node_id=node_id))
await anyio.sleep(10)
for pm in self._process_managers:
events = await pm.reconcile(self.state)
for event in events:
await self._apply_and_broadcast(event)
await anyio.sleep(1)
async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
@@ -396,32 +490,15 @@ class Master:
await self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
if isinstance(event, JacclSideChannelData):
await self._apply_and_broadcast(event)
await self._handle_jaccl_side_channel(event)
continue
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
await self._apply_and_broadcast(event)
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
@@ -453,10 +530,49 @@ class Master:
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self.event_sender.send(
await self._apply_and_broadcast(
TracesMerged(task_id=task_id, traces=all_trace_data)
)
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
"""Accumulate SideChannel contributions; when all runners for an instance
have submitted for the same sequence, emit JacclSideChannelGathered."""
iid = event.instance_id
seq = event.sequence
if iid not in self._jaccl_pending:
self._jaccl_pending[iid] = {}
if seq not in self._jaccl_pending[iid]:
self._jaccl_pending[iid][seq] = {}
self._jaccl_pending[iid][seq][event.runner_id] = event.data
instance = self.state.instances.get(iid)
if instance is None:
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
return
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
submitted = set(self._jaccl_pending[iid][seq].keys())
logger.info(
f"JACCL side channel: instance={iid} seq={seq} "
f"submitted={len(submitted)}/{len(expected_runners)}"
)
if submitted >= expected_runners:
gathered = dict(self._jaccl_pending[iid][seq])
del self._jaccl_pending[iid][seq]
if not self._jaccl_pending[iid]:
del self._jaccl_pending[iid]
await self._apply_and_broadcast(
JacclSideChannelGathered(
instance_id=iid,
sequence=seq,
gathered_data=gathered,
)
)

View File

@@ -6,11 +6,11 @@ from typing import Sequence
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_largest_cycles,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
@@ -106,23 +106,27 @@ def place_instance(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
largest_cycles = get_largest_cycles(cycles_with_sufficient_memory)
smallest_rdma_cycles = [
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
largest_rdma_cycles = [
cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
smallest_cycles = smallest_rdma_cycles
if command.instance_meta == InstanceMeta.MlxJaccl:
if not largest_rdma_cycles:
raise ValueError(
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
)
largest_cycles = largest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycle
for cycle in smallest_cycles
for cycle in largest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles,
key=lambda cycle: sum(
(node_memory[node_id].ram_available for node_id in cycle),
start=Memory(),

View File

@@ -37,11 +37,11 @@ def filter_cycles_by_memory(
return filtered_cycles
def get_smallest_cycles(
def get_largest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
max_nodes = max(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == max_nodes]
def allocate_layers_proportionally(

View File

@@ -0,0 +1,12 @@
from collections.abc import Sequence
from typing import Protocol, runtime_checkable
from exo.shared.types.events import Event
from exo.shared.types.state import State
@runtime_checkable
class ProcessManager(Protocol):
"""A reconciliation step that examines state and returns corrective events."""
async def reconcile(self, state: State) -> Sequence[Event]: ...

View File

@@ -0,0 +1,62 @@
from collections.abc import Sequence
from typing import final
from loguru import logger
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
from exo.shared.types.state import State
MAX_INSTANCE_RETRIES = 3
@final
class InstanceHealthReconciler:
"""Delete instances whose network connections are broken or whose runners have all failed."""
async def reconcile(self, state: State) -> Sequence[Event]:
events: list[Event] = []
for instance_id, instance in state.instances.items():
if not instance_connections_healthy(instance, state.topology):
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error="Network connection lost",
)
)
continue
is_failed, error_message = instance_runners_failed(
instance, state.runners, state.node_identities
)
if is_failed:
# Retry within the same instance if backed by a MetaInstance
mid = instance.meta_instance_id
mi = state.meta_instances.get(mid) if mid else None
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
logger.info(
f"Instance {instance_id} failed (attempt"
f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES}),"
f" retrying: {error_message}"
)
events.append(
InstanceRetrying(
instance_id=instance_id,
meta_instance_id=mid,
failure_error=error_message or "Runner failed",
)
)
else:
if mid and mi:
logger.warning(
f"Instance {instance_id} exceeded retry limit"
f" ({MAX_INSTANCE_RETRIES}), deleting:"
f" {error_message}"
)
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error=error_message,
)
)
return events

View File

@@ -0,0 +1,91 @@
from collections.abc import Sequence
from typing import final
import anyio
from loguru import logger
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
from exo.shared.types.state import State
from exo.shared.types.worker.instances import Instance, InstanceId
MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10
@final
class MetaInstanceReconciler:
"""Place instances for unsatisfied MetaInstances."""
async def reconcile(self, state: State) -> Sequence[Event]:
all_events: list[Event] = []
# Local copy for intermediate tracking — so placement of B
# sees A's instance and doesn't double-place on same resources.
current_instances: dict[InstanceId, Instance] = dict(state.instances)
unsatisfied = find_unsatisfied_meta_instances(
state.meta_instances,
current_instances,
state.topology,
)
for meta_instance in unsatisfied:
try:
with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS):
model_card = await ModelCard.load(meta_instance.model_id)
except TimeoutError:
logger.warning(
f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle"
)
continue
except Exception as exc:
logger.warning(
f"ModelCard.load failed for {meta_instance.model_id}: {exc}"
)
error = f"Failed to load model card: {exc}"
if meta_instance.placement_error != error:
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=error,
)
)
continue
result = try_place_for_meta_instance(
meta_instance,
model_card,
state.topology,
current_instances,
state.node_memory,
state.node_network,
)
# Update local instance map so next placement sees this one
for event in result.events:
if isinstance(event, InstanceCreated):
logger.info(
f"MetaInstance reconciler placed instance"
f" {event.instance.instance_id} for"
f" {meta_instance.model_id}"
)
current_instances[event.instance.instance_id] = event.instance
all_events.extend(result.events)
# Emit placement failure if error differs from what's already in state
if (
result.error is not None
and meta_instance.placement_error != result.error
):
logger.warning(
f"MetaInstance placement failed for"
f" {meta_instance.model_id}: {result.error}"
)
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=result.error,
)
)
return all_events

View File

@@ -0,0 +1,27 @@
from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from typing import final
from loguru import logger
from exo.shared.types.events import Event, NodeTimedOut
from exo.shared.types.state import State
_DEFAULT_TIMEOUT = timedelta(seconds=30)
@final
class NodeTimeoutReconciler:
"""Time out nodes that haven't been seen recently."""
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
self.timeout = timeout
async def reconcile(self, state: State) -> Sequence[Event]:
now = datetime.now(tz=timezone.utc)
events: list[Event] = []
for node_id, last_seen in state.last_seen.items():
if now - last_seen > self.timeout:
logger.info(f"Removing node {node_id} due to inactivity")
events.append(NodeTimedOut(node_id=node_id))
return events

240
src/exo/master/reconcile.py Normal file
View File

@@ -0,0 +1,240 @@
from collections.abc import Mapping, Sequence
from typing import NamedTuple
from loguru import logger
from exo.master.placement import get_transition_events, place_instance
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import Event
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.instances import (
BaseInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
RunnerStatus,
)
class PlacementResult(NamedTuple):
"""Result of a placement attempt: events to apply and optional error reason."""
events: Sequence[Event]
error: str | None
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
"""Reconstruct ring order from shard device_rank."""
node_ranks: list[tuple[NodeId, int]] = []
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
shard = instance.shard_assignments.runner_to_shard[runner_id]
node_ranks.append((node_id, shard.device_rank))
node_ranks.sort(key=lambda x: x[1])
return [node_id for node_id, _ in node_ranks]
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
"""Check that the specific IPs used by a ring instance still exist in the topology."""
ring = _get_ring_order(instance)
n = len(ring)
for node in ring:
hosts = instance.hosts_by_node[node]
for idx in range(n):
host = hosts[idx]
if host.ip in ("0.0.0.0", "198.51.100.1"):
continue # self or placeholder
# Real connection: node → ring[idx]. Check specific IP.
connections = topology.get_all_connections_between(node, ring[idx])
if not any(
isinstance(c, SocketConnection)
and c.sink_multiaddr.ip_address == host.ip
for c in connections
):
return False
return True
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
ring = _get_ring_order(instance)
n = len(ring)
for i in range(n):
for j in range(n):
iface = instance.jaccl_devices[i][j]
if iface is None:
continue
connections = topology.get_all_connections_between(ring[i], ring[j])
if not any(
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
for c in connections
):
return False
return True
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
"""Check that an instance's nodes and specific connections are still in the topology."""
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if not all(topology.contains_node(n) for n in instance_nodes):
return False
if len(instance_nodes) <= 1:
return True
match instance:
case MlxRingInstance():
return _ring_connections_healthy(instance, topology)
case MlxJacclInstance():
return _jaccl_connections_healthy(instance, topology)
def instance_runners_failed(
instance: Instance,
runners: Mapping[RunnerId, RunnerStatus],
node_identities: Mapping[NodeId, NodeIdentity],
) -> tuple[bool, str | None]:
"""Check if an instance's runners have all reached terminal failure states.
Returns ``(True, error_message)`` when ALL runners are terminal
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
Returns ``(False, None)`` when runners are still active, haven't reported
yet, or all gracefully shut down (no ``RunnerFailed``).
"""
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
if not instance_runner_ids:
return False, None
# Build reverse mapping: runner_id -> node_id
runner_to_node: dict[RunnerId, NodeId] = {
runner_id: node_id
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
}
has_any_failed = False
error_messages: list[str] = []
for runner_id in instance_runner_ids:
status = runners.get(runner_id)
if status is None:
# Runner hasn't reported yet — instance is still starting
return False, None
if isinstance(status, RunnerFailed):
has_any_failed = True
if status.error_message:
node_id = runner_to_node.get(runner_id)
name = (
node_identities[node_id].friendly_name
if node_id and node_id in node_identities
else node_id or "unknown"
)
error_messages.append(f"{name}: {status.error_message}")
elif isinstance(status, RunnerShutdown):
pass # Terminal but not a failure indicator on its own
else:
# Runner is still active (connecting, loading, running, etc.)
return False, None
if has_any_failed:
return True, "; ".join(error_messages) if error_messages else "Runner failed"
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
return False, None
def instance_satisfies_meta_instance(
meta_instance: MetaInstance,
instance: Instance,
) -> bool:
"""Check if a single instance satisfies a meta-instance's constraints.
This is a pure constraint check (model, min_nodes, node_ids).
Use ``instance_connections_healthy`` separately for topology health.
"""
if instance.shard_assignments.model_id != meta_instance.model_id:
return False
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if len(instance_nodes) < meta_instance.min_nodes:
return False
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
instance_nodes
)
def find_unsatisfied_meta_instances(
meta_instances: Mapping[MetaInstanceId, MetaInstance],
instances: Mapping[InstanceId, Instance],
topology: Topology,
) -> Sequence[MetaInstance]:
"""Return meta-instances that have no healthy backing instance."""
unsatisfied: list[MetaInstance] = []
for meta_id, meta_instance in meta_instances.items():
has_healthy_backing = any(
instance.meta_instance_id == meta_id
and instance_connections_healthy(instance, topology)
for instance in instances.values()
)
if not has_healthy_backing:
unsatisfied.append(meta_instance)
return unsatisfied
def try_place_for_meta_instance(
meta_instance: MetaInstance,
model_card: ModelCard,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> PlacementResult:
"""Try to place an instance satisfying the meta-instance constraints.
Returns a :class:`PlacementResult` with events on success, or an error
reason on failure.
"""
command = PlaceInstance(
model_card=model_card,
sharding=meta_instance.sharding,
instance_meta=meta_instance.instance_meta,
min_nodes=meta_instance.min_nodes,
)
try:
target_instances = place_instance(
command,
topology,
current_instances,
node_memory,
node_network,
required_nodes=(
set(meta_instance.node_ids) if meta_instance.node_ids else None
),
)
# Tag the new instance with meta_instance_id
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
if new_instance_ids:
new_id = next(iter(new_instance_ids))
target_instances[new_id] = target_instances[new_id].model_copy(
update={"meta_instance_id": meta_instance.meta_instance_id}
)
return PlacementResult(
events=list(get_transition_events(current_instances, target_instances, {})),
error=None,
)
except ValueError as e:
logger.debug(
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
)
return PlacementResult(events=[], error=str(e))

View File

@@ -42,7 +42,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_string())
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -75,7 +75,7 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_string()}_sender")
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(

View File

@@ -0,0 +1,778 @@
"""Edge-case and regression tests for MetaInstance lifecycle, concurrent operations, and error handling."""
import pytest
from exo.master.process_managers.instance_health import (
MAX_INSTANCE_RETRIES,
InstanceHealthReconciler,
)
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
instance_connections_healthy,
instance_runners_failed,
instance_satisfies_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, MetaInstanceId, NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NodeIdentity
from exo.shared.types.state import State
from exo.shared.types.tasks import LoadModel, TaskId, TaskStatus
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerReady,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
# --- Helpers (copied from test_reconcile.py for independence) ---
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_kb(1000),
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
def _topology(*node_ids: str, connect: bool = True) -> Topology:
t = Topology()
nodes = [NodeId(n) for n in node_ids]
for n in nodes:
t.add_node(n)
if connect and len(nodes) > 1:
for i in range(len(nodes)):
j = (i + 1) % len(nodes)
t.add_connection(
Connection(
source=nodes[i],
sink=nodes[j],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
)
),
)
)
t.add_connection(
Connection(
source=nodes[j],
sink=nodes[i],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
)
),
)
)
return t
def _meta_instance(
model_id: str = "test-org/test-model",
*,
min_nodes: int = 1,
node_ids: list[NodeId] | None = None,
meta_instance_id: MetaInstanceId | None = None,
consecutive_failures: int = 0,
last_failure_error: str | None = None,
placement_error: str | None = None,
) -> MetaInstance:
return MetaInstance(
meta_instance_id=meta_instance_id or MetaInstanceId(),
model_id=ModelId(model_id),
min_nodes=min_nodes,
node_ids=node_ids,
consecutive_failures=consecutive_failures,
last_failure_error=last_failure_error,
placement_error=placement_error,
)
def _instance(
model_id: str = "test-org/test-model",
node_ids: list[str] | None = None,
instance_id: InstanceId | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> tuple[InstanceId, MlxRingInstance]:
iid = instance_id or InstanceId()
nodes = node_ids or ["node-a"]
n = len(nodes)
mc = _model_card(model_id)
ephemeral_port = 50000
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
runner_to_shard = {
runner_id: PipelineShardMetadata(
model_card=mc,
device_rank=i,
world_size=n,
start_layer=0,
end_layer=mc.n_layers,
n_layers=mc.n_layers,
)
for i, runner_id in enumerate(node_to_runner.values())
}
hosts_by_node: dict[NodeId, list[Host]] = {}
for r, node_str in enumerate(nodes):
hosts: list[Host] = []
for idx in range(n):
if idx == r:
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
else:
hosts.append(Host(ip="198.51.100.1", port=0))
hosts_by_node[NodeId(node_str)] = hosts
return iid, MlxRingInstance(
instance_id=iid,
shard_assignments=ShardAssignments(
model_id=ModelId(model_id),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
),
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
meta_instance_id=meta_instance_id,
)
# =============================================================================
# 1. MetaInstance lifecycle edge cases
# =============================================================================
def test_meta_instance_model_is_frozen():
"""MetaInstance should be immutable (frozen model)."""
meta = _meta_instance()
try:
meta.model_id = ModelId("something-else")
raise AssertionError("Should have raised")
except Exception:
pass # Expected — frozen model
def test_meta_instance_created_then_deleted_roundtrip():
"""Create and delete a MetaInstance through apply — state should be clean."""
state = State()
meta = _meta_instance()
state = apply(
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta))
)
assert meta.meta_instance_id in state.meta_instances
state = apply(
state,
IndexedEvent(
idx=1, event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
),
)
assert meta.meta_instance_id not in state.meta_instances
assert len(state.meta_instances) == 0
def test_delete_nonexistent_meta_instance_is_safe():
"""Deleting a MetaInstance that doesn't exist should not crash."""
state = State()
event = MetaInstanceDeleted(meta_instance_id=MetaInstanceId("nonexistent"))
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
def test_placement_failed_for_nonexistent_meta_instance_is_safe():
"""MetaInstancePlacementFailed for unknown ID should not crash."""
state = State()
event = MetaInstancePlacementFailed(
meta_instance_id=MetaInstanceId("nonexistent"),
reason="test",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
def test_multiple_meta_instances_for_same_model():
"""Multiple MetaInstances for the same model are tracked independently."""
state = State()
meta_a = _meta_instance("test-org/model-x")
meta_b = _meta_instance("test-org/model-x")
state = apply(
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta_a))
)
state = apply(
state, IndexedEvent(idx=1, event=MetaInstanceCreated(meta_instance=meta_b))
)
assert len(state.meta_instances) == 2
assert meta_a.meta_instance_id in state.meta_instances
assert meta_b.meta_instance_id in state.meta_instances
# =============================================================================
# 2. Retry logic edge cases
# =============================================================================
def test_retry_counter_resets_on_successful_instance_creation():
"""When a new instance is created for a meta-instance, failures should reset."""
meta = _meta_instance(consecutive_failures=2, last_failure_error="old")
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
mi = state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
# last_failure_error is preserved (for UI display)
assert mi.last_failure_error == "old"
async def test_retry_count_increments_through_full_cycle():
"""Walk through MAX_INSTANCE_RETRIES worth of retries, then verify delete."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
topology = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
topology=topology,
)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
for idx, i in enumerate(range(MAX_INSTANCE_RETRIES)):
# Simulate runners failing
state_with_runners = state.model_copy(
update={"runners": {runner_ids[0]: RunnerFailed(error_message=f"fail-{i}")}}
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state_with_runners)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying), f"iteration {i}"
state = apply(state, IndexedEvent(idx=idx, event=events[0]))
# After MAX_INSTANCE_RETRIES retries, failure counter should be at max
mi = state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == MAX_INSTANCE_RETRIES
# Next failure should result in deletion
state_with_runners = state.model_copy(
update={"runners": {runner_ids[0]: RunnerFailed(error_message="final")}}
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state_with_runners)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_respects_exact_limit():
"""At exactly MAX_INSTANCE_RETRIES, reconciler should delete, not retry."""
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_at_limit_minus_one_retries():
"""At MAX_INSTANCE_RETRIES - 1, reconciler should still retry."""
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES - 1)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
# =============================================================================
# 3. Error handling edge cases
# =============================================================================
def test_runners_failed_with_empty_error_message():
"""RunnerFailed with empty error_message should still report as failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerFailed(error_message="")
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
# Empty error message means we get the fallback
assert error == "Runner failed"
def test_runners_failed_with_none_error_message():
"""RunnerFailed with None error_message should still report as failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerFailed(error_message=None)
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error == "Runner failed"
def test_runners_failed_collects_all_error_messages():
"""With multiple failed runners, all error messages should be collected."""
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM on GPU 0"),
runner_ids[1]: RunnerFailed(error_message="OOM on GPU 1"),
runner_ids[2]: RunnerFailed(error_message="OOM on GPU 2"),
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "OOM on GPU 0" in error
assert "OOM on GPU 1" in error
assert "OOM on GPU 2" in error
def test_runners_failed_includes_friendly_name():
"""Error messages should include node friendly names when available."""
_, inst = _instance(node_ids=["node-a"])
node_id = NodeId("node-a")
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {runner_ids[0]: RunnerFailed(error_message="OOM")}
identities = {node_id: NodeIdentity(friendly_name="My Mac Studio")}
is_failed, error = instance_runners_failed(inst, runners, identities)
assert is_failed is True
assert error is not None
assert "My Mac Studio" in error
def test_instance_retrying_for_missing_instance_is_safe():
"""InstanceRetrying for an instance not in state should not crash.
NOTE: When the instance is missing, the handler returns early WITHOUT
incrementing the MetaInstance failure counter. This means stale retry
events for already-deleted instances are silently dropped. This is
acceptable since the InstanceDeleted handler already increments failures.
"""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = InstanceRetrying(
instance_id=InstanceId("nonexistent"),
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
# Does not crash, but failure count is NOT incremented (early return)
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
# =============================================================================
# 4. Backward compatibility
# =============================================================================
def test_instance_without_meta_instance_id_works():
"""Instances created without meta_instance_id should still function normally."""
_, inst = _instance(node_ids=["node-a"])
assert inst.meta_instance_id is None
topology = _topology("node-a")
assert instance_connections_healthy(inst, topology) is True
def test_instance_deleted_without_meta_does_not_affect_meta_instances():
"""Deleting an instance without meta_instance_id should not affect meta_instances."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"]) # no meta_instance_id
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="crash")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0 # unchanged
def test_satisfies_ignores_meta_instance_id_binding():
"""instance_satisfies_meta_instance checks constraints only, not binding."""
meta = _meta_instance()
_, inst = _instance(node_ids=["node-a"]) # no meta_instance_id set
# Should match on constraints (model, min_nodes) regardless of binding
assert instance_satisfies_meta_instance(meta, inst) is True
def test_find_unsatisfied_uses_binding_not_constraints():
"""find_unsatisfied checks meta_instance_id binding, not just constraint matching."""
meta = _meta_instance()
# Instance matches constraints but is NOT bound to this meta_instance
iid, inst = _instance(node_ids=["node-a"])
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta}, {iid: inst}, topology
)
# Should be unsatisfied because instance.meta_instance_id != meta.meta_instance_id
assert list(result) == [meta]
# =============================================================================
# 5. Concurrent / multi-instance scenarios
# =============================================================================
async def test_health_reconciler_handles_multiple_failing_instances():
"""Multiple instances failing simultaneously should each get their own event."""
meta_a = _meta_instance()
meta_b = _meta_instance()
iid_a, inst_a = _instance(
node_ids=["node-a"], meta_instance_id=meta_a.meta_instance_id
)
iid_b, inst_b = _instance(
node_ids=["node-b"], meta_instance_id=meta_b.meta_instance_id
)
runner_ids_a = list(inst_a.shard_assignments.node_to_runner.values())
runner_ids_b = list(inst_b.shard_assignments.node_to_runner.values())
state = State(
meta_instances={
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
instances={iid_a: inst_a, iid_b: inst_b},
runners={
runner_ids_a[0]: RunnerFailed(error_message="OOM"),
runner_ids_b[0]: RunnerFailed(error_message="OOM"),
},
topology=_topology("node-a", "node-b"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 2
# Both should be InstanceRetrying since failures < MAX
assert all(isinstance(e, InstanceRetrying) for e in events)
instance_ids = {e.instance_id for e in events} # type: ignore[union-attr]
assert instance_ids == {iid_a, iid_b}
async def test_health_reconciler_mixed_healthy_and_failing():
"""Only failing instances should produce events; healthy ones should not."""
meta_healthy = _meta_instance()
meta_failing = _meta_instance()
iid_h, inst_h = _instance(
node_ids=["node-a"], meta_instance_id=meta_healthy.meta_instance_id
)
iid_f, inst_f = _instance(
node_ids=["node-b"], meta_instance_id=meta_failing.meta_instance_id
)
runner_ids_h = list(inst_h.shard_assignments.node_to_runner.values())
runner_ids_f = list(inst_f.shard_assignments.node_to_runner.values())
state = State(
meta_instances={
meta_healthy.meta_instance_id: meta_healthy,
meta_failing.meta_instance_id: meta_failing,
},
instances={iid_h: inst_h, iid_f: inst_f},
runners={
runner_ids_h[0]: RunnerReady(),
runner_ids_f[0]: RunnerFailed(error_message="crash"),
},
topology=_topology("node-a", "node-b"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
assert events[0].instance_id == iid_f
async def test_meta_instance_reconciler_empty_state():
"""MetaInstanceReconciler with no meta_instances should produce no events."""
state = State()
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 0
# =============================================================================
# 6. Placement error tracking
# =============================================================================
def test_placement_failed_sets_error():
"""MetaInstancePlacementFailed should set placement_error on the MetaInstance."""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstancePlacementFailed(
meta_instance_id=meta.meta_instance_id,
reason="Not enough memory",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.placement_error == "Not enough memory"
def test_instance_created_clears_placement_error():
"""InstanceCreated should clear placement_error on the MetaInstance."""
meta = _meta_instance(placement_error="Not enough memory")
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
mi = state.meta_instances[meta.meta_instance_id]
assert mi.placement_error is None
def test_placement_error_does_not_increment_failures():
"""Placement failures should only set placement_error, not increment consecutive_failures."""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstancePlacementFailed(
meta_instance_id=meta.meta_instance_id,
reason="No resources",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
assert mi.placement_error == "No resources"
# =============================================================================
# 7. State serialization roundtrip
# =============================================================================
def test_state_with_meta_instances_serializes():
"""State with meta_instances should serialize and deserialize correctly."""
meta = _meta_instance(consecutive_failures=2, last_failure_error="test")
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
json_str = state.model_dump_json()
restored = State.model_validate_json(json_str)
assert meta.meta_instance_id in restored.meta_instances
mi = restored.meta_instances[meta.meta_instance_id]
assert mi.model_id == meta.model_id
assert mi.consecutive_failures == 2
assert mi.last_failure_error == "test"
assert iid in restored.instances
assert restored.instances[iid].meta_instance_id == meta.meta_instance_id
# =============================================================================
# 8. MetaInstanceReconciler error handling
# =============================================================================
async def test_meta_instance_reconciler_model_load_error_emits_placement_failed(
monkeypatch: "pytest.MonkeyPatch",
):
"""When ModelCard.load raises, reconciler emits MetaInstancePlacementFailed."""
import exo.master.process_managers.meta_instance as mi_mod
meta = _meta_instance()
topo = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
topology=topo,
)
async def _failing_load(_model_id: ModelId) -> ModelCard:
raise RuntimeError("Network error")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
assert len(placement_failed) == 1
assert "Failed to load model card" in placement_failed[0].reason
assert meta.meta_instance_id == placement_failed[0].meta_instance_id
async def test_meta_instance_reconciler_model_load_error_skips_dedup(
monkeypatch: "pytest.MonkeyPatch",
):
"""When ModelCard.load error matches existing placement_error, no duplicate event."""
import exo.master.process_managers.meta_instance as mi_mod
meta = _meta_instance(placement_error="Failed to load model card: Network error")
topo = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
topology=topo,
)
async def _failing_load(_model_id: ModelId) -> ModelCard:
raise RuntimeError("Network error")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
# Error matches existing placement_error, so no duplicate event emitted
assert len(events) == 0
async def test_meta_instance_reconciler_continues_after_error(
monkeypatch: "pytest.MonkeyPatch",
):
"""Reconciler should continue to next meta-instance after one fails to load."""
import exo.master.process_managers.meta_instance as mi_mod
meta_a = _meta_instance(model_id="org/model-a")
meta_b = _meta_instance(model_id="org/model-b")
topo = _topology("node-a")
state = State(
meta_instances={
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
topology=topo,
)
call_count = 0
async def _load_second_fails(model_id: ModelId) -> ModelCard:
nonlocal call_count
call_count += 1
raise RuntimeError(f"Cannot load {model_id}")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_load_second_fails)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
# Both meta-instances should have been attempted (not short-circuited)
assert call_count == 2
# Both should have placement failed events
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
assert len(placement_failed) == 2
# =============================================================================
# 8. Cascade delete with task cancellation
# =============================================================================
def test_cascade_delete_cancels_active_tasks():
"""Deleting a MetaInstance should cancel tasks on backing instances.
Regression test: previously, cascade-deleting backing instances via
DeleteMetaInstance did not emit TaskStatusUpdated(Cancelled) for active
tasks, leaving orphaned task references in state.
"""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
task_id = TaskId()
task = LoadModel(task_id=task_id, instance_id=iid, task_status=TaskStatus.Running)
# Build state with meta-instance, backing instance, and active task
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
tasks={task_id: task},
topology=_topology("node-a"),
)
# Simulate the cascade-delete event sequence produced by main.py:
# 1. MetaInstanceDeleted
# 2. TaskStatusUpdated(Cancelled) for active tasks
# 3. InstanceDeleted
idx = 0
state = apply(
state,
IndexedEvent(
idx=idx,
event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id),
),
)
idx += 1
state = apply(
state,
IndexedEvent(
idx=idx,
event=TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Cancelled),
),
)
idx += 1
state = apply(
state,
IndexedEvent(idx=idx, event=InstanceDeleted(instance_id=iid)),
)
# Verify everything is cleaned up
assert len(state.meta_instances) == 0
assert len(state.instances) == 0
assert state.tasks[task_id].task_status == TaskStatus.Cancelled
def test_cascade_delete_skips_completed_tasks():
"""Cascade delete should only cancel Pending/Running tasks, not completed ones."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
running_task_id = TaskId()
completed_task_id = TaskId()
running_task = LoadModel(
task_id=running_task_id, instance_id=iid, task_status=TaskStatus.Running
)
completed_task = LoadModel(
task_id=completed_task_id, instance_id=iid, task_status=TaskStatus.Complete
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
tasks={running_task_id: running_task, completed_task_id: completed_task},
topology=_topology("node-a"),
)
# Only the running task should be cancelled — we verify the logic pattern
# by checking which tasks are Pending or Running
active_tasks = [
t
for t in state.tasks.values()
if t.instance_id == iid
and t.task_status in (TaskStatus.Pending, TaskStatus.Running)
]
assert len(active_tasks) == 1
assert active_tasks[0].task_id == running_task_id

View File

@@ -3,10 +3,10 @@ import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_largest_cycles,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_shard_assignments_for_pipeline_parallel,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
create_node_memory,
@@ -143,7 +143,7 @@ def test_filter_multiple_cycles_by_memory():
}
def test_get_smallest_cycles():
def test_get_largest_cycles():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
@@ -175,12 +175,12 @@ def test_get_smallest_cycles():
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
# act
smallest_cycles = get_smallest_cycles(cycles)
largest_cycles = get_largest_cycles(cycles)
# assert
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
assert len(largest_cycles) == 1
assert len(largest_cycles[0]) == 3
assert set(n for n in largest_cycles[0]) == {node_a_id, node_b_id, node_c_id}
@pytest.mark.parametrize(

View File

@@ -0,0 +1,742 @@
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
instance_connections_healthy,
instance_runners_failed,
instance_satisfies_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, MetaInstanceId, NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerLoading,
RunnerReady,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_kb(1000),
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
def _topology(*node_ids: str, connect: bool = True) -> Topology:
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
between consecutive nodes (including wrap-around).
"""
t = Topology()
nodes = [NodeId(n) for n in node_ids]
for n in nodes:
t.add_node(n)
if connect and len(nodes) > 1:
for i in range(len(nodes)):
j = (i + 1) % len(nodes)
t.add_connection(
Connection(
source=nodes[i],
sink=nodes[j],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
)
),
)
)
t.add_connection(
Connection(
source=nodes[j],
sink=nodes[i],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
)
),
)
)
return t
def _meta_instance(
model_id: str = "test-org/test-model",
*,
min_nodes: int = 1,
node_ids: list[NodeId] | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> MetaInstance:
return MetaInstance(
meta_instance_id=meta_instance_id or MetaInstanceId(),
model_id=ModelId(model_id),
min_nodes=min_nodes,
node_ids=node_ids,
)
def _instance(
model_id: str = "test-org/test-model",
node_ids: list[str] | None = None,
instance_id: InstanceId | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> tuple[InstanceId, MlxRingInstance]:
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
iid = instance_id or InstanceId()
nodes = node_ids or ["node-a"]
n = len(nodes)
mc = _model_card(model_id)
ephemeral_port = 50000
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
runner_to_shard = {
runner_id: PipelineShardMetadata(
model_card=mc,
device_rank=i,
world_size=n,
start_layer=0,
end_layer=mc.n_layers,
n_layers=mc.n_layers,
)
for i, runner_id in enumerate(node_to_runner.values())
}
# Build hosts_by_node with IPs matching _topology() convention:
# node at index idx has IP 10.0.0.{idx+1}
hosts_by_node: dict[NodeId, list[Host]] = {}
for r, node_str in enumerate(nodes):
hosts: list[Host] = []
for idx in range(n):
if idx == r:
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
else:
hosts.append(Host(ip="198.51.100.1", port=0))
hosts_by_node[NodeId(node_str)] = hosts
return iid, MlxRingInstance(
instance_id=iid,
shard_assignments=ShardAssignments(
model_id=ModelId(model_id),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
),
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
meta_instance_id=meta_instance_id,
)
# --- instance_satisfies_meta_instance (pure constraint matching) ---
def test_satisfies_matching_model():
meta = _meta_instance()
_, inst = _instance(node_ids=["node-a"])
assert instance_satisfies_meta_instance(meta, inst) is True
def test_not_satisfies_wrong_model():
meta = _meta_instance("test-org/model-a")
_, inst = _instance("test-org/model-b")
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_missing_required_node():
meta = _meta_instance(node_ids=[NodeId("node-c")])
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_fewer_than_min_nodes():
meta = _meta_instance(min_nodes=3)
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_satisfies_with_node_ids_specified():
meta = _meta_instance(node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2)
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
assert instance_satisfies_meta_instance(meta, inst) is True
# --- instance_connections_healthy ---
def test_healthy_single_node_present():
_, inst = _instance(node_ids=["node-a"])
topology = _topology("node-a")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_single_node_missing():
_, inst = _instance(node_ids=["node-a"])
topology = Topology() # empty
assert instance_connections_healthy(inst, topology) is False
def test_healthy_two_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_two_node_edge_removed():
"""Nodes present but edge removed — ring broken."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", connect=False)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_two_node_ip_changed():
"""Edge exists but with a different IP than instance was configured with."""
_, inst = _instance(node_ids=["node-a", "node-b"])
# Build topology with different IPs than _instance() expects
topology = Topology()
topology.add_node(NodeId("node-a"))
topology.add_node(NodeId("node-b"))
topology.add_connection(
Connection(
source=NodeId("node-a"),
sink=NodeId("node-b"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=NodeId("node-b"),
sink=NodeId("node-a"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
),
)
)
assert instance_connections_healthy(inst, topology) is False
def test_healthy_three_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_three_node_one_edge_removed():
"""Remove one edge from a three-node ring — instance unhealthy."""
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
# Build topology with one direction of one edge missing
topology = Topology()
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
for n in nodes:
topology.add_node(n)
# Add all edges except node-a → node-b
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[1],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[0],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
# Missing: node-a → node-b (ip 10.0.0.2)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_node_missing_from_topology():
"""Instance has a node that's not in the topology at all."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a") # node-b not present
assert instance_connections_healthy(inst, topology) is False
def test_healthy_extra_nodes_in_topology():
"""Extra nodes in topology don't affect instance health."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
# --- find_unsatisfied_meta_instances ---
def test_unsatisfied_no_meta_instances():
result = find_unsatisfied_meta_instances({}, {}, Topology())
assert list(result) == []
def test_unsatisfied_one_satisfied():
meta = _meta_instance()
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == []
def test_unsatisfied_one_not_satisfied():
meta = _meta_instance("test-org/model-x")
id_a, inst_a = _instance("test-org/model-y")
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
)
assert list(result) == [meta]
def test_unsatisfied_mix():
meta_satisfied = _meta_instance("test-org/model-a")
meta_unsatisfied = _meta_instance("test-org/model-b")
id_a, inst_a = _instance(
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_satisfied.meta_instance_id: meta_satisfied,
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
},
{id_a: inst_a},
topology,
)
assert list(result) == [meta_unsatisfied]
def test_unsatisfied_node_disconnect():
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a") # node-b disconnected
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_edge_break():
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_idempotent():
meta = _meta_instance("test-org/model-x")
topology = _topology("node-a")
meta_instances = {meta.meta_instance_id: meta}
instances: dict[InstanceId, MlxRingInstance] = {}
result_1 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
result_2 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
assert result_1 == result_2
def test_unsatisfied_exclusive_binding():
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
meta_a = _meta_instance("test-org/model-x")
meta_b = _meta_instance("test-org/model-x")
id_inst, inst = _instance(
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
{id_inst: inst},
topology,
)
assert list(result) == [meta_b]
# --- apply handlers ---
def test_apply_meta_instance_created():
state = State()
meta = _meta_instance()
event = MetaInstanceCreated(meta_instance=meta)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id in new_state.meta_instances
assert new_state.meta_instances[meta.meta_instance_id] == meta
def test_apply_meta_instance_deleted():
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
def test_apply_meta_instance_deleted_clears_failure_info():
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
)
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
# --- instance_runners_failed ---
def test_runners_failed_all_failed():
"""All runners in RunnerFailed -> instance is failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runners = {
rid: RunnerFailed(error_message="OOM")
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "OOM" in error
def test_runners_failed_mixed_failed_shutdown():
"""One Failed + one Shutdown = failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="crash"),
runner_ids[1]: RunnerShutdown(),
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "crash" in error
def test_runners_not_failed_all_shutdown():
"""All Shutdown (graceful) = not a failure."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerShutdown() for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_still_active():
"""Some runners still active = not failed yet."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerLoading(),
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_no_status():
"""Runner not yet reported = not failed."""
_, inst = _instance(node_ids=["node-a"])
is_failed, _ = instance_runners_failed(inst, {}, {})
assert is_failed is False
def test_runners_not_failed_healthy():
"""Runners in Ready state = not failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerReady() for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
# --- failure tracking in apply_instance_deleted ---
def test_apply_instance_deleted_tracks_failure():
"""InstanceDeleted with failure_error increments meta instance failure count."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "Runner OOM"
def test_apply_instance_deleted_increments_failure():
"""Subsequent failures increment the counter."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="new error")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 3
assert mi.last_failure_error == "new error"
def test_apply_instance_deleted_no_failure_no_tracking():
"""InstanceDeleted without failure_error does not track."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
def test_apply_instance_deleted_orphan_no_tracking():
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
iid, inst = _instance(node_ids=["node-a"])
state = State(instances={iid: inst})
event = InstanceDeleted(instance_id=iid, failure_error="crash")
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
# --- InstanceRetrying ---
def test_apply_instance_retrying_removes_runners():
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerShutdown(),
}
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners=runners,
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="OOM",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
# Instance still exists
assert iid in new_state.instances
# Runners removed
assert runner_ids[0] not in new_state.runners
assert runner_ids[1] not in new_state.runners
def test_apply_instance_retrying_increments_failure():
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "crash"
def test_apply_instance_retrying_skips_missing_runners():
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
# No runners in state at all
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
# Should not raise
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert iid in new_state.instances
def test_apply_instance_created_resets_failure_counter():
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 3, "last_failure_error": "old error"}
)
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
event = InstanceCreated(instance=inst)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
assert mi.last_failure_error == "old error"
assert mi.placement_error is None
# --- InstanceHealthReconciler retry-vs-delete ---
async def test_health_reconciler_retries_when_under_limit():
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
assert events[0].instance_id == iid
assert events[0].meta_instance_id == meta.meta_instance_id
async def test_health_reconciler_deletes_when_limit_reached():
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_deletes_without_meta_instance():
"""Instances without a MetaInstance are deleted immediately on runner failure."""
iid, inst = _instance(node_ids=["node-a"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_network_failure_always_deletes():
"""Network failure always triggers InstanceDeleted regardless of retry count."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
topology=_topology("node-a"), # node-b missing
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
assert events[0].failure_error == "Network connection lost"

View File

@@ -0,0 +1,37 @@
from enum import Enum
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages"""
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connection_type: ConnectionMessageType
remote_ipv4: str
remote_tcp_port: int
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
)

View File

@@ -16,19 +16,17 @@ from anyio.abc import TaskGroup
from exo_pyo3_bindings import (
AllQueuesFullError,
Keypair,
NetworkingHandle,
NoPeersSubscribedToTopicError,
PyMessage,
PySwarm,
)
from filelock import FileLock
from loguru import logger
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.shared.election import ConnectionMessage
from exo.shared.types.common import NodeId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.pydantic_ext import CamelCaseModel
from .connection_message import ConnectionMessage
from .topics import CONNECTION_MESSAGES, PublishPolicy, TypedTopic
@@ -103,14 +101,18 @@ class TopicRouter[T: CamelCaseModel]:
class Router:
@classmethod
def create(cls, identity: Keypair) -> "Router":
return cls(handle=PySwarm(identity))
def create(
cls,
identity: Keypair,
bootstrap_peers: list[str] | None = None,
) -> "Router":
return cls(handle=NetworkingHandle(identity, bootstrap_peers or []))
def __init__(self, handle: PySwarm):
def __init__(self, handle: NetworkingHandle):
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
send, recv = channel[tuple[str, bytes]]()
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
self._net = handle
self._net: NetworkingHandle = handle
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
self._id_count = count()
self._tg: TaskGroup | None = None
@@ -156,6 +158,7 @@ class Router:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
@@ -180,44 +183,38 @@ class Router:
async def _networking_recv(self):
while True:
try:
msg = await self._net.recv()
except NoPeersSubscribedToTopicError:
continue
except AllQueuesFullError:
logger.warning("All peer queues full, messages have been lost")
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
if topic not in self.topic_routers:
logger.warning(f"Received message on unknown or inactive topic {topic}")
continue
match msg:
case PyMessage.Connection():
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(
ConnectionMessage(
node_id=NodeId(msg.node_id), connected=msg.connected
)
)
case PyMessage.Gossip():
if msg.topic not in self.topic_routers:
logger.warning(
f"Received message on unknown or inactive topic {msg.topic}"
)
continue
logger.trace(
f"Received message on {msg.topic} with payload {msg.data}"
)
router = self.topic_routers[msg.topic]
await router.publish_bytes(msg.data)
case _:
raise ValueError("net recv returned something impossible")
router = self.topic_routers[topic]
await router.publish_bytes(data)
async def _networking_recv_connection_messages(self):
while True:
update = await self._net.connection_update_recv()
message = ConnectionMessage.from_update(update)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
async def _networking_publish(self):
with self.networking_receiver as networked_items:
async for topic, data in networked_items:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
try:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
except NoPeersSubscribedToTopicError:
pass
except AllQueuesFullError:
logger.warning(f"All peer queues full, dropping message on {topic}")
def get_node_id_keypair(
@@ -228,7 +225,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate()
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
@@ -242,12 +239,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.deserialize(protobuf_encoded)
return Keypair.from_protobuf_encoding(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate()
f.write(keypair.serialize())
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -1,7 +1,8 @@
from dataclasses import dataclass
from enum import Enum
from exo.shared.election import ConnectionMessage, ElectionMessage
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
ForwarderEvent,

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from loguru import logger
from exo.shared.types.common import NodeId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -12,6 +12,12 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
JacclSideChannelData,
JacclSideChannelGathered,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -28,6 +34,7 @@ from exo.shared.types.events import (
TracesCollected,
TracesMerged,
)
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
@@ -66,12 +73,22 @@ def event_apply(event: Event, state: State) -> State:
| InputChunkReceived()
| TracesCollected()
| TracesMerged()
| JacclSideChannelData()
| JacclSideChannelGathered()
): # Pass-through events that don't modify state
return state
case InstanceCreated():
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceRetrying():
return apply_instance_retrying(event, state)
case MetaInstanceCreated():
return apply_meta_instance_created(event, state)
case MetaInstanceDeleted():
return apply_meta_instance_deleted(event, state)
case MetaInstancePlacementFailed():
return apply_meta_instance_placement_failed(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -174,20 +191,123 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
return state.model_copy(update={"tasks": new_tasks})
def _update_meta_instance(
state: State, mid: MetaInstanceId, **fields: object
) -> Mapping[MetaInstanceId, MetaInstance]:
mi = state.meta_instances[mid]
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
def apply_instance_created(event: InstanceCreated, state: State) -> State:
instance = event.instance
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
instance.instance_id: instance,
}
return state.model_copy(update={"instances": new_instances})
update: dict[str, object] = {"instances": new_instances}
# Reset failure tracking when a new instance is created for a meta-instance
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
mi = state.meta_instances[instance.meta_instance_id]
if mi.placement_error is not None or mi.consecutive_failures > 0:
update["meta_instances"] = _update_meta_instance(
state,
instance.meta_instance_id,
placement_error=None,
consecutive_failures=0,
)
return state.model_copy(update=update)
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
deleted_instance = state.instances.get(event.instance_id)
new_instances: Mapping[InstanceId, Instance] = {
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
}
return state.model_copy(update={"instances": new_instances})
update: dict[str, object] = {"instances": new_instances}
# Track failure on the MetaInstance itself
if (
event.failure_error
and deleted_instance
and deleted_instance.meta_instance_id
and deleted_instance.meta_instance_id in state.meta_instances
):
mid = deleted_instance.meta_instance_id
mi = state.meta_instances[mid]
update["meta_instances"] = {
**state.meta_instances,
mid: mi.model_copy(
update={
"consecutive_failures": mi.consecutive_failures + 1,
"last_failure_error": event.failure_error,
}
),
}
return state.model_copy(update=update)
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
"""Runners failed but retry limit not reached — remove runners, keep instance."""
instance = state.instances.get(event.instance_id)
if instance is None:
# Instance was already deleted (e.g. cascade from DeleteMetaInstance).
# The InstanceDeleted handler already incremented consecutive_failures
# on the MetaInstance, so skipping here avoids double-counting.
return state
# Remove all runners belonging to this instance from state
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
new_runners: Mapping[RunnerId, RunnerStatus] = {
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
}
update: dict[str, object] = {"runners": new_runners}
# Increment failure count on the MetaInstance
if event.meta_instance_id in state.meta_instances:
update["meta_instances"] = _update_meta_instance(
state,
event.meta_instance_id,
consecutive_failures=state.meta_instances[
event.meta_instance_id
].consecutive_failures
+ 1,
last_failure_error=event.failure_error,
)
return state.model_copy(update=update)
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
**state.meta_instances,
event.meta_instance.meta_instance_id: event.meta_instance,
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
mid: mi
for mid, mi in state.meta_instances.items()
if mid != event.meta_instance_id
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_placement_failed(
event: MetaInstancePlacementFailed, state: State
) -> State:
if event.meta_instance_id not in state.meta_instances:
return state
return state.model_copy(
update={
"meta_instances": _update_meta_instance(
state, event.meta_instance_id, placement_error=event.reason
)
}
)
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
@@ -218,6 +338,11 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
@@ -258,6 +383,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"downloads": downloads,
"topology": topology,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
"node_system": node_system,

View File

@@ -10,6 +10,7 @@ from anyio import (
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, Sender
@@ -18,11 +19,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
DEFAULT_ELECTION_TIMEOUT = 3.0
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connected: bool
class ElectionMessage(CamelCaseModel):
clock: int
seniority: int

View File

@@ -1,7 +1,7 @@
import pytest
from anyio import create_task_group, fail_after, move_on_after
from exo.routing.router import ConnectionMessage
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
@@ -330,7 +330,9 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
await cm_tx.send(
ConnectionMessage(
node_id=NodeId(),
connected=True,
connection_type=ConnectionMessageType.Connected,
remote_ipv4="",
remote_tcp_port=0,
)
)

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release()
# wait to be told to begin simultaneous read
ev.wait()
queue.put(get_node_id_keypair().serialize())
queue.put(get_node_id_keypair().to_protobuf_encoding())
def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -6,7 +6,7 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -262,6 +262,26 @@ class DeleteInstanceResponse(BaseModel):
instance_id: InstanceId
class CreateMetaInstanceParams(BaseModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
class CreateMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class DeleteMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None

View File

@@ -6,7 +6,8 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -48,6 +49,14 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class CreateMetaInstance(BaseCommand):
meta_instance: MetaInstance
class DeleteMetaInstance(BaseCommand):
meta_instance_id: MetaInstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
@@ -93,6 +102,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| CreateMetaInstance
| DeleteMetaInstance
| TaskCancelled
| TaskFinished
| SendInputChunk

View File

@@ -42,6 +42,10 @@ class CommandId(Id):
pass
class MetaInstanceId(Id):
"""Identifier for a MetaInstance."""
class Host(CamelCaseModel):
ip: str
port: int

View File

@@ -1,11 +1,14 @@
import base64
from collections.abc import Mapping
from datetime import datetime
from typing import final
from typing import Annotated, final
from pydantic import Field
from pydantic import BeforeValidator, Field, PlainSerializer
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -14,6 +17,28 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
def _decode_base64_bytes(v: bytes | str) -> bytes:
if isinstance(v, bytes):
return v
return base64.b64decode(v)
def _encode_base64_bytes(v: bytes) -> str:
return base64.b64encode(v).decode("ascii")
Base64Bytes = Annotated[
bytes,
BeforeValidator(_decode_base64_bytes),
PlainSerializer(_encode_base64_bytes, return_type=str),
]
"""bytes that serialize to/from base64 strings in JSON.
Needed because TaggedModel's wrap validator converts JSON→Python validation
context, which breaks strict-mode bytes deserialization from JSON strings.
"""
class EventId(Id):
"""
Newtype around `ID`
@@ -66,6 +91,30 @@ class InstanceCreated(BaseEvent):
class InstanceDeleted(BaseEvent):
instance_id: InstanceId
failure_error: str | None = None
class MetaInstanceCreated(BaseEvent):
meta_instance: MetaInstance
class MetaInstanceDeleted(BaseEvent):
meta_instance_id: MetaInstanceId
@final
class MetaInstancePlacementFailed(BaseEvent):
meta_instance_id: MetaInstanceId
reason: str
@final
class InstanceRetrying(BaseEvent):
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
instance_id: InstanceId
meta_instance_id: MetaInstanceId
failure_error: str
class RunnerStatusUpdated(BaseEvent):
@@ -132,6 +181,25 @@ class TracesMerged(BaseEvent):
traces: list[TraceEventData]
@final
class JacclSideChannelData(BaseEvent):
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
instance_id: InstanceId
runner_id: RunnerId
sequence: int
data: Base64Bytes
@final
class JacclSideChannelGathered(BaseEvent):
"""Gathered result of a JACCL SideChannel all_gather round."""
instance_id: InstanceId
sequence: int
gathered_data: Mapping[RunnerId, Base64Bytes]
Event = (
TestEvent
| TaskCreated
@@ -141,6 +209,10 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceRetrying
| MetaInstanceCreated
| MetaInstanceDeleted
| MetaInstancePlacementFailed
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut
@@ -152,6 +224,8 @@ Event = (
| TopologyEdgeDeleted
| TracesCollected
| TracesMerged
| JacclSideChannelData
| JacclSideChannelGathered
)

View File

@@ -0,0 +1,25 @@
from typing import final
from pydantic import Field
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.worker.instances import InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import FrozenModel
@final
class MetaInstance(FrozenModel):
"""Declarative constraint: ensure an instance matching these parameters always exists."""
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
# Failure tracking
placement_error: str | None = None
consecutive_failures: int = 0
last_failure_error: str | None = None

View File

@@ -6,7 +6,8 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import (
DiskUsage,
MemoryUsage,
@@ -41,6 +42,7 @@ class State(CamelCaseModel):
arbitrary_types_allowed=True,
)
instances: Mapping[InstanceId, Instance] = {}
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}

View File

@@ -26,7 +26,6 @@ class DownloadProgressData(CamelCaseModel):
class BaseDownloadProgress(TaggedModel):
node_id: NodeId
shard_metadata: ShardMetadata
model_directory: str = ""
class DownloadPending(BaseDownloadProgress):

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +19,7 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
meta_instance_id: MetaInstanceId | None = None
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -1,7 +1,5 @@
import sys
def print_startup_banner(port: int) -> None:
"""Print a prominent startup banner with API endpoint information."""
dashboard_url = f"http://localhost:{port}"
banner = f"""
╔═══════════════════════════════════════════════════════════════════════╗
@@ -29,4 +27,4 @@ def print_startup_banner(port: int) -> None:
"""
print(banner, file=sys.stderr)
print(banner)

View File

@@ -306,7 +306,7 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Starting prefill")
logger.info("Ready to prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(

View File

@@ -353,13 +353,7 @@ def load_tokenizer_for_model_id(
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(
hf_tokenizer,
eos_token_ids=eos_token_ids,
tool_call_start="<|tool_calls_section_begin|>",
tool_call_end="<|tool_calls_section_end|>",
tool_parser=_parse_kimi_tool_calls,
)
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
tokenizer = load_tokenizer(
model_path,
@@ -591,41 +585,3 @@ def mx_barrier(group: Group | None):
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)
def _parse_kimi_tool_calls(text: str):
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
_tool_call_split_regex = re.compile(
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
)
def _parse_single_tool(text: str) -> dict[str, Any]:
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError("No tool call found.")
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
func_name = func_name_match.group(2) # e.g. "get_weather"
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError("No tool call arguments found.")
func_args = func_args_match.group(1)
try:
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
except Exception:
arg_dct = None
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
tool_matches = _tool_call_split_regex.findall(text)
if tool_matches:
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
else:
return [_parse_single_tool(text)]

View File

@@ -24,6 +24,7 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
JacclSideChannelGathered,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -159,6 +160,15 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
if isinstance(event, JacclSideChannelGathered):
for runner in self.runners.values():
if (
runner.bound_instance.instance.instance_id
== event.instance_id
):
runner.notify_gathered(event)
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id

View File

@@ -35,6 +35,7 @@ from exo.shared.types.worker.runners import (
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerStatus,
RunnerWarmingUp,
)
@@ -56,7 +57,7 @@ def plan(
return (
_cancel_tasks(runners, tasks)
or _kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _create_runner(node_id, runners, instances, all_runners)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
@@ -75,6 +76,12 @@ def _kill_runner(
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
return Shutdown(instance_id=instance_id, runner_id=runner_id)
# Master removed our runner from state (retry signal) and process is dead
if runner_id not in all_runners and isinstance(
runner.status, (RunnerFailed, RunnerShutdown)
):
return Shutdown(instance_id=instance_id, runner_id=runner_id)
for (
global_runner_id
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
@@ -92,6 +99,7 @@ def _create_runner(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus],
) -> CreateRunner | None:
for instance in instances.values():
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
@@ -101,6 +109,16 @@ def _create_runner(
if runner_id in runners:
continue
# Don't create while any peer runner is in a terminal state — wait for
# the master to emit InstanceRetrying which removes them from state.
has_terminal_peer = any(
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
for peer_rid in instance.shard_assignments.node_to_runner.values()
if peer_rid != runner_id
)
if has_terminal_peer:
continue
shard = instance.shard(runner_id)
assert shard is not None

View File

@@ -17,6 +17,7 @@ def entrypoint(
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
pipe_fifo_paths: tuple[str, str] | None = None,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -30,6 +31,16 @@ def entrypoint(
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
if pipe_fifo_paths is not None:
fifo_c2p, fifo_p2c = pipe_fifo_paths
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
global logger
logger = _logger

View File

@@ -1,10 +1,11 @@
import base64
import json
import math
import resource
import time
from collections.abc import Generator
from functools import cache
from typing import Literal
from typing import Any, Callable, Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -15,6 +16,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask
@@ -91,8 +93,6 @@ from exo.worker.engines.mlx.utils_mlx import (
)
from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser, make_mlx_parser
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
@@ -138,7 +138,6 @@ def main(
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
tool_parser: ToolParser | None = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
@@ -204,17 +203,8 @@ def main(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
tool_parser = make_mlx_parser(
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
kv_prefix_cache = KVPrefixCache(group)
elif (
@@ -320,11 +310,31 @@ def main(
mlx_generator, tokenizer
)
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
if isinstance(inference_model, GptOssModel):
elif isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
completion_tokens = 0
tokens_since_last_cancel_check = 0
@@ -577,8 +587,21 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
@@ -635,9 +658,9 @@ def parse_gpt_oss(
def parse_thinking_models(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -758,55 +781,221 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse], tool_parser: ToolParser
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
if response.text.startswith(tool_parser.start_parsing):
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"tool call parsing interrupted, yield partial tool call as text"
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(
tool_calls=tools, usage=response.usage, stats=response.stats
)
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.finish_reason is not None:
logger.info(
"toll call parsing interrupted, yield partial tool call as text"
)
yield GenerationResponse(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=response.usage,
stats=response.stats,
)
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
_func_name_regex = re.compile(
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
original_func_name = func_name_match.group(1)
tool_id = func_name_match.group(2)
# strip off the `functions.` prefix, if it exists.
func_name = original_func_name[original_func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(
id=f"{original_func_name}:{tool_id}",
name=func_name,
arguments=arg_dct, # pyright: ignore[reportAny]
)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
raw_id: object = obj.get("id")
extra = {"id": str(raw_id)} if raw_id is not None else {}
return ToolCallItem(
**extra,
name=name,
arguments=json.dumps(args),
)
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -1,6 +1,10 @@
import contextlib
import os
import signal
import struct
import tempfile
from dataclasses import dataclass, field
from functools import partial
from multiprocessing import Process
from typing import Self
@@ -14,12 +18,14 @@ from loguru import logger
from exo.shared.types.events import (
Event,
JacclSideChannelData,
JacclSideChannelGathered,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import (
RunnerConnecting,
RunnerFailed,
@@ -34,6 +40,26 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.worker.runner.bootstrap import entrypoint
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
data = b""
while len(data) < n:
chunk = os.read(fd, n - len(data))
if not chunk:
return None
data += chunk
return data
def _pipe_write_all(fd: int, data: bytes) -> None:
"""Write all bytes to a file descriptor."""
view = memoryview(data)
while view:
written = os.write(fd, view)
view = view[written:]
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -48,10 +74,19 @@ class RunnerSupervisor:
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_pipe_read_fd: int | None = None # Python reads runner's pipe output
_pipe_write_fd: int | None = None # Python writes gathered data to runner
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
_gathered_waiters: dict[
int, tuple[anyio.Event, JacclSideChannelGathered | None]
] = field(default_factory=dict, init=False)
@classmethod
def create(
@@ -65,6 +100,23 @@ class RunnerSupervisor:
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
# Named pipes work across multiprocessing.Process spawn (macOS default).
# FIFO c2p: C++ writes local data → Python reads it
# FIFO p2c: Python writes gathered data → C++ reads it
fifo_dir: str | None = None
fifo_c2p: str | None = None
fifo_p2c: str | None = None
pipe_fifo_paths: tuple[str, str] | None = None
if isinstance(bound_instance.instance, MlxJacclInstance):
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
os.mkfifo(fifo_c2p)
os.mkfifo(fifo_p2c)
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
runner_process = Process(
target=entrypoint,
args=(
@@ -73,6 +125,7 @@ class RunnerSupervisor:
task_recv,
cancel_recv,
logger,
pipe_fifo_paths,
),
daemon=True,
)
@@ -88,13 +141,45 @@ class RunnerSupervisor:
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
_fifo_dir=fifo_dir,
_fifo_c2p=fifo_c2p,
_fifo_p2c=fifo_p2c,
)
return self
async def run(self):
self.runner_process.start()
await self._forward_events()
if self._fifo_c2p is not None and self._fifo_p2c is not None:
# Open FIFOs from parent side. These block until child opens the other end,
# so we run them in threads concurrently to avoid deadlock.
fifo_c2p = self._fifo_c2p
fifo_p2c = self._fifo_p2c
async def open_read() -> None:
self._pipe_read_fd = await to_thread.run_sync(
partial(os.open, fifo_c2p, os.O_RDONLY)
)
async def open_write() -> None:
self._pipe_write_fd = await to_thread.run_sync(
partial(os.open, fifo_p2c, os.O_WRONLY)
)
async with anyio.create_task_group() as open_tg:
open_tg.start_soon(open_read)
open_tg.start_soon(open_write)
logger.info(
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
)
async with anyio.create_task_group() as tg:
tg.start_soon(self._pipe_relay)
tg.start_soon(self._forward_events)
else:
await self._forward_events()
def shutdown(self):
logger.info("Runner supervisor shutting down")
@@ -103,6 +188,7 @@ class RunnerSupervisor:
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self._close_pipe_fds()
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
@@ -181,6 +267,110 @@ class RunnerSupervisor:
for tid in self.pending:
self.pending[tid].set()
def _close_pipe_fds(self) -> None:
if self._pipe_read_fd is not None:
with contextlib.suppress(OSError):
os.close(self._pipe_read_fd)
self._pipe_read_fd = None
if self._pipe_write_fd is not None:
with contextlib.suppress(OSError):
os.close(self._pipe_write_fd)
self._pipe_write_fd = None
if self._child_pipe_fds is not None:
for fd in self._child_pipe_fds:
with contextlib.suppress(OSError):
os.close(fd)
self._child_pipe_fds = None
# Clean up FIFO files
if self._fifo_c2p is not None:
with contextlib.suppress(OSError):
os.unlink(self._fifo_c2p)
self._fifo_c2p = None
if self._fifo_p2c is not None:
with contextlib.suppress(OSError):
os.unlink(self._fifo_p2c)
self._fifo_p2c = None
if self._fifo_dir is not None:
with contextlib.suppress(OSError):
os.rmdir(self._fifo_dir)
self._fifo_dir = None
async def _pipe_relay(self) -> None:
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
assert self._pipe_read_fd is not None
assert self._pipe_write_fd is not None
read_fd = self._pipe_read_fd
write_fd = self._pipe_write_fd
sequence = 0
try:
while True:
# 1. Read local data from runner: [uint32 size][size bytes]
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
if header is None:
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
break
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
local_data = await to_thread.run_sync(
partial(_pipe_read_exact, read_fd, data_size)
)
if local_data is None:
logger.warning("JACCL pipe relay: EOF reading data payload")
break
logger.info(
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
)
# 2. Emit JacclSideChannelData event
waiter = anyio.Event()
self._gathered_waiters[sequence] = (waiter, None)
await self._event_sender.send(
JacclSideChannelData(
instance_id=self.bound_instance.instance.instance_id,
runner_id=self.bound_instance.bound_runner_id,
sequence=sequence,
data=local_data,
)
)
# 3. Wait for gathered result
await waiter.wait()
_, gathered_event = self._gathered_waiters.pop(sequence)
assert gathered_event is not None
# 4. Order gathered data by runner rank and concatenate
instance = self.bound_instance.instance
assert isinstance(instance, MlxJacclInstance)
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
ordered_data = b"".join(
gathered_event.gathered_data[rid] for rid in runner_order
)
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
total_size = len(ordered_data)
response = struct.pack("<I", total_size) + ordered_data
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
logger.info(
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
)
sequence += 1
except OSError as e:
logger.warning(f"JACCL pipe relay: OS error: {e}")
except Exception as e:
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
"""Called by the worker when a JacclSideChannelGathered event arrives."""
seq = event.sequence
if seq not in self._gathered_waiters:
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
return
waiter, _ = self._gathered_waiters[seq]
self._gathered_waiters[seq] = (waiter, event)
waiter.set()
def __del__(self) -> None:
if self.runner_process.is_alive():
logger.warning("RunnerSupervisor was not stopped cleanly.")

View File

@@ -1,72 +0,0 @@
import json
from dataclasses import dataclass
from typing import Any, Callable
from exo.shared.types.api import ToolCallItem
@dataclass
class ToolParser:
start_parsing: str
end_parsing: str
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
def make_mlx_parser(
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> ToolParser:
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix(tool_call_start)
text = text.removesuffix(tool_call_end)
parsed = tool_parser(text)
if isinstance(parsed, list):
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
else:
return [ToolCallItem.model_validate(_flatten(parsed))]
except Exception:
return None
return ToolParser(
start_parsing=tool_call_start,
end_parsing=tool_call_end,
parse_tool_calls=parse_tool_calls,
)
# TODO / example code:
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix("<tool_call>")
text = text.removesuffix("</tool_call>")
top_level = {
k: json.dumps(v) if isinstance(v, (dict, list)) else v
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
}
return [ToolCallItem.model_validate(top_level)]
except Exception:
return None
def _flatten(p: dict[str, Any]) -> dict[str, str]:
return {
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
for k, v in p.items() # pyright: ignore[reportAny]
}
json_tool_parser = ToolParser(
start_parsing="<tool_call>",
end_parsing="</tool_call>",
parse_tool_calls=_parse_json_calls,
)
def infer_tool_parser(chat_template: str) -> ToolParser | None:
"""Attempt to auto-infer a tool parser from the chat template."""
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
return json_tool_parser
return None

View File

@@ -5,13 +5,12 @@ from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.runner import parse_tool_calls
from exo.worker.runner.tool_parsers import make_mlx_parser
def _make_responses(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""Create a sequence of GenerationResponses from text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
@@ -23,13 +22,10 @@ def _make_responses(
)
def _dummier_parser(text: str) -> dict[str, Any]:
def _dummy_parser(text: str) -> dict[str, Any]:
return {"name": "test_fn", "arguments": {"arg": text}}
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
class TestParseToolCalls:
"""Tests for parse_tool_calls generator."""
@@ -39,6 +35,8 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -52,6 +50,8 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -76,7 +76,9 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
"<tool_call>",
"</tool_call>",
_failing_parser,
)
)

View File

@@ -43,5 +43,4 @@ for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
echo "all hosts alive!"
wait

48
uv.lock generated
View File

@@ -377,8 +377,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -416,9 +416,9 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.7" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "msgspec", specifier = ">=0.19.0" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
@@ -1020,8 +1020,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1048,12 +1048,18 @@ wheels = [
name = "mlx"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform == 'linux'",
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
]
@@ -1066,14 +1072,6 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.7.dev20260217+50487b41"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.6"
@@ -1100,20 +1098,30 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.7"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/66/0d/56542e2ae13ec6f542d3977d7cff89a205d4f6c5122e0ce23f33265f61c9/mlx_lm-0.30.7.tar.gz", hash = "sha256:e5f31ac58d9f2381f28e1ba639ff903e64f7cff1bdc245c0bc97f72264be329c", size = 275764, upload-time = "2026-02-12T18:41:11.86Z" }
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/17/a41c798a3d9cbdc47f39c6db5bba4c2cd199203ead26bf911cb03b644070/mlx_lm-0.30.7-py3-none-any.whl", hash = "sha256:17442a4bf01c4c2d3bca1e647712fe44f19890c3f1eadc8589d389e57b44b9bf", size = 386591, upload-time = "2026-02-12T18:41:10.236Z" },
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
]
[[package]]