Compare commits

..

23 Commits

Author SHA1 Message Date
Alex Cheema
a288401a7f fix: pass _cancel_sender in RunnerSupervisor test helper
After merging main (api cancellation #1276), the RunnerSupervisor
dataclass requires a _cancel_sender field. Update the test helper
to create and pass this channel.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:25:22 -08:00
Alex Cheema
b36721e6d9 ci: retrigger CI (darwin runner back) 2026-02-16 10:01:27 -08:00
Alex Cheema
671e5de248 ci: retrigger CI (darwin runner stale) 2026-02-16 10:01:27 -08:00
Alex Cheema
5ec7b35841 ci: retrigger CI 2026-02-16 10:01:27 -08:00
Alex Cheema
92b20128a7 fix: retry failed e2e tests once to handle flaky Docker networking
Docker mDNS discovery can be slow on first boot in CI, causing
cluster_formation to timeout on "Nodes discovered each other" while
subsequent tests pass fine. Retry failed tests once before counting
them as real failures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:01:27 -08:00
Alex Cheema
1fa4d3b087 fix: scope e2e CI triggers, add temperature=0, fail on missing snapshots
- Scope e2e workflow to only trigger on pushes to e2e-tests branch
  (not every branch push)
- Add temperature=0 to remaining snapshot test chat calls for
  deterministic output
- Make assert_snapshot fail when no baseline exists instead of silently
  creating one — baselines must be explicitly generated with
  UPDATE_SNAPSHOTS=1

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:01:27 -08:00
Alex Cheema
d9c884b9df fix: mark all inference snapshot tests as slow to fix CI timeout
Snapshot tests do MLX inference on x86 CPU in Docker which takes >600s
per test, causing the 45-minute CI job to timeout. Only cluster_formation
and no_internet (non-inference tests) should run in CI. Inference
snapshot tests can be run locally with --slow or E2E_SLOW=1.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:01:27 -08:00
Alex Cheema
ba934cf237 fix: resolve lint/format issues after merging main and fix pytest collection
Add root conftest.py to exclude tests/start_distributed_test.py from
pytest collection (it calls sys.exit at module level). Fix ruff lint
issues (import sorting, f-string without placeholders, lambda loop
variable capture) and apply nix fmt formatting to e2e files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:01:27 -08:00
Alex Cheema
54a036223b fix: add health check and heartbeat to RunnerSupervisor
Add proactive monitoring to detect runner process death and unresponsiveness:

- Health check loop polls is_alive() every 1s, detects unexpected exits
- Counter-based heartbeat detects frozen/unresponsive processes
- Emits RunnerFailed event and releases pending task waiters on failure
- Add EXO_RUNNER_MUST_DIE debug trigger for testing abrupt process death
- Add chaos E2E test that kills runner mid-inference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:01:08 -08:00
Alex Cheema
199fb1c9fa fix: enable MLX CPU inference on x86_64 Linux in Docker
Two issues prevented MLX CPU from working on x86_64 in Docker:

1. Missing BLAS/LAPACK libraries: MLX CPU backend requires libblas-dev,
   liblapack-dev, and liblapacke-dev on Linux. Added to apt-get install.

2. g++ wrapper ordering: The -fpermissive wrapper for GCC 14 was installed
   AFTER uv sync, but MLX may compile extensions during install. Moved
   the wrapper BEFORE uv sync so both build-time and runtime JIT
   compilation benefit from the fix.

MLX publishes manylinux_2_35_x86_64 wheels, so this uses the native
CPU backend — no alternative inference framework needed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
cc7850180d Add multi-model snapshot tests for model diversity
Add e2e snapshot test that exercises 3 different model architectures
to catch model-specific regressions:
- SmolLM2-135M-Instruct (tiny llama, bf16, ~269MB)
- Llama-3.2-1B-Instruct-4bit (small llama, 4bit, ~730MB)
- gemma-2-2b-it-4bit (gemma2 architecture, 4bit, ~1.5GB)

Each model gets its own snapshot file. All use the same prompt
("What is the capital of France?"), seed=42, max_tokens=32.

Also adds model cards for SmolLM2-135M-Instruct and gemma-2-2b-it-4bit
(Llama-3.2-1B-Instruct-4bit already had one).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
37af43e52f feat: add Docker layer caching to e2e CI with buildx + GHA cache
Pre-build the Docker image using docker/build-push-action with GitHub
Actions cache (type=gha). On cache hit, the image loads from cache
instead of rebuilding (~12min → seconds).

Changes:
- CI: set up buildx, build image with --cache-from/--cache-to type=gha
- docker-compose.yml: add image tag (exo-e2e:latest) so compose uses
  the pre-built image instead of rebuilding
- conftest.py: Cluster.build() skips if exo-e2e:latest already exists
  (pre-built in CI), falls back to docker compose build for local dev

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
effafc1d48 feat: make snapshot tests run on x86 Ubuntu CI without GPU
MLX already supports x86 CPU via mlx[cpu] and the Dockerfile has the
GCC workaround for CPU JIT. The only barriers were the 'slow' markers
causing tests to be skipped in CI.

Changes:
- Remove 'slow' marker from all snapshot tests so they run by default
- Make snapshots architecture-aware (snapshots/{arch}/{name}.json) since
  floating-point results differ between x86_64 and arm64
- Store architecture in snapshot metadata
- Increase CI timeout from 30 to 45 minutes for model download + CPU inference
- Update docstrings to remove Apple Silicon requirement

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
3fb663ec25 feat: add snapshot test cases for code gen, reasoning, long output, and edge cases
Expand e2e snapshot coverage beyond the single 'What is 2+2?' test:
- test_snapshot_code_gen.py: code generation prompt (max_tokens=64)
- test_snapshot_reasoning.py: step-by-step math reasoning (max_tokens=64)
- test_snapshot_long_output.py: longer response with max_tokens=128
- test_snapshot_edge.py: single word, special chars, and unicode prompts

All use seed=42 and the shared assert_snapshot() infrastructure.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
c2f9034914 feat: add reusable snapshot regression testing to e2e framework
Add e2e/snapshot.py with assert_snapshot() for deterministic regression
testing. On first run, saves inference output as the expected snapshot.
On subsequent runs, compares against it with unified diff on mismatch.
Set UPDATE_SNAPSHOTS=1 or pass --update-snapshots to regenerate.

Refactor test_inference_snapshot.py to use the shared infrastructure
and drop temperature=0 in favor of seed-only determinism.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
2cce4e8f04 fix: add root conftest.py to exclude start_distributed_test from pytest collection
The tests/start_distributed_test.py script calls sys.exit() at module
level, which crashes pytest collection. Exclude it via collect_ignore.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
ce82486a79 fix: ruff lint and formatting for e2e test files
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
6e7dc0b042 fix: skip slow inference test in CI, run with --slow
MLX CPU inference on x86_64 is too slow for CI runners (~10min+ for
a single request). Mark the inference snapshot test as slow so it's
skipped by default. Run with --slow or E2E_SLOW=1 on Apple Silicon.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
acdf2751a3 feat: add deterministic inference snapshot test
Launch mlx-community/Qwen3-0.6B-4bit on the cluster, send a chat
completion with seed=42 and temperature=0, and verify the output
matches a committed snapshot. Tests inference determinism end-to-end.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
0035ac3cf2 fix: make no_internet test actually block internet with iptables
Use iptables to block all outbound traffic except private subnets and
multicast (for mDNS discovery). Verify internet is blocked by curling
huggingface.co from inside each container and checking exo logs for
"Internet connectivity: False".

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
26621503c8 fix: reduce Docker image size and free more CI disk space
Clean up Rust target/ and cargo registry after uv sync in the same RUN
command so build artifacts aren't committed to the layer (~1-2 GB saved).
Also remove more unused toolchains from the CI runner.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
b9dcea2b4e fix: free disk space in CI before Docker build
The runner was running out of disk space during the Docker image build
(Rust compilation + Python deps). Remove unused toolchains first.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
Alex Cheema
9bcc0f968b feat: add Docker-based E2E test framework
Add a Python/asyncio E2E test framework that spins up 2-node exo clusters
in Docker Compose and verifies cluster formation, discovery, election, and
API health. Includes a no-internet chaos test using DNS blocking.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:00:04 -08:00
71 changed files with 4056 additions and 1610 deletions

15
.dockerignore Normal file
View File

@@ -0,0 +1,15 @@
.venv/
.direnv/
target/
.git/
.idea/
.pytest_cache/
.ruff_cache/
dashboard/node_modules/
dashboard/.svelte-kit/
dashboard/build/
dist/
*.pdb
**/__pycache__
**/.DS_Store
.mlx_typings/

44
.github/workflows/e2e.yml vendored Normal file
View File

@@ -0,0 +1,44 @@
name: e2e-tests
on:
push:
branches:
- e2e-tests
pull_request:
branches:
- staging
- main
jobs:
e2e:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
/opt/microsoft /opt/az
docker system prune -af
df -h /
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build E2E image with cache
uses: docker/build-push-action@v6
with:
context: .
file: e2e/Dockerfile
tags: exo-e2e:latest
load: true
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Run E2E tests
run: python3 e2e/run_all.py

275
Cargo.lock generated
View File

@@ -125,9 +125,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.101"
version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "arc-swap"
@@ -141,6 +141,12 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.7.1"
@@ -165,7 +171,7 @@ checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
"synstructure",
]
@@ -177,7 +183,7 @@ checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -224,7 +230,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -298,6 +304,19 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
[[package]]
name = "bigdecimal"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
dependencies = [
"autocfg",
"libm",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.3"
@@ -421,9 +440,9 @@ dependencies = [
[[package]]
name = "chrono"
version = "0.4.43"
version = "0.4.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118"
checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2"
dependencies = [
"iana-time-zone",
"js-sys",
@@ -497,6 +516,15 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
[[package]]
name = "convert_case"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -644,7 +672,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -670,7 +698,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d162beedaa69905488a8da94f5ac3edb4dd4788b732fadb7bd120b2625c1976"
dependencies = [
"data-encoding",
"syn 1.0.109",
"syn 2.0.111",
]
[[package]]
@@ -681,7 +709,7 @@ checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -718,6 +746,29 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_more"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.111",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -738,7 +789,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -820,7 +871,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -888,17 +939,22 @@ name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"env_logger",
"extend",
"futures-lite",
"futures",
"impl-trait-for-tuples",
"libp2p",
"log",
"networking",
"once_cell",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"thiserror 2.0.17",
"thread_local",
"tokio",
"util",
]
@@ -911,15 +967,9 @@ checksum = "311a6d2f1f9d60bff73d2c78a0af97ed27f79672f15c238192a5bbb64db56d00"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
version = "0.13.1"
@@ -1028,10 +1078,7 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
@@ -1043,7 +1090,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -1593,6 +1640,17 @@ dependencies = [
"xmltree",
]
[[package]]
name = "impl-trait-for-tuples"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "indexmap"
version = "2.12.1"
@@ -1657,7 +1715,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -1711,7 +1769,7 @@ checksum = "980af8b43c3ad5d8d349ace167ec8170839f753a42d233ba19e08afe1850fa69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -1771,6 +1829,12 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libp2p"
version = "0.56.0"
@@ -2300,7 +2364,7 @@ checksum = "dd297cf53f0cb3dee4d2620bb319ae47ef27c702684309f682bdb7e55a18ae9c"
dependencies = [
"heck",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -2760,13 +2824,16 @@ name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures-lite",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"libp2p",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
@@ -2838,9 +2905,9 @@ dependencies = [
[[package]]
name = "num-conv"
version = "0.2.0"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
@@ -2851,6 +2918,17 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -3053,7 +3131,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3165,9 +3243,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.106"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
@@ -3192,7 +3270,7 @@ checksum = "440f724eba9f6996b75d63681b0a92b06947f1457076d503a4d2e2c8f56442b8"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3201,14 +3279,28 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
dependencies = [
"bigdecimal",
"either",
"hashbrown 0.16.1",
"indexmap",
"indoc",
"inventory",
"libc",
"lock_api",
"memoffset",
"num-bigint",
"num-complex",
"num-rational",
"num-traits",
"once_cell",
"ordered-float",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"rust_decimal",
"smallvec",
"unindent",
]
@@ -3236,7 +3328,7 @@ checksum = "bcd7d70ee0ca1661c40407e6f84e4463ef2658c90a9e2fbbd4515b2bcdfcaeca"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3278,7 +3370,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3291,14 +3383,14 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
name = "pyo3-stub-gen"
version = "0.19.0"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b159f7704044f57d058f528a6f1f22a0a0a327dcb595c5fb38beae658e0338d6"
checksum = "398b833826a83ca72c1e26d1b2c7c71f9ca7c3bfc74eacc663901895c362ae33"
dependencies = [
"anyhow",
"chrono",
@@ -3313,25 +3405,22 @@ dependencies = [
"ordered-float",
"pyo3",
"pyo3-stub-gen-derive",
"rustpython-parser",
"serde",
"serde_json",
"time",
"toml",
]
[[package]]
name = "pyo3-stub-gen-derive"
version = "0.19.0"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8c79e7c5b1fcec7c39ab186594658a971c59911eb6fbab5a5932cf2318534be"
checksum = "2426ba759d848787239d80f9fdb1f223786976f87fb6c3da8188ca7c17744b28"
dependencies = [
"heck",
"indexmap",
"proc-macro2",
"quote",
"rustpython-parser",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3414,9 +3503,9 @@ dependencies = [
[[package]]
name = "quote"
version = "1.0.44"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
@@ -3652,6 +3741,16 @@ dependencies = [
"tokio",
]
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -3887,7 +3986,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -3905,9 +4004,9 @@ dependencies = [
[[package]]
name = "serde_spanned"
version = "1.0.4"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776"
checksum = "e24345aa0fe688594e73770a5f6d1b216508b4f93484c0026d521acd30134392"
dependencies = [
"serde_core",
]
@@ -4095,9 +4194,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.116"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3df424c70518695237746f84cede799c9c58fcb37450d7b23716568cc8bc69cb"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
@@ -4112,7 +4211,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4188,7 +4287,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4199,7 +4298,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4213,30 +4312,30 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.47"
version = "0.3.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c"
checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d"
dependencies = [
"deranged",
"itoa",
"num-conv",
"powerfmt",
"serde_core",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.8"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca"
checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b"
[[package]]
name = "time-macros"
version = "0.2.27"
version = "0.2.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215"
checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3"
dependencies = [
"num-conv",
"time-core",
@@ -4312,7 +4411,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4330,9 +4429,9 @@ dependencies = [
[[package]]
name = "toml"
version = "1.0.2+spec-1.1.0"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1dfefef6a142e93f346b64c160934eb13b5594b84ab378133ac6815cb2bd57f"
checksum = "f0dc8b1fb61449e27716ec0e1bdf0f6b8f3e8f6b05391e8497b8b6d7804ea6d8"
dependencies = [
"indexmap",
"serde_core",
@@ -4345,27 +4444,27 @@ dependencies = [
[[package]]
name = "toml_datetime"
version = "1.0.0+spec-1.1.0"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e"
checksum = "f2cdb639ebbc97961c51720f858597f7f24c4fc295327923af55b74c3c724533"
dependencies = [
"serde_core",
]
[[package]]
name = "toml_parser"
version = "1.0.9+spec-1.1.0"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4"
checksum = "c0cbe268d35bdb4bb5a56a2de88d0ad0eb70af5384a99d648cd4b3d04039800e"
dependencies = [
"winnow",
]
[[package]]
name = "toml_writer"
version = "1.0.6+spec-1.1.0"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607"
checksum = "df8b2b54733674ad286d16267dcfc7a71ed5c776e4ac7aa3c3e2561f7c637bf2"
[[package]]
name = "tower-service"
@@ -4392,7 +4491,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4516,12 +4615,24 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"
@@ -4704,7 +4815,7 @@ dependencies = [
"bumpalo",
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
"wasm-bindgen-shared",
]
@@ -4846,7 +4957,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4857,7 +4968,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4868,7 +4979,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -4879,7 +4990,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -5268,7 +5379,7 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
"synstructure",
]
@@ -5289,7 +5400,7 @@ checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -5309,7 +5420,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
"synstructure",
]
@@ -5330,7 +5441,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]
[[package]]
@@ -5363,5 +5474,5 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
"syn 2.0.111",
]

View File

@@ -26,22 +26,49 @@ opt-level = 3
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-lite = "2.6.1"
futures-util = "0.3"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"

View File

@@ -72,23 +72,16 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
```bash
nix run .#exo
```
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```

View File

@@ -126,37 +126,11 @@ final class ExoProcessController: ObservableObject {
return
}
process.terminationHandler = nil
status = .stopped
guard process.isRunning else {
self.process = nil
return
if process.isRunning {
process.terminate()
}
let proc = process
self.process = nil
Task.detached {
proc.interrupt()
for _ in 0..<50 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
proc.terminate()
}
for _ in 0..<30 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
kill(proc.processIdentifier, SIGKILL)
}
}
status = .stopped
}
func restart() {

View File

@@ -1,7 +0,0 @@
# Canary benchmark manifest
#
# Lists the suite files to include. Each file defines benchmarks
# with shared constraints, topology, and default args.
include = [
"single-m3-ultra.toml",
]

View File

@@ -1,189 +0,0 @@
# Single-node M3 Ultra benchmarks
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=1)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
]
[topology]
type = "none"
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]

1
conftest.py Normal file
View File

@@ -0,0 +1 @@
collect_ignore = ["tests/start_distributed_test.py"]

View File

File diff suppressed because it is too large Load Diff

58
e2e/Dockerfile Normal file
View File

@@ -0,0 +1,58 @@
# Stage 1: Build the dashboard
FROM node:22-slim AS dashboard
WORKDIR /app/dashboard
COPY dashboard/package.json dashboard/package-lock.json ./
RUN npm ci
COPY dashboard/ .
RUN npm run build
# Stage 2: Build and run exo
FROM python:3.13-slim
# Install system dependencies
# libblas-dev/liblapack-dev/liblapacke-dev are required by MLX CPU backend on Linux
RUN apt-get update && apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
libblas-dev \
liblapack-dev \
liblapacke-dev \
curl \
protobuf-compiler \
iptables \
&& rm -rf /var/lib/apt/lists/*
# Install Rust nightly
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
ENV PATH="/root/.cargo/bin:${PATH}"
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
# Must be done BEFORE uv sync so any source builds also get the fix
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
chmod +x /usr/bin/g++
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files first for better layer caching
COPY pyproject.toml Cargo.toml uv.lock README.md ./
COPY rust/ ./rust/
COPY bench/pyproject.toml ./bench/pyproject.toml
# Copy source and resources
COPY src/ ./src/
COPY resources/ ./resources/
# Copy built dashboard from stage 1
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
# Install Python deps and build Rust bindings, then clean up build artifacts
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
CMD [".venv/bin/exo", "-v"]

195
e2e/conftest.py Normal file
View File

@@ -0,0 +1,195 @@
"""Shared E2E test infrastructure for exo cluster tests."""
import asyncio
import json
import os
import sys
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
E2E_DIR = Path(__file__).parent.resolve()
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
class Cluster:
"""Async wrapper around a docker compose exo cluster."""
def __init__(self, name: str, overrides: list[str] | None = None):
self.name = name
self.project = f"e2e-{name}"
compose_files = [str(E2E_DIR / "docker-compose.yml")]
for path in overrides or []:
compose_files.append(str(E2E_DIR / path))
self._compose_base = [
"docker",
"compose",
"-p",
self.project,
*[arg for f in compose_files for arg in ("-f", f)],
]
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
await self.stop()
async def _run(self, *args: str, check: bool = True) -> str:
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
print(output, file=sys.stderr)
raise RuntimeError(
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
)
return output
async def build(self):
# Skip build if the image was pre-built (e.g. in CI with buildx cache)
proc = await asyncio.create_subprocess_exec(
"docker",
"image",
"inspect",
"exo-e2e:latest",
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await proc.wait()
if proc.returncode == 0:
print(" Using pre-built image (exo-e2e:latest)")
return
print(" Building images...")
await self._run("build", "--quiet")
async def start(self):
print(" Starting cluster...")
await self._run("up", "-d")
async def stop(self):
print(" Cleaning up...")
await self._run("down", "--timeout", "5", check=False)
async def logs(self) -> str:
return await self._run("logs", check=False)
async def exec(
self, service: str, *cmd: str, check: bool = True
) -> tuple[int, str]:
"""Run a command inside a running container. Returns (returncode, output)."""
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
"exec",
"-T",
service,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
raise RuntimeError(
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
)
return proc.returncode, output
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
"""Poll check_fn every 2s until it returns True or timeout expires."""
print(f" Waiting for {description}...")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if await check_fn():
print(f" {description}")
return
await asyncio.sleep(2)
output = await self.logs()
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
raise TimeoutError(f"Timed out waiting for {description}")
async def assert_healthy(self):
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
async def both_nodes_started():
log = await self.logs()
return log.count("Starting node") >= 2
async def nodes_discovered():
log = await self.logs()
return log.count("ConnectionMessageType.Connected") >= 2
async def master_elected():
log = await self.logs()
return "demoting self" in log
async def api_responding():
try:
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
return resp.status == 200
except (URLError, OSError):
return False
await self.wait_for("Both nodes started", both_nodes_started)
await self.wait_for("Nodes discovered each other", nodes_discovered)
await self.wait_for("Master election resolved", master_elected)
await self.wait_for("API responding", api_responding)
async def _api(
self, method: str, path: str, body: dict | None = None, timeout: int = 30
) -> dict:
"""Make an API request to the cluster. Returns parsed JSON."""
url = f"http://localhost:52415{path}"
data = json.dumps(body).encode() if body else None
req = Request(
url, data=data, headers={"Content-Type": "application/json"}, method=method
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda: urlopen(req, timeout=timeout).read()
)
return json.loads(resp_bytes)
async def place_model(self, model: str, timeout: int = 600):
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
await self._api("POST", "/place_instance", {"model_id": model})
async def model_ready():
try:
resp = await self._api("GET", "/v1/models")
return any(m.get("id") == model for m in resp.get("data", []))
except Exception:
return False
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
async def chat(
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
) -> dict:
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
deadline = asyncio.get_event_loop().time() + timeout
last_error = None
while asyncio.get_event_loop().time() < deadline:
try:
req = Request(
"http://localhost:52415/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda r=req: urlopen(r, timeout=300).read()
)
return json.loads(resp_bytes)
except Exception as e:
last_error = e
await asyncio.sleep(5)
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")

20
e2e/docker-compose.yml Normal file
View File

@@ -0,0 +1,20 @@
services:
exo-node-1:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]
ports:
- "52415:52415"
exo-node-2:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]

83
e2e/run_all.py Normal file
View File

@@ -0,0 +1,83 @@
#!/usr/bin/env python3
"""Discovers and runs all E2E tests in e2e/test_*.py.
Tests with '# slow' on the first line of their docstring are skipped
unless --slow is passed or E2E_SLOW=1 is set.
"""
import os
import subprocess
import sys
from pathlib import Path
E2E_DIR = Path(__file__).parent.resolve()
def is_slow(test_file: Path) -> bool:
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
with open(test_file) as f:
for line in f:
if line.strip().startswith("#"):
continue
if line.strip().startswith('"""') or line.strip().startswith("'''"):
# Read into the docstring
for doc_line in f:
if "slow" in doc_line.lower() and doc_line.strip().startswith(
"slow"
):
return True
if '"""' in doc_line or "'''" in doc_line:
break
break
return False
def main():
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
if "--update-snapshots" in sys.argv:
os.environ["UPDATE_SNAPSHOTS"] = "1"
test_files = sorted(E2E_DIR.glob("test_*.py"))
if not test_files:
print("No test files found")
sys.exit(1)
passed = 0
failed = 0
skipped = 0
failures = []
for test_file in test_files:
name = test_file.stem
if is_slow(test_file) and not run_slow:
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
skipped += 1
continue
print(f"=== {name} ===")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
# Retry once — Docker networking (mDNS) can be slow on first boot
print(f"\n=== {name} === RETRYING (attempt 2/2)")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
failed += 1
failures.append(name)
print()
total = passed + failed + skipped
print("================================")
print(
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
)
if failed:
print(f"Failed: {' '.join(failures)}")
sys.exit(1)
if __name__ == "__main__":
main()

78
e2e/snapshot.py Normal file
View File

@@ -0,0 +1,78 @@
"""Snapshot testing infrastructure for E2E tests.
Provides deterministic regression testing by comparing inference output
against committed baseline snapshots. Tests FAIL if no baseline exists —
baselines must be explicitly generated and committed.
Generate baselines: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow
Update after intentional changes: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow
Snapshots are stored per-architecture (e.g. snapshots/x86_64/, snapshots/arm64/)
since floating-point results differ between CPU architectures.
"""
import difflib
import json
import os
import platform
from pathlib import Path
ARCH = platform.machine()
SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" / ARCH
def assert_snapshot(
name: str,
content: str,
metadata: dict,
) -> None:
"""Compare content against a saved snapshot, or create one if missing.
Args:
name: Snapshot identifier (used as filename: snapshots/{arch}/{name}.json).
content: The actual inference output to compare.
metadata: Additional context stored alongside content (model, seed, etc.).
Not used for comparison -- purely documentary.
Raises:
AssertionError: If content doesn't match the saved snapshot.
Environment:
UPDATE_SNAPSHOTS=1: Overwrite existing snapshot with actual content.
"""
snapshot_file = SNAPSHOTS_DIR / f"{name}.json"
update = os.environ.get("UPDATE_SNAPSHOTS") == "1"
if update:
# Explicitly regenerate snapshot
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
snapshot_data = {**metadata, "arch": ARCH, "content": content}
snapshot_file.write_text(json.dumps(snapshot_data, indent=2) + "\n")
print(f" Updated snapshot: {ARCH}/{snapshot_file.name}")
elif not snapshot_file.exists():
raise AssertionError(
f"No baseline snapshot for '{name}' on {ARCH}.\n"
f"Expected file: {snapshot_file}\n\n"
f"Generate baselines with: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow"
)
else:
snapshot = json.loads(snapshot_file.read_text())
expected = snapshot["content"]
if content != expected:
diff = "\n".join(
difflib.unified_diff(
expected.splitlines(),
content.splitlines(),
fromfile=f"expected ({snapshot_file.relative_to(SNAPSHOTS_DIR.parent.parent)})",
tofile="actual",
lineterm="",
)
)
raise AssertionError(
f"Snapshot mismatch for '{name}' on {ARCH}!\n\n"
f"{diff}\n\n"
f"Expected: {expected!r}\n"
f"Actual: {content!r}\n\n"
f"To update: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow"
)
print(f" Output matches snapshot ({ARCH}/{snapshot_file.name})")

View File

@@ -0,0 +1,22 @@
"""Test: Basic cluster formation.
Verifies two nodes discover each other, elect a master, and the API responds.
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster("cluster_formation") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print("PASSED: cluster_formation")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,61 @@
"""Test: Deterministic inference output (snapshot test).
Sends a chat completion request with a fixed seed,
then verifies the output matches a known-good snapshot. This ensures
inference produces consistent results across runs.
Uses MLX CPU backend in Docker on x86 Linux.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "What is 2+2? Reply with just the number."
MAX_TOKENS = 32
async def main():
async with Cluster("inference_snapshot") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="inference_snapshot",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: inference_snapshot")
if __name__ == "__main__":
asyncio.run(main())

47
e2e/test_no_internet.py Normal file
View File

@@ -0,0 +1,47 @@
"""Test: Cluster works without internet access.
Verifies exo functions correctly when containers can talk to each other
but cannot reach the internet. Uses iptables to block all outbound traffic
except private subnets and multicast (for mDNS discovery).
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster(
"no_internet",
overrides=["tests/no_internet/docker-compose.override.yml"],
) as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Verify internet is actually blocked from inside the containers
for node in ["exo-node-1", "exo-node-2"]:
rc, _ = await cluster.exec(
node,
"curl",
"-sf",
"--max-time",
"3",
"https://huggingface.co",
check=False,
)
assert rc != 0, f"{node} should not be able to reach the internet"
print(f" {node}: internet correctly blocked")
# Verify exo detected no internet connectivity
log = await cluster.logs()
assert "Internet connectivity: False" in log, "exo should detect no internet"
print(" exo correctly detected no internet connectivity")
print("PASSED: no_internet")
if __name__ == "__main__":
asyncio.run(main())

65
e2e/test_runner_chaos.py Normal file
View File

@@ -0,0 +1,65 @@
"""Test: Runner chaos — abrupt runner death detection.
slow
Sends a chat completion with the EXO_RUNNER_MUST_DIE trigger, which causes
the runner process to call os._exit(1) (simulating an OOM kill). Verifies that
the RunnerSupervisor health check detects the death and the system doesn't hang.
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import contextlib
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
async def main():
async with Cluster("runner_chaos") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Place the model so a runner is loaded and ready
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
# Send a chat request with the die trigger.
# The runner will call os._exit(1) mid-inference, simulating OOM kill.
# The chat request itself will fail — that's expected.
print(" Sending EXO_RUNNER_MUST_DIE trigger...")
with contextlib.suppress(Exception):
await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": "EXO RUNNER MUST DIE"}],
timeout=60,
)
# Wait for the health check to detect the death and emit RunnerFailed
async def health_check_detected():
log = await cluster.logs()
return "runner process died unexpectedly" in log
await cluster.wait_for(
"Health check detected runner death",
health_check_detected,
timeout=30,
)
# Verify RunnerFailed was emitted (visible in logs)
log = await cluster.logs()
assert "runner process died unexpectedly" in log, (
f"Expected health check to detect runner death but it didn't.\nLogs:\n{log}"
)
print("PASSED: runner_chaos")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,60 @@
"""Test: Code generation snapshot.
slow
Verifies deterministic output for a code generation prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = (
"Write a Python function to reverse a string. Only output the code, no explanation."
)
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_code_gen") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_code_gen",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_code_gen")
if __name__ == "__main__":
asyncio.run(main())

65
e2e/test_snapshot_edge.py Normal file
View File

@@ -0,0 +1,65 @@
"""Test: Edge case snapshots.
slow
Verifies deterministic output for edge-case prompts: single word input,
special characters, and unicode.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
MAX_TOKENS = 32
CASES = [
("edge_single_word", "Hi"),
("edge_special_chars", "What does 2 * (3 + 4) / 7 - 1 equal? Use <math> tags."),
("edge_unicode", "Translate 'hello' to Japanese, Chinese, and Korean."),
]
async def main():
async with Cluster("snapshot_edge") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
for snapshot_name, prompt in CASES:
print(f" [{snapshot_name}] Sending: {prompt!r}")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": prompt}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{snapshot_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": prompt,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_edge")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,58 @@
"""Test: Longer output snapshot.
slow
Verifies deterministic output with a higher max_tokens (128).
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "Explain how a binary search algorithm works."
MAX_TOKENS = 128
async def main():
async with Cluster("snapshot_long_output") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED}, max_tokens={MAX_TOKENS})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_long_output",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_long_output")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,73 @@
"""Test: Multi-model snapshot tests.
slow
Verifies deterministic output across different model architectures to catch
model-specific regressions. Each model uses its own snapshot file.
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
SEED = 42
PROMPT = "What is the capital of France?"
MAX_TOKENS = 32
MODELS = [
"mlx-community/SmolLM2-135M-Instruct",
"mlx-community/Llama-3.2-1B-Instruct-4bit",
"mlx-community/gemma-2-2b-it-4bit",
]
async def main():
async with Cluster("snapshot_multi_model") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
for model in MODELS:
short_name = (
model.split("/")[-1].lower().replace("-", "_").replace(".", "_")
)
snapshot_name = f"snapshot_multi_{short_name}"
print(f" Launching model {model}...")
await cluster.place_model(model)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=model,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{short_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": model,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print(f" [{short_name}] PASSED")
print("PASSED: snapshot_multi_model")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,58 @@
"""Test: Reasoning/math snapshot.
slow
Verifies deterministic output for a simple reasoning prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "If I have 3 apples and give away 1, how many do I have? Think step by step."
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_reasoning") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_reasoning",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_reasoning")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,32 @@
# Block all outbound internet traffic using iptables while preserving:
# - Multicast (224.0.0.0/4) for mDNS peer discovery
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
# - Loopback (127/8)
# Requires NET_ADMIN capability for iptables.
services:
exo-node-1:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v
exo-node-2:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v

View File

@@ -115,7 +115,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260217+50487b41"; in
version = let v = "0.30.6"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
};
patches = [

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.7",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -64,7 +64,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
@@ -133,7 +132,7 @@ markers = [
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -58,21 +58,6 @@
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
@@ -89,25 +74,14 @@
linuxOverlay
]
);
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/SmolLM2-135M-Instruct"
n_layers = 30
hidden_size = 576
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "SmolLM2 135M"
capabilities = ["text"]
[storage_size]
in_bytes = 269060381

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/gemma-2-2b-it-4bit"
n_layers = 26
hidden_size = 2304
supports_tensor = false
tasks = ["TextGeneration"]
family = "gemma2"
quantization = "4bit"
base_model = "Gemma 2 2B"
capabilities = ["text"]
[storage_size]
in_bytes = 1492755242

View File

@@ -25,38 +25,53 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
# "nightly", # enables better-supported GIL integration
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
# "ordered-float", "rust_decimal", "smallvec",
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.19.0" }
pyo3-stub-gen = { version = "0.17.2" }
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures-lite = { workspace = true }
futures = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,85 +1,155 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401, F403, F405
# ruff: noqa: E501, F401
import builtins
import enum
import typing
__all__ = [
"AllQueuesFullError",
"Keypair",
"NoPeersSubscribedToTopicError",
"PyMessage",
"PySwarm",
]
@typing.final
class AllQueuesFullError(builtins.Exception):
def __new__(cls, *_a: typing.Any) -> AllQueuesFullError: ...
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
@typing.final
class Keypair:
r"""
Identity keypair of a node.
"""
@staticmethod
def generate() -> Keypair:
def generate_ed25519() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def deserialize(bytes: bytes) -> Keypair:
def generate_ecdsa() -> Keypair:
r"""
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
def serialize(self) -> bytes:
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_string(self) -> builtins.str:
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
def __new__(cls, *_a: typing.Any) -> NoPeersSubscribedToTopicError: ...
def __str__(self) -> builtins.str: ...
class PyMessage:
@typing.final
class Connection(PyMessage):
__match_args__ = ("node_id", "connected",)
@property
def node_id(self) -> builtins.str: ...
@property
def connected(self) -> builtins.bool: ...
def __new__(cls, node_id: builtins.str, connected: builtins.bool) -> PyMessage.Connection: ...
@typing.final
class Gossip(PyMessage):
__match_args__ = ("node_id", "topic", "data",)
@property
def node_id(self) -> builtins.str: ...
@property
def topic(self) -> builtins.str: ...
@property
def data(self) -> bytes: ...
def __new__(cls, node_id: builtins.str, topic: builtins.str, data: bytes) -> PyMessage.Gossip: ...
...
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
"""
def to_string(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
"""
@typing.final
class PySwarm:
def __new__(cls, identity: Keypair) -> PySwarm: ...
async def recv(self) -> PyMessage:
class NetworkingHandle:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next message from networking.
Receives the next `ConnectionUpdate` from networking.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> None:
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
"""
async def gossipsub_unsubscribe(self, topic: builtins.str) -> None:
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Unsubscribes from a `GossipSub` topic.
@@ -87,6 +157,65 @@ class PySwarm:
"""
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
r"""
Publishes a message to the network on a specific topic.
Publishes a message with multiple topics to the `GossipSub` network.
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...

View File

@@ -1,4 +1,8 @@
//! See: <https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await>
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
@@ -6,17 +10,31 @@ use std::{
task::{Context, Poll},
};
pub struct AllowThreads<F>(pub(crate) F);
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
#[pin_project]
#[repr(transparent)]
pub(crate) struct AllowThreads<F>(#[pin] F);
impl<F> AllowThreads<F>
where
Self: Future,
{
pub fn new(f: F) -> Self {
Self(f)
}
}
impl<F> Future for AllowThreads<F>
where
F: Future + Unpin + Send,
F::Output: Send,
F: Future + Ungil,
F::Output: Ungil,
{
type Output = F::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
}
}

View File

@@ -0,0 +1,240 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -1,47 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn deserialize(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn serialize<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_string(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
Ok(())
}

View File

@@ -4,13 +4,28 @@
//!
//!
mod allow_threading;
mod ident;
mod networking;
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
use crate::ident::ident_submodule;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -19,18 +34,35 @@ pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyResult, Python};
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::attach(|py| PyBytes::new(py, self).unbind())
Python::with_gil(|py| PyBytes::new(py, self).unbind())
}
}
@@ -45,16 +77,120 @@ pub(crate) mod ext {
}
pub trait FutureExt: Future + Sized {
/// SEE: https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
fn allow_threads_py(self) -> AllowThreads<Self>
where
AllowThreads<Self>: Future,
{
AllowThreads(self)
AllowThreads::new(self)
}
}
impl<T: Future> FutureExt for T {}
#[ext(pub, name = PyErrExt)]
impl PyErr {
fn receiver_channel_closed() -> Self {
PyConnectionError::new_err("Receiver channel closed unexpectedly")
}
}
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
match self {
Ok(v) => Some(v),
Err(e) => {
// write error back to python
e.write_unraisable(py, None);
None
}
}
}
}
#[ext(pub, name = TokioRuntimeExt)]
impl Runtime {
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
}
}
#[ext(pub, name = TokioMpscSenderExt)]
impl<T> mpsc::Sender<T> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
/// channel has not hung up already. An unsuccessful send would be one where
/// the corresponding receiver has already been closed.
async fn send_py(&self, value: T) -> PyResult<()> {
self.send(value)
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
}
#[ext(pub, name = TokioMpscReceiverExt)]
impl<T> mpsc::Receiver<T> {
/// Receives the next value for this receiver.
async fn recv_py(&mut self) -> PyResult<T> {
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
}
/// Receives at most `limit` values for this receiver and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
// get updates from receiver channel
let mut updates = Vec::with_capacity(limit);
let received = self.recv_many(&mut updates, limit).await;
// if we received zero items, then the channel was unexpectedly closed
if limit != 0 && received == 0 {
return Err(PyErr::receiver_channel_closed());
}
Ok(updates)
}
/// Tries to receive the next value for this receiver.
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
match self.try_recv() {
Ok(v) => Ok(Some(v)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
}
}
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
@@ -65,9 +201,16 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
// top-level constructs
// TODO: ...
Ok(())
}

View File

@@ -1,24 +1,31 @@
#![allow(
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::ResultExt as _;
use crate::ext::{ByteArrayExt as _, FutureExt as _};
use crate::ident::PyKeypair;
use crate::networking::exception::{PyAllQueuesFullError, PyNoPeersSubscribedToTopicError};
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::pyclass;
use futures_lite::FutureExt as _;
use networking::swarm::{FromSwarm, Swarm, ToSwarm};
use pyo3::coroutine::CancelHandle;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::prelude::*;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3_async_runtimes::tokio::get_runtime;
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods};
use std::pin::pin;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
use pyo3::types::PyTuple;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
@@ -42,11 +49,16 @@ mod exception {
#[pymethods]
impl PyNoPeersSubscribedToTopicError {
#[new]
#[pyo3(signature = (*_a))]
pub(crate) fn new(_a: &Bound<'_, PyTuple>) -> Self {
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
@@ -72,179 +84,489 @@ mod exception {
#[pymethods]
impl PyAllQueuesFullError {
#[new]
#[pyo3(signature = (*_a))]
pub(crate) fn new(_a: &Bound<'_, PyTuple>) -> Self {
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
}
#[gen_stub_pyclass]
#[pyclass]
struct PySwarm {
swarm: Arc<Mutex<Swarm>>,
from_swarm: Mutex<mpsc::Receiver<FromSwarm>>,
to_swarm: Mutex<mpsc::Sender<ToSwarm>>,
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
#[gen_stub_pyclass_complex_enum]
#[pyclass]
pub enum PyMessage {
Connection {
node_id: String,
connected: bool,
},
Gossip {
node_id: String,
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
data: Py<PyBytes>,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
impl TryFrom<FromSwarm> for PyMessage {
type Error = PyErr;
fn try_from(value: FromSwarm) -> Result<Self, Self::Error> {
match value {
FromSwarm::Discovered(nid) => Ok(PyMessage::Connection {
node_id: nid.to_base58(),
connected: true,
}),
FromSwarm::Expired(nid) => Ok(PyMessage::Connection {
node_id: nid.to_base58(),
connected: false,
}),
FromSwarm::Message(nid, topic, data) => Ok(PyMessage::Gossip {
node_id: nid.to_base58(),
topic,
data: data.pybytes(),
}),
FromSwarm::PublishError(e) => match e {
libp2p::gossipsub::PublishError::NoPeersSubscribedToTopic => {
Err(PyNoPeersSubscribedToTopicError::new_err())
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyAllQueuesFullError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
libp2p::gossipsub::PublishError::AllQueuesFull(_) => {
Err(PyAllQueuesFullError::new_err())
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
e => Err(PyRuntimeError::new_err(e.to_string())),
},
}
}
}
log::info!("RUST: networking task stopped");
}
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle {
// channels
to_task_tx: Option<mpsc::Sender<ToTask>>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
}
impl Drop for PyNetworkingHandle {
fn drop(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
}
#[allow(clippy::expect_used)]
impl PyNetworkingHandle {
fn new(
to_task_tx: mpsc::Sender<ToTask>,
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
) -> Self {
Self {
to_task_tx: Some(to_task_tx),
connection_update_rx: Mutex::new(connection_update_rx),
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
}
}
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
self.to_task_tx
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PySwarm {
impl PyNetworkingHandle {
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
// immediately beforehand to release the interpreter.
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
// ---- Lifecycle management methods ----
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (to_client, from_swarm) = mpsc::channel(MPSC_CHANNEL_SIZE);
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { Swarm::new(identity, from_client, to_client) })
.block_on(async { create_swarm(identity) })
.pyerr()?;
Ok(Self {
swarm: Arc::new(Mutex::new(swarm)),
from_swarm: Mutex::new(from_swarm),
to_swarm: Mutex::new(to_swarm),
})
// spawn tokio task running the networking logic
get_runtime().spawn(async move {
networking_task(
swarm,
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
}
#[gen_stub(skip)]
async fn run(&self, #[pyo3(cancel_handle)] mut cancel: CancelHandle) -> PyResult<()> {
let copy = Arc::clone(&self.swarm);
let jh = get_runtime().spawn(async move {
copy.try_lock()
.expect("tried to run swarm twice")
.run()
.await
});
jh.or(async {
cancel.cancelled().await;
Ok(())
})
.await
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
#[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
// ---- Connection update receiver methods ----
/// Receives the next message from networking.
async fn recv(&self) -> PyResult<PyMessage> {
let msg = pin!(
self.from_swarm
.try_lock()
.expect("called recv concurrently")
.recv()
)
.allow_threads_py()
.await;
match msg {
None => Err(PyConnectionError::new_err("swarm closed")),
Some(msg) => msg.try_into(),
}
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
///
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
/// will sleep until a `ConnectionUpdate`s is sent.
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next `ConnectionUpdate` from networking.
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic.
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<()> {
///
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
pin!(
self.to_swarm
.try_lock()
.expect("called send concurrently")
.send(ToSwarm::Subscribe(topic))
)
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyConnectionError::new_err("swarm closed"))
self.to_task_tx()
.send_py(ToTask::GossipsubSubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())?
}
/// Unsubscribes from a `GossipSub` topic.
///
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<()> {
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
pin!(
self.to_swarm
.try_lock()
.expect("called send concurrently")
.send(ToSwarm::Unsubscribe(topic))
)
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyConnectionError::new_err("swarm closed"))
self.to_task_tx()
.send_py(ToTask::GossipsubUnsubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & convert any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
/// Publishes a message to the network on a specific topic.
/// Publishes a message with multiple topics to the `GossipSub` network.
///
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
pin!(
self.to_swarm
.try_lock()
.expect("called send concurrently")
.send(ToSwarm::Message(topic, data))
)
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyConnectionError::new_err("swarm closed"))
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,
data,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors => ignore messageID for now!!!
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())??;
Ok(())
}
// ---- Gossipsub message receiver methods ----
/// Receives the next message from the `GossipSub` network.
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
self.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
}
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.into_iter()
.map(|(t, d)| (t, d.pybytes()))
.collect())
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
m.add_class::<exception::PyAllQueuesFullError>()?;
m.add_class::<PySwarm>()?;
m.add_class::<PyMessage>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?;
Ok(())
}

View File

@@ -0,0 +1,159 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -0,0 +1,8 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -0,0 +1,81 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -19,14 +19,21 @@ either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures-lite = { workspace = true }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
@@ -34,4 +41,4 @@ keccak-const = { workspace = true }
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,6 +1,6 @@
use libp2p::identity;
use networking::swarm::{FromSwarm, Swarm, ToSwarm};
use tokio::sync::mpsc;
use futures::stream::StreamExt as _;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
@@ -11,50 +11,60 @@ async fn main() {
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
let (to_swarm, from_client) = mpsc::channel(20);
let (to_client, mut from_swarm) = mpsc::channel(20);
// Configure swarm
let mut swarm = Swarm::new(
identity::Keypair::generate_ed25519(),
from_client,
to_client,
)
.expect("Swarm creation failed");
let mut swarm =
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
// Create a Gossipsub topic & subscribe
_ = to_swarm
.send(ToSwarm::Subscribe("test-net".to_owned()))
.await;
let topic = gossipsub::IdentTopic::new("test-net");
swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("Subscribing to topic failed");
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
tokio::task::spawn(async move { swarm.run().await });
// Kick it off
loop {
select! {
// on gossipsub outgoing
Ok(Some(line)) = stdin.next_line() => {
_= to_swarm.send(ToSwarm::Message("test-net".to_owned(), line.into_bytes())).await;
}
event = from_swarm.recv() => match event {
// on gossipsub incoming
Some(FromSwarm::Message(pid, topic, content)) => {
assert_eq!(topic, "test-net");
let fmt = String::from_utf8_lossy(&content);
println!("{pid}: {fmt}");
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
Some(FromSwarm::Discovered(pid)) => {
eprintln!("\n\nConnected to: {pid}\n\n");
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
Some(FromSwarm::Expired(pid)) => {
eprintln!("\n\nDisconnected from: {pid}\n\n");
}
None => break,
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }

View File

@@ -0,0 +1,127 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -1,10 +1,11 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::futures::FutureExt;
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{

View File

@@ -0,0 +1,44 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -4,7 +4,18 @@
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
@@ -43,3 +54,11 @@ pub(crate) mod ext {
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -1,30 +1,9 @@
use crate::alias;
use crate::discovery;
use crate::swarm::transport::tcp_transport;
use behaviour::{Behaviour, BehaviourEvent};
use futures_lite::StreamExt;
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
use tokio::sync::mpsc;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
pub struct Swarm {
swarm: libp2p::Swarm<Behaviour>,
from_client: mpsc::Receiver<ToSwarm>,
to_client: mpsc::Sender<FromSwarm>,
}
#[derive(Debug)]
pub enum FromSwarm {
PublishError(gossipsub::PublishError),
Discovered(PeerId),
Expired(PeerId),
Message(PeerId, String, Vec<u8>),
}
#[derive(Debug)]
pub enum ToSwarm {
Message(String, Vec<u8>),
Subscribe(String),
Unsubscribe(String),
}
pub type Swarm = libp2p::Swarm<Behaviour>;
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
@@ -36,142 +15,23 @@ pub enum ToSwarm {
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
impl Swarm {
/// Create and configure a swarm which listens to all ports on OS
pub fn new(
keypair: identity::Keypair,
from_client: mpsc::Receiver<ToSwarm>,
to_client: mpsc::Sender<FromSwarm>,
) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.build();
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.build();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(Self {
swarm,
from_client,
to_client,
})
}
pub async fn run(&mut self) {
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = self.from_client.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
ToSwarm::Subscribe(topic) => {
// try to subscribe
match self.swarm.behaviour_mut().gossipsub.subscribe(&gossipsub::IdentTopic::new(topic.clone())) {
Err(e) => {
let gossipsub::SubscriptionError::PublishError(e) = e else {
unreachable!("topic filter used")
};
let Ok(()) = self.to_client.send(FromSwarm::PublishError(e)).await else {
log::warn!("RUST: client connection closed");
break
};
},
Ok(false) => log::warn!("RUST: tried to subscribe to topic twice"),
Ok(true) => {},
}
}
ToSwarm::Unsubscribe(topic) => {
// try to subscribe
if !self.swarm.behaviour_mut().gossipsub.unsubscribe(&gossipsub::IdentTopic::new(topic)) {
log::warn!("RUST: tried to unsubscribe from topic twice");
}
}
ToSwarm::Message( topic, data ) => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
match self.swarm.behaviour_mut().gossipsub.publish(
gossipsub::IdentTopic::new(topic), data
) {
Ok(_) => {},
Err(e) => {
let Ok(()) = self.to_client.send(FromSwarm::PublishError(e)).await else {
log::warn!("RUST: client connection closed");
break
};
},
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = self.swarm.next() => {
let Some(swarm_event) = swarm_event else {
log::warn!("RUST: swarm closed communication");
break
};
let SwarmEvent::Behaviour(behaviour_event) = swarm_event else {
continue
};
match behaviour_event {
BehaviourEvent::Gossipsub(gossipsub::Event::Message {
message: gossipsub::Message {
source,
topic,
data,
..
},
..
}) => {
let Some(peer_id) = source else {
log::warn!("RUST: ignoring message with unknown source on {topic}");
continue;
};
// send incoming message to channel (or exit if connection closed)
if let Err(e) = self.to_client.send(FromSwarm::Message(peer_id, topic.into_string(), data)).await {
log::warn!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
break
};
},
BehaviourEvent::Discovery(discovery::Event::ConnectionEstablished { peer_id, .. }) => {
// send connection event to channel (or exit if connection closed)
if let Err(_) = self.to_client.send(FromSwarm::Discovered(peer_id)).await {
log::warn!("RUST: swarm closed communication");
};
},
BehaviourEvent::Discovery(discovery::Event::ConnectionClosed { peer_id, .. }) => {
// send connection event to channel (or exit if connection closed)
if let Err(_) = self.to_client.send(FromSwarm::Expired(peer_id)).await {
log::warn!("RUST: swarm closed communication");
};
},
e => {
log::debug!("RUST: other event {e:?}");
}
}
}
}
}
log::info!("RUST: networking task stopped");
}
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm)
}
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures_lite::{AsyncRead, AsyncWrite};
use futures::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;

View File

@@ -14,7 +14,6 @@ from exo.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
@@ -64,9 +63,6 @@ class DownloadCoordinator:
self.event_sender, self.event_receiver = channel[Event]()
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
return str(EXO_MODELS_DIR / model_id.normalize())
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
@@ -78,7 +74,6 @@ class DownloadCoordinator:
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -98,7 +93,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
@@ -176,11 +170,7 @@ class DownloadCoordinator:
return
# Emit pending status
progress = DownloadPending(
shard_metadata=shard,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
@@ -194,7 +184,6 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -217,7 +206,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
@@ -231,7 +219,6 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(
@@ -266,7 +253,6 @@ class DownloadCoordinator:
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
@@ -309,18 +295,11 @@ class DownloadCoordinator:
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
@@ -329,9 +308,6 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
continue

View File

@@ -44,7 +44,7 @@ class Node:
@classmethod
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_string())
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)
@@ -136,8 +136,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -149,6 +147,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit

View File

@@ -144,8 +144,8 @@ async def collect_responses_response(
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=tool.id,
call_id=tool.id,
id=f"fc_{tool.id}",
call_id=f"call_{tool.id}",
name=tool.name,
arguments=tool.arguments,
)

View File

@@ -42,7 +42,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_string())
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -75,7 +75,7 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_string()}_sender")
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(

View File

@@ -0,0 +1,37 @@
from enum import Enum
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages"""
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connection_type: ConnectionMessageType
remote_ipv4: str
remote_tcp_port: int
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
)

View File

@@ -16,19 +16,17 @@ from anyio.abc import TaskGroup
from exo_pyo3_bindings import (
AllQueuesFullError,
Keypair,
NetworkingHandle,
NoPeersSubscribedToTopicError,
PyMessage,
PySwarm,
)
from filelock import FileLock
from loguru import logger
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.shared.election import ConnectionMessage
from exo.shared.types.common import NodeId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.pydantic_ext import CamelCaseModel
from .connection_message import ConnectionMessage
from .topics import CONNECTION_MESSAGES, PublishPolicy, TypedTopic
@@ -104,13 +102,13 @@ class TopicRouter[T: CamelCaseModel]:
class Router:
@classmethod
def create(cls, identity: Keypair) -> "Router":
return cls(handle=PySwarm(identity))
return cls(handle=NetworkingHandle(identity))
def __init__(self, handle: PySwarm):
def __init__(self, handle: NetworkingHandle):
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
send, recv = channel[tuple[str, bytes]]()
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
self._net = handle
self._net: NetworkingHandle = handle
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
self._id_count = count()
self._tg: TaskGroup | None = None
@@ -156,6 +154,7 @@ class Router:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
@@ -180,44 +179,38 @@ class Router:
async def _networking_recv(self):
while True:
try:
msg = await self._net.recv()
except NoPeersSubscribedToTopicError:
continue
except AllQueuesFullError:
logger.warning("All peer queues full, messages have been lost")
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
if topic not in self.topic_routers:
logger.warning(f"Received message on unknown or inactive topic {topic}")
continue
match msg:
case PyMessage.Connection():
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(
ConnectionMessage(
node_id=NodeId(msg.node_id), connected=msg.connected
)
)
case PyMessage.Gossip():
if msg.topic not in self.topic_routers:
logger.warning(
f"Received message on unknown or inactive topic {msg.topic}"
)
continue
logger.trace(
f"Received message on {msg.topic} with payload {msg.data}"
)
router = self.topic_routers[msg.topic]
await router.publish_bytes(msg.data)
case _:
raise ValueError("net recv returned something impossible")
router = self.topic_routers[topic]
await router.publish_bytes(data)
async def _networking_recv_connection_messages(self):
while True:
update = await self._net.connection_update_recv()
message = ConnectionMessage.from_update(update)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
async def _networking_publish(self):
with self.networking_receiver as networked_items:
async for topic, data in networked_items:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
try:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
except NoPeersSubscribedToTopicError:
pass
except AllQueuesFullError:
logger.warning(f"All peer queues full, dropping message on {topic}")
def get_node_id_keypair(
@@ -228,7 +221,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate()
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
@@ -242,12 +235,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.deserialize(protobuf_encoded)
return Keypair.from_protobuf_encoding(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate()
f.write(keypair.serialize())
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -1,7 +1,8 @@
from dataclasses import dataclass
from enum import Enum
from exo.shared.election import ConnectionMessage, ElectionMessage
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
ForwarderEvent,

View File

@@ -218,6 +218,11 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
@@ -258,6 +263,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"downloads": downloads,
"topology": topology,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
"node_system": node_system,

View File

@@ -10,6 +10,7 @@ from anyio import (
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, Sender
@@ -18,11 +19,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
DEFAULT_ELECTION_TIMEOUT = 3.0
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connected: bool
class ElectionMessage(CamelCaseModel):
clock: int
seniority: int

View File

@@ -1,7 +1,7 @@
import pytest
from anyio import create_task_group, fail_after, move_on_after
from exo.routing.router import ConnectionMessage
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
@@ -330,7 +330,9 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
await cm_tx.send(
ConnectionMessage(
node_id=NodeId(),
connected=True,
connection_type=ConnectionMessageType.Connected,
remote_ipv4="",
remote_tcp_port=0,
)
)

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release()
# wait to be told to begin simultaneous read
ev.wait()
queue.put(get_node_id_keypair().serialize())
queue.put(get_node_id_keypair().to_protobuf_encoding())
def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -26,7 +26,6 @@ class DownloadProgressData(CamelCaseModel):
class BaseDownloadProgress(TaggedModel):
node_id: NodeId
shard_metadata: ShardMetadata
model_directory: str = ""
class DownloadPending(BaseDownloadProgress):

View File

@@ -1,7 +1,5 @@
import sys
def print_startup_banner(port: int) -> None:
"""Print a prominent startup banner with API endpoint information."""
dashboard_url = f"http://localhost:{port}"
banner = f"""
╔═══════════════════════════════════════════════════════════════════════╗
@@ -29,4 +27,4 @@ def print_startup_banner(port: int) -> None:
"""
print(banner, file=sys.stderr)
print(banner)

View File

@@ -306,7 +306,7 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Starting prefill")
logger.info("Ready to prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(

View File

@@ -353,13 +353,7 @@ def load_tokenizer_for_model_id(
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(
hf_tokenizer,
eos_token_ids=eos_token_ids,
tool_call_start="<|tool_calls_section_begin|>",
tool_call_end="<|tool_calls_section_end|>",
tool_parser=_parse_kimi_tool_calls,
)
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
tokenizer = load_tokenizer(
model_path,
@@ -591,41 +585,3 @@ def mx_barrier(group: Group | None):
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)
def _parse_kimi_tool_calls(text: str):
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
_tool_call_split_regex = re.compile(
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
)
def _parse_single_tool(text: str) -> dict[str, Any]:
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError("No tool call found.")
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
func_name = func_name_match.group(2) # e.g. "get_weather"
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError("No tool call arguments found.")
func_args = func_args_match.group(1)
try:
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
except Exception:
arg_dct = None
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
tool_matches = _tool_call_split_regex.findall(text)
if tool_matches:
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
else:
return [_parse_single_tool(text)]

View File

@@ -1,4 +1,8 @@
from __future__ import annotations
import os
import threading
from multiprocessing.sharedctypes import Synchronized
import loguru
@@ -10,6 +14,15 @@ from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
logger: "loguru.Logger" = loguru.logger
HEARTBEAT_INTERVAL_SECONDS = 0.5
def _heartbeat_loop(heartbeat: Synchronized[int], stop: threading.Event) -> None:
"""Daemon thread that periodically increments the heartbeat counter."""
while not stop.is_set():
heartbeat.value += 1
stop.wait(HEARTBEAT_INTERVAL_SECONDS)
def entrypoint(
bound_instance: BoundInstance,
@@ -17,6 +30,7 @@ def entrypoint(
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
heartbeat: Synchronized[int] | None = None,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -35,6 +49,17 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Start heartbeat thread so the supervisor can detect if we freeze.
stop_heartbeat = threading.Event()
heartbeat_thread: threading.Thread | None = None
if heartbeat is not None:
heartbeat_thread = threading.Thread(
target=_heartbeat_loop,
args=(heartbeat, stop_heartbeat),
daemon=True,
)
heartbeat_thread.start()
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main
@@ -53,6 +78,9 @@ def entrypoint(
)
)
finally:
stop_heartbeat.set()
if heartbeat_thread is not None:
heartbeat_thread.join(timeout=1)
try:
event_sender.close()
task_receiver.close()

View File

@@ -1,10 +1,12 @@
import base64
import json
import math
import os
import resource
import time
from collections.abc import Generator
from functools import cache
from typing import Literal
from typing import Any, Callable, Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -15,6 +17,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask
@@ -91,8 +94,6 @@ from exo.worker.engines.mlx.utils_mlx import (
)
from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser, make_mlx_parser
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
@@ -138,7 +139,6 @@ def main(
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
tool_parser: ToolParser | None = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
@@ -204,17 +204,8 @@ def main(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
tool_parser = make_mlx_parser(
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
kv_prefix_cache = KVPrefixCache(group)
elif (
@@ -320,11 +311,31 @@ def main(
mlx_generator, tokenizer
)
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
if isinstance(inference_model, GptOssModel):
elif isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
completion_tokens = 0
tokens_since_last_cancel_check = 0
@@ -577,8 +588,21 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
@@ -635,9 +659,9 @@ def parse_gpt_oss(
def parse_thinking_models(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -758,58 +782,225 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse], tool_parser: ToolParser
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
if response.text.startswith(tool_parser.start_parsing):
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"tool call parsing interrupted, yield partial tool call as text"
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(
tool_calls=tools, usage=response.usage, stats=response.stats
)
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.finish_reason is not None:
logger.info(
"toll call parsing interrupted, yield partial tool call as text"
)
yield GenerationResponse(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=response.usage,
stats=response.stats,
)
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
_func_name_regex = re.compile(
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
original_func_name = func_name_match.group(1)
tool_id = func_name_match.group(2)
# strip off the `functions.` prefix, if it exists.
func_name = original_func_name[original_func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(
id=f"{original_func_name}:{tool_id}",
name=func_name,
arguments=arg_dct, # pyright: ignore[reportAny]
)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
raw_id: object = obj.get("id")
extra = {"id": str(raw_id)} if raw_id is not None else {}
return ToolCallItem(
**extra,
name=name,
arguments=json.dumps(args),
)
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
EXO_RUNNER_MUST_DIE = "EXO RUNNER MUST DIE"
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
@@ -825,6 +1016,9 @@ def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
if not prompt:
return
if EXO_RUNNER_MUST_DIE in prompt:
logger.info("Abrupt process death triggered (simulates OOM kill)")
os._exit(1)
if EXO_RUNNER_MUST_FAIL in prompt:
logger.info("raising exception")
raise Exception("Artificial runner exception - for testing purposes only.")

View File

@@ -1,12 +1,17 @@
from __future__ import annotations
import contextlib
import multiprocessing
import signal
from dataclasses import dataclass, field
from multiprocessing import Process
from multiprocessing.sharedctypes import Synchronized
from typing import Self
import anyio
from anyio import (
BrokenResourceError,
CancelScope,
ClosedResourceError,
to_thread,
)
@@ -26,6 +31,7 @@ from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerLoading,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
@@ -36,6 +42,8 @@ from exo.worker.runner.bootstrap import entrypoint
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
HEALTH_CHECK_INTERVAL_SECONDS = 1
HEARTBEAT_STALE_CHECKS = 10
@dataclass(eq=False)
@@ -48,10 +56,14 @@ class RunnerSupervisor:
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_heartbeat: Synchronized[int]
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
_death_handled: bool = field(default=False, init=False)
_last_heartbeat_value: int = field(default=0, init=False)
_heartbeat_stale_count: int = field(default=0, init=False)
@classmethod
def create(
@@ -65,6 +77,8 @@ class RunnerSupervisor:
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
heartbeat: Synchronized[int] = multiprocessing.Value("Q", 0)
runner_process = Process(
target=entrypoint,
args=(
@@ -73,6 +87,7 @@ class RunnerSupervisor:
task_recv,
cancel_recv,
logger,
heartbeat,
),
daemon=True,
)
@@ -88,13 +103,16 @@ class RunnerSupervisor:
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
_heartbeat=heartbeat,
)
return self
async def run(self):
self.runner_process.start()
await self._forward_events()
async with anyio.create_task_group() as tg:
tg.start_soon(self._forward_events)
tg.start_soon(self._health_check, tg.cancel_scope)
def shutdown(self):
logger.info("Runner supervisor shutting down")
@@ -177,9 +195,99 @@ class RunnerSupervisor:
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
if not self._death_handled:
self._death_handled = True
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
async def _health_check(self, cancel_scope: CancelScope) -> None:
"""Periodically check if the runner process is alive and responsive.
Detects two failure modes:
1. Process death (e.g. OOM kill) without cleanly closing the event
channel, which would leave _forward_events blocked on queue.get().
2. Unresponsive process (e.g. frozen by OS memory pressure, deadlock)
detected via a stale heartbeat counter.
"""
while True:
await anyio.sleep(HEALTH_CHECK_INTERVAL_SECONDS)
if not self.runner_process.is_alive():
self._handle_process_exit(cancel_scope)
return
# Check heartbeat counter — if it hasn't changed between
# consecutive checks, the subprocess may be frozen.
current = self._heartbeat.value
if current > 0:
if current == self._last_heartbeat_value:
self._heartbeat_stale_count += 1
if self._heartbeat_stale_count >= HEARTBEAT_STALE_CHECKS:
logger.error(
f"Health check: runner process unresponsive "
f"(heartbeat stale for {self._heartbeat_stale_count} checks), killing"
)
self._handle_unresponsive(cancel_scope)
return
else:
self._heartbeat_stale_count = 0
self._last_heartbeat_value = current
def _handle_process_exit(self, cancel_scope: CancelScope) -> None:
"""Handle runner process that has exited."""
if not self._death_handled:
self._death_handled = True
if isinstance(
self.status, (RunnerShutdown, RunnerShuttingDown, RunnerFailed)
):
logger.info("Health check: runner process exited (expected)")
else:
rc = self.runner_process.exitcode
if isinstance(rc, int) and rc < 0:
sig = -rc
try:
cause = f"signal={sig} ({signal.strsignal(sig)})"
except Exception:
cause = f"signal={sig}"
else:
cause = f"exitcode={rc}"
logger.error(
f"Health check: runner process died unexpectedly ({cause})"
)
self._event_sender.send_nowait(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(
error_message=f"Terminated ({cause})"
),
)
)
self.shutdown()
for tid in self.pending:
self.pending[tid].set()
cancel_scope.cancel()
def _handle_unresponsive(self, cancel_scope: CancelScope) -> None:
"""Handle runner process that is alive but unresponsive."""
if not self._death_handled:
self._death_handled = True
self._event_sender.send_nowait(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(
error_message="Runner process unresponsive (heartbeat timeout)"
),
)
)
for tid in self.pending:
self.pending[tid].set()
self.shutdown()
cancel_scope.cancel()
def __del__(self) -> None:
if self.runner_process.is_alive():

View File

@@ -1,72 +0,0 @@
import json
from dataclasses import dataclass
from typing import Any, Callable
from exo.shared.types.api import ToolCallItem
@dataclass
class ToolParser:
start_parsing: str
end_parsing: str
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
def make_mlx_parser(
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> ToolParser:
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix(tool_call_start)
text = text.removesuffix(tool_call_end)
parsed = tool_parser(text)
if isinstance(parsed, list):
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
else:
return [ToolCallItem.model_validate(_flatten(parsed))]
except Exception:
return None
return ToolParser(
start_parsing=tool_call_start,
end_parsing=tool_call_end,
parse_tool_calls=parse_tool_calls,
)
# TODO / example code:
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix("<tool_call>")
text = text.removesuffix("</tool_call>")
top_level = {
k: json.dumps(v) if isinstance(v, (dict, list)) else v
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
}
return [ToolCallItem.model_validate(top_level)]
except Exception:
return None
def _flatten(p: dict[str, Any]) -> dict[str, str]:
return {
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
for k, v in p.items() # pyright: ignore[reportAny]
}
json_tool_parser = ToolParser(
start_parsing="<tool_call>",
end_parsing="</tool_call>",
parse_tool_calls=_parse_json_calls,
)
def infer_tool_parser(chat_template: str) -> ToolParser | None:
"""Attempt to auto-infer a tool parser from the chat template."""
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
return json_tool_parser
return None

View File

@@ -5,13 +5,12 @@ from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.runner import parse_tool_calls
from exo.worker.runner.tool_parsers import make_mlx_parser
def _make_responses(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""Create a sequence of GenerationResponses from text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
@@ -23,13 +22,10 @@ def _make_responses(
)
def _dummier_parser(text: str) -> dict[str, Any]:
def _dummy_parser(text: str) -> dict[str, Any]:
return {"name": "test_fn", "arguments": {"arg": text}}
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
class TestParseToolCalls:
"""Tests for parse_tool_calls generator."""
@@ -39,6 +35,8 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -52,6 +50,8 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -76,7 +76,9 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
"<tool_call>",
"</tool_call>",
_failing_parser,
)
)

View File

@@ -1 +1,204 @@
# TODO:
from __future__ import annotations
import multiprocessing
import os
import signal as signal_module
from collections.abc import Callable
from multiprocessing.sharedctypes import Synchronized
from typing import Any
import anyio
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.runners import RunnerFailed, RunnerIdle, RunnerShutdown
from exo.utils.channels import Receiver, Sender, channel, mp_channel
from exo.worker.runner.runner_supervisor import (
HEALTH_CHECK_INTERVAL_SECONDS,
HEARTBEAT_STALE_CHECKS,
RunnerSupervisor,
)
from ...constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from ..conftest import get_bound_mlx_ring_instance
def _die_immediately() -> None:
"""Subprocess target that exits with a non-zero code."""
os._exit(1)
def _die_with_signal() -> None:
"""Subprocess target that kills itself with SIGKILL (simulates OOM)."""
os.kill(os.getpid(), signal_module.SIGKILL)
def _exit_cleanly() -> None:
"""Subprocess target that exits with code 0."""
os._exit(0)
def _hang_forever() -> None:
"""Subprocess target that hangs without updating heartbeat (simulates freeze)."""
import time
# Write one heartbeat so the supervisor starts tracking, then stop.
time.sleep(100000)
def _build_supervisor(
event_sender: Sender[Event],
target: Callable[..., Any],
) -> RunnerSupervisor:
"""Build a RunnerSupervisor with a custom subprocess target.
Uses a clone of event_sender (matching real Worker behavior) so that
closing the supervisor's copy doesn't close the test's receiver.
"""
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NODE_A,
)
_ev_send, ev_recv = mp_channel[Event]()
task_sender, _task_recv = mp_channel[Task]()
cancel_sender, _cancel_recv = mp_channel[TaskId]()
runner_process = multiprocessing.Process(target=target, daemon=True)
heartbeat: Synchronized[int] = multiprocessing.Value("Q", 0)
return RunnerSupervisor(
bound_instance=bound_instance,
shard_metadata=bound_instance.bound_shard,
runner_process=runner_process,
initialize_timeout=10,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender.clone(),
_heartbeat=heartbeat,
)
def _collect_failed_events(
event_receiver: Receiver[Event],
) -> list[RunnerFailed]:
"""Drain the receiver and return all RunnerFailed statuses."""
out: list[RunnerFailed] = []
while True:
try:
event = event_receiver.receive_nowait()
except Exception:
break
if isinstance(event, RunnerStatusUpdated) and isinstance(
event.runner_status, RunnerFailed
):
out.append(event.runner_status)
return out
async def test_health_check_detects_dead_process():
"""When the runner process dies with a non-zero exit code, the health check
should emit a RunnerFailed event and run() should return."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_immediately)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "exitcode=1" in failures[0].error_message
async def test_health_check_detects_signal_death():
"""When the runner process is killed by a signal (e.g. OOM -> SIGKILL),
the health check should report the signal in the failure message."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_with_signal)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "signal=9" in failures[0].error_message
async def test_health_check_releases_pending_tasks():
"""When the runner dies, any pending start_task() waiters should be unblocked."""
event_sender, _event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_immediately)
# Register a pending waiter as if start_task() was waiting for acknowledgement
task_event = anyio.Event()
tid = TaskId("pending-task")
supervisor.pending[tid] = task_event
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
assert task_event.is_set()
async def test_clean_exit_no_failure_when_shutdown_status():
"""When the runner was in RunnerShutdown status and exits with code 0,
no RunnerFailed event should be emitted."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _exit_cleanly)
# Simulate that the runner had already reported shutdown via events
supervisor.status = RunnerShutdown()
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 0
async def test_unexpected_exit_code_zero_emits_failure():
"""When the runner exits with code 0 but was NOT in a shutdown state,
this is unexpected and should still emit RunnerFailed."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _exit_cleanly)
assert isinstance(supervisor.status, RunnerIdle)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "exitcode=0" in failures[0].error_message
async def test_heartbeat_timeout_detects_unresponsive_process():
"""When the runner process is alive but its heartbeat goes stale,
the health check should kill it and emit RunnerFailed."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _hang_forever)
# Pre-seed the heartbeat counter with a non-zero value and set the
# supervisor's last-seen value to match so it appears stale immediately.
# Set stale count to HEARTBEAT_STALE_CHECKS - 1 so a single check triggers.
supervisor._heartbeat.value = 42 # pyright: ignore[reportPrivateUsage]
supervisor._last_heartbeat_value = 42 # pyright: ignore[reportPrivateUsage]
supervisor._heartbeat_stale_count = HEARTBEAT_STALE_CHECKS - 1 # pyright: ignore[reportPrivateUsage]
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "unresponsive" in failures[0].error_message.lower()

View File

@@ -43,5 +43,4 @@ for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
echo "all hosts alive!"
wait

48
uv.lock generated
View File

@@ -377,8 +377,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -416,9 +416,9 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.7" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "msgspec", specifier = ">=0.19.0" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
@@ -1020,8 +1020,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1048,12 +1048,18 @@ wheels = [
name = "mlx"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform == 'linux'",
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
]
@@ -1066,14 +1072,6 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.7.dev20260217+50487b41"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.6"
@@ -1100,20 +1098,30 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.7"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/66/0d/56542e2ae13ec6f542d3977d7cff89a205d4f6c5122e0ce23f33265f61c9/mlx_lm-0.30.7.tar.gz", hash = "sha256:e5f31ac58d9f2381f28e1ba639ff903e64f7cff1bdc245c0bc97f72264be329c", size = 275764, upload-time = "2026-02-12T18:41:11.86Z" }
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/17/a41c798a3d9cbdc47f39c6db5bba4c2cd199203ead26bf911cb03b644070/mlx_lm-0.30.7-py3-none-any.whl", hash = "sha256:17442a4bf01c4c2d3bca1e647712fe44f19890c3f1eadc8589d389e57b44b9bf", size = 386591, upload-time = "2026-02-12T18:41:10.236Z" },
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
]
[[package]]