mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 19:22:39 -05:00
Compare commits
3 Commits
rust-explo
...
runner-rac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a0b9dc727 | ||
|
|
6177550c34 | ||
|
|
7b6cad94c6 |
4
.github/workflows/pipeline.yml
vendored
4
.github/workflows/pipeline.yml
vendored
@@ -142,4 +142,6 @@ jobs:
|
||||
# Run pytest outside sandbox (needs GPU access for MLX)
|
||||
export HOME="$RUNNER_TEMP"
|
||||
export EXO_TESTS=1
|
||||
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
export EXO_DASHBOARD_DIR="$PWD/dashboard/"
|
||||
export EXO_RESOURCES_DIR="$PWD/resources"
|
||||
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
|
||||
271
Cargo.lock
generated
271
Cargo.lock
generated
@@ -141,6 +141,12 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.7.1"
|
||||
@@ -298,6 +304,19 @@ version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
|
||||
|
||||
[[package]]
|
||||
name = "bigdecimal"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bimap"
|
||||
version = "0.6.3"
|
||||
@@ -334,6 +353,31 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bon"
|
||||
version = "3.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebeb9aaf9329dff6ceb65c689ca3db33dbf15f324909c60e4e5eef5701ce31b1"
|
||||
dependencies = [
|
||||
"bon-macros",
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bon-macros"
|
||||
version = "3.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77e9d642a7e3a318e37c2c9427b5a6a48aa1ad55dcd986f3034ab2239045a645"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"ident_case",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bs58"
|
||||
version = "0.5.1"
|
||||
@@ -497,6 +541,15 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
|
||||
|
||||
[[package]]
|
||||
name = "convert_case"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
|
||||
dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@@ -647,6 +700,41 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.9.0"
|
||||
@@ -673,6 +761,17 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "delegate"
|
||||
version = "0.13.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.7.10"
|
||||
@@ -707,6 +806,29 @@ dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustc_version",
|
||||
"syn 2.0.111",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
@@ -876,23 +998,37 @@ dependencies = [
|
||||
name = "exo_pyo3_bindings"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"env_logger",
|
||||
"futures-lite",
|
||||
"extend",
|
||||
"futures",
|
||||
"impl-trait-for-tuples",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
"once_cell",
|
||||
"pin-project",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"thiserror 2.0.17",
|
||||
"thread_local",
|
||||
"tokio",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
name = "extend"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
checksum = "311a6d2f1f9d60bff73d2c78a0af97ed27f79672f15c238192a5bbb64db56d00"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
@@ -1002,10 +1138,7 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -1492,6 +1625,12 @@ dependencies = [
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "1.1.0"
|
||||
@@ -1567,6 +1706,17 @@ dependencies = [
|
||||
"xmltree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "impl-trait-for-tuples"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.12.1"
|
||||
@@ -1595,6 +1745,15 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "internment"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "636d4b0f6a39fd684effe2a73f5310df16a3fa7954c26d36833e98f44d1977a2"
|
||||
dependencies = [
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inventory"
|
||||
version = "0.3.21"
|
||||
@@ -1721,6 +1880,12 @@ dependencies = [
|
||||
"cpufeatures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "keccak-const"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57d8d8ce877200136358e0bbff3a77965875db3af755a11e1fa6b1b3e2df13ea"
|
||||
|
||||
[[package]]
|
||||
name = "lalrpop-util"
|
||||
version = "0.20.2"
|
||||
@@ -1739,6 +1904,12 @@ version = "0.2.178"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
|
||||
|
||||
[[package]]
|
||||
name = "libp2p"
|
||||
version = "0.56.0"
|
||||
@@ -2727,10 +2898,20 @@ dependencies = [
|
||||
name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2812,6 +2993,17 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
@@ -3115,6 +3307,16 @@ dependencies = [
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "primeorder"
|
||||
version = "0.13.6"
|
||||
@@ -3162,14 +3364,28 @@ version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
|
||||
dependencies = [
|
||||
"bigdecimal",
|
||||
"either",
|
||||
"hashbrown 0.16.1",
|
||||
"indexmap",
|
||||
"indoc",
|
||||
"inventory",
|
||||
"libc",
|
||||
"lock_api",
|
||||
"memoffset",
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"ordered-float",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"rust_decimal",
|
||||
"smallvec",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
@@ -3524,6 +3740,12 @@ dependencies = [
|
||||
"yasna",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "recursion"
|
||||
version = "0.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dba2197bf7b1d87b4dd460c195f4edeb45a94e82e8054f8d5f317c1f0e93ca1"
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.18"
|
||||
@@ -3610,6 +3832,16 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust_decimal"
|
||||
version = "1.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
@@ -4474,12 +4706,24 @@ version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_names2"
|
||||
version = "1.3.0"
|
||||
@@ -4560,6 +4804,19 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "util"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"bon",
|
||||
"derive_more",
|
||||
"extend",
|
||||
"internment",
|
||||
"once_cell",
|
||||
"recursion",
|
||||
"thiserror 2.0.17",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.19.0"
|
||||
|
||||
108
Cargo.toml
108
Cargo.toml
@@ -3,6 +3,7 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -23,18 +24,62 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
|
||||
# networking
|
||||
libp2p = "0.56"
|
||||
libp2p-tcp = "0.44"
|
||||
|
||||
[workspace.lints.rust]
|
||||
static_mut_refs = "warn"
|
||||
static_mut_refs = "warn" # Or use "warn" instead of deny
|
||||
incomplete_features = "allow"
|
||||
|
||||
# Clippy's lint category level configurations;
|
||||
# every member crate needs to inherit these by adding
|
||||
@@ -55,3 +100,64 @@ perf = { level = "warn", priority = -1 }
|
||||
pedantic = { level = "warn", priority = -1 }
|
||||
nursery = { level = "warn", priority = -1 }
|
||||
cargo = { level = "warn", priority = -1 }
|
||||
|
||||
# Individual Clippy lints from the `restriction` category
|
||||
arithmetic_side_effects = "warn"
|
||||
as_conversions = "warn"
|
||||
assertions_on_result_states = "warn"
|
||||
clone_on_ref_ptr = "warn"
|
||||
decimal_literal_representation = "warn"
|
||||
default_union_representation = "warn"
|
||||
deref_by_slicing = "warn"
|
||||
disallowed_script_idents = "deny"
|
||||
else_if_without_else = "warn"
|
||||
empty_enum_variants_with_brackets = "warn"
|
||||
empty_structs_with_brackets = "warn"
|
||||
error_impl_error = "warn"
|
||||
exit = "deny"
|
||||
expect_used = "warn"
|
||||
float_cmp_const = "warn"
|
||||
get_unwrap = "warn"
|
||||
if_then_some_else_none = "warn"
|
||||
impl_trait_in_params = "warn"
|
||||
indexing_slicing = "warn"
|
||||
infinite_loop = "warn"
|
||||
let_underscore_must_use = "warn"
|
||||
let_underscore_untyped = "warn"
|
||||
lossy_float_literal = "warn"
|
||||
mem_forget = "warn"
|
||||
missing_inline_in_public_items = "warn"
|
||||
multiple_inherent_impl = "warn"
|
||||
multiple_unsafe_ops_per_block = "warn"
|
||||
mutex_atomic = "warn"
|
||||
non_zero_suggestions = "warn"
|
||||
panic = "warn"
|
||||
partial_pub_fields = "warn"
|
||||
pattern_type_mismatch = "warn"
|
||||
pub_without_shorthand = "warn"
|
||||
rc_buffer = "warn"
|
||||
rc_mutex = "warn"
|
||||
redundant_type_annotations = "warn"
|
||||
renamed_function_params = "warn"
|
||||
rest_pat_in_fully_bound_structs = "warn"
|
||||
same_name_method = "warn"
|
||||
self_named_module_files = "deny"
|
||||
semicolon_inside_block = "warn"
|
||||
shadow_same = "warn"
|
||||
shadow_unrelated = "warn"
|
||||
str_to_string = "warn"
|
||||
string_add = "warn"
|
||||
string_lit_chars_any = "warn"
|
||||
string_to_string = "warn"
|
||||
tests_outside_test_module = "warn"
|
||||
todo = "warn"
|
||||
try_err = "warn"
|
||||
undocumented_unsafe_blocks = "warn"
|
||||
unnecessary_safety_comment = "warn"
|
||||
unnecessary_safety_doc = "warn"
|
||||
unneeded_field_pattern = "warn"
|
||||
unseparated_literal_suffix = "warn"
|
||||
unused_result_ok = "warn"
|
||||
unused_trait_names = "warn"
|
||||
unwrap_used = "warn"
|
||||
verbose_file_reads = "warn"
|
||||
|
||||
@@ -69,7 +69,8 @@
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
|
||||
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
done
|
||||
'';
|
||||
|
||||
@@ -3,6 +3,7 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
@@ -3,6 +3,7 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
@@ -3,6 +3,7 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
@@ -3,6 +3,7 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
@@ -3,6 +3,7 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
@@ -3,6 +3,7 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
2
rust/clippy.toml
Normal file
2
rust/clippy.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
@@ -25,25 +25,44 @@ workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
"abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"experimental-async" # async support in #[pyfunction] & #[pymethods]
|
||||
# "experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
# "py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
# "ordered-float", "rust_decimal", "smallvec",
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-log = "0.13.2"
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
#tracing = "0.1"
|
||||
@@ -56,4 +75,3 @@ env_logger = "0.11"
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
futures-lite = "2.6.1"
|
||||
|
||||
@@ -2,39 +2,220 @@
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
import enum
|
||||
import typing
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
@staticmethod
|
||||
def generate() -> Keypair:
|
||||
class AllQueuesFullError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Generate a new ed25519 keypair
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ecdsa() -> Keypair:
|
||||
r"""
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key to a protobuf structure.
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
def to_string(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PyPeer:
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def new(kp: Keypair, namespace: builtins.str) -> PyPeer: ...
|
||||
async def subscribe(self, topic: builtins.str) -> None: ...
|
||||
async def unsubscribe(self, topic: builtins.str) -> None: ...
|
||||
async def send(self, topic: builtins.str, payload: bytes) -> None: ...
|
||||
async def run(self) -> None: ...
|
||||
async def recv(self) -> PySwarmEvent: ...
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class PySwarmEvent:
|
||||
def downcast_discovered(self) -> typing.Optional[builtins.str]: ...
|
||||
def downcast_expired(self) -> typing.Optional[builtins.str]: ...
|
||||
def downcast_message(self) -> typing.Optional[tuple[builtins.str, builtins.str, bytes]]: ...
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
|
||||
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
"""
|
||||
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Unsubscribes from a `GossipSub` topic.
|
||||
|
||||
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
"""
|
||||
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
|
||||
r"""
|
||||
Publishes a message with multiple topics to the `GossipSub` network.
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
//! See: <https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await>
|
||||
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -6,17 +10,31 @@ use std::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
pub struct AllowThreads<F>(pub(crate) F);
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
#[pin_project]
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct AllowThreads<F>(#[pin] F);
|
||||
|
||||
impl<F> AllowThreads<F>
|
||||
where
|
||||
Self: Future,
|
||||
{
|
||||
pub fn new(f: F) -> Self {
|
||||
Self(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Unpin + Send,
|
||||
F::Output: Send,
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,42 +1,216 @@
|
||||
//! TODO: crate documentation
|
||||
pub(crate) mod allow_threading;
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod take_once {
|
||||
use tokio::sync::Mutex;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
pub struct TakeOnce<T>(Mutex<Option<T>>);
|
||||
impl<T> TakeOnce<T> {
|
||||
pub fn new(t: T) -> Self {
|
||||
Self(Mutex::new(Some(t)))
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
/// Namespace for all the constants used by this crate.
|
||||
pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
pub fn take(&self) -> Option<T> {
|
||||
match self.0.try_lock() {
|
||||
Ok(mut o) => o.take(),
|
||||
Err(_) => None,
|
||||
}
|
||||
|
||||
#[ext(pub, name = ResultExt)]
|
||||
impl<T, E> Result<T, E>
|
||||
where
|
||||
E: ToString,
|
||||
{
|
||||
fn pyerr(self) -> PyResult<T> {
|
||||
self.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FutureExt: Future + Sized {
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
fn allow_threads_py(self) -> AllowThreads<Self>
|
||||
where
|
||||
AllowThreads<Self>: Future,
|
||||
{
|
||||
AllowThreads::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Future> FutureExt for T {}
|
||||
|
||||
#[ext(pub, name = PyErrExt)]
|
||||
impl PyErr {
|
||||
fn receiver_channel_closed() -> Self {
|
||||
PyConnectionError::new_err("Receiver channel closed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
match self {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => {
|
||||
// write error back to python
|
||||
e.write_unraisable(py, None);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioRuntimeExt)]
|
||||
impl Runtime {
|
||||
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
|
||||
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscSenderExt)]
|
||||
impl<T> mpsc::Sender<T> {
|
||||
/// Sends a value, waiting until there is capacity.
|
||||
///
|
||||
/// A successful send occurs when it is determined that the other end of the
|
||||
/// channel has not hung up already. An unsuccessful send would be one where
|
||||
/// the corresponding receiver has already been closed.
|
||||
async fn send_py(&self, value: T) -> PyResult<()> {
|
||||
self.send(value)
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscReceiverExt)]
|
||||
impl<T> mpsc::Receiver<T> {
|
||||
/// Receives the next value for this receiver.
|
||||
async fn recv_py(&mut self) -> PyResult<T> {
|
||||
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
|
||||
}
|
||||
|
||||
/// Receives at most `limit` values for this receiver and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
|
||||
// get updates from receiver channel
|
||||
let mut updates = Vec::with_capacity(limit);
|
||||
let received = self.recv_many(&mut updates, limit).await;
|
||||
|
||||
// if we received zero items, then the channel was unexpectedly closed
|
||||
if limit != 0 && received == 0 {
|
||||
return Err(PyErr::receiver_channel_closed());
|
||||
}
|
||||
|
||||
Ok(updates)
|
||||
}
|
||||
|
||||
/// Tries to receive the next value for this receiver.
|
||||
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
|
||||
match self.try_recv() {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(TryRecvError::Empty) => Ok(None),
|
||||
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use pyo3::prelude::*;
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
#[pymodule(name = "exo_pyo3_bindings")]
|
||||
pub fn networking_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
// setup runtime
|
||||
let mut builder = tokio::runtime::Builder::new_multi_thread();
|
||||
builder.enable_all();
|
||||
pyo3_async_runtimes::tokio::init(builder);
|
||||
|
||||
m.add_class::<networking::PyPeer>()?;
|
||||
m.add_class::<networking::PyKeypair>()?;
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,214 +1,571 @@
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use crate::take_once::TakeOnce;
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
|
||||
use std::pin::pin;
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
use util::ext::VecExt as _;
|
||||
|
||||
use futures_lite::FutureExt;
|
||||
use libp2p::{gossipsub::PublishError, identity::Keypair};
|
||||
use networking::{FromSwarm, Peer, ToSwarm};
|
||||
use pyo3::{
|
||||
coroutine::CancelHandle,
|
||||
exceptions::{PyConnectionError, PyRuntimeError, PyValueError},
|
||||
prelude::*,
|
||||
types::PyBytes,
|
||||
};
|
||||
use pyo3_stub_gen::{
|
||||
derive::{gen_methods_from_python, gen_stub_pyclass, gen_stub_pymethods},
|
||||
inventory::submit,
|
||||
};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[derive(Clone)]
|
||||
pub struct PyKeypair(Keypair);
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
|
||||
pub struct PyNoPeersSubscribedToTopicError {}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyKeypair {
|
||||
/// Generate a new ed25519 keypair
|
||||
#[staticmethod]
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
const MSG: &'static str = "\
|
||||
No peers are currently subscribed to receive messages on this topic. \
|
||||
Wait for peers to subscribe or check your network connectivity.";
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: &Bound<'_, PyBytes>) -> Self {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Self(Keypair::from_protobuf_encoding(&bytes).expect("todo"))
|
||||
}
|
||||
|
||||
/// Encode a private key to a protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
match self.0.to_protobuf_encoding() {
|
||||
Ok(bytes) => Ok(PyBytes::new(py, &bytes)),
|
||||
Err(e) => Err(PyValueError::new_err(e.to_string())),
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
fn to_string(&self) -> String {
|
||||
self.0.public().to_peer_id().to_base58()
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
|
||||
pub struct PyAllQueuesFullError {}
|
||||
|
||||
impl PyAllQueuesFullError {
|
||||
const MSG: &'static str =
|
||||
"All libp2p peers are unresponsive, resend the message or reconnect.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAllQueuesFullError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PeerBuilder(
|
||||
String,
|
||||
Keypair,
|
||||
mpsc::Sender<FromSwarm>,
|
||||
mpsc::Receiver<ToSwarm>,
|
||||
);
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass]
|
||||
pub struct PyPeer {
|
||||
peer: TakeOnce<PeerBuilder>,
|
||||
to_swarm: mpsc::Sender<ToSwarm>,
|
||||
from_swarm: Mutex<mpsc::Receiver<FromSwarm>>,
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyPeer {
|
||||
#[staticmethod]
|
||||
fn new(kp: PyKeypair, namespace: String) -> PyResult<Self> {
|
||||
let (to_client, from_swarm) = mpsc::channel(1024);
|
||||
let (to_swarm, from_client) = mpsc::channel(1024);
|
||||
Ok(Self {
|
||||
peer: TakeOnce::new(PeerBuilder(namespace, kp.0, to_client, from_client)),
|
||||
to_swarm,
|
||||
from_swarm: Mutex::new(from_swarm),
|
||||
})
|
||||
impl PyNetworkingHandle {
|
||||
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
|
||||
// immediately beforehand to release the interpreter.
|
||||
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
|
||||
// ---- Lifecycle management methods ----
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
async fn run(&self, #[pyo3(cancel_handle)] mut cancel: CancelHandle) -> PyResult<()> {
|
||||
let builder = self
|
||||
.peer
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("tried to run peer twice"))?;
|
||||
let jh = pyo3_async_runtimes::tokio::get_runtime()
|
||||
.spawn(async move {
|
||||
let mut peer =
|
||||
Peer::new(builder.0, builder.1, builder.2, builder.3).map_err(|_| {
|
||||
PyConnectionError::new_err("peer failed to listen on default address")
|
||||
})?;
|
||||
peer.run()
|
||||
.await
|
||||
.map_err(|()| PyConnectionError::new_err("peer communication closed"))
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.or(async {
|
||||
cancel.cancelled().await;
|
||||
Ok(Ok(()))
|
||||
});
|
||||
match AllowThreads(pin!(jh)).await {
|
||||
Err(e) if e.is_cancelled() => Ok(()),
|
||||
Err(e) if e.is_panic() => Err(PyRuntimeError::new_err(format!("tokio panic {e}"))),
|
||||
Err(_) => unreachable!(),
|
||||
Ok(res) => res,
|
||||
}
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
}
|
||||
|
||||
async fn subscribe(&self, topic: String) -> PyResult<()> {
|
||||
self.to_swarm
|
||||
.send(ToSwarm::Subscribe(topic))
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & convert any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
|
||||
}
|
||||
async fn unsubscribe(&self, topic: String) -> PyResult<()> {
|
||||
self.to_swarm
|
||||
.send(ToSwarm::Unsubscribe(topic))
|
||||
.await
|
||||
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
|
||||
}
|
||||
async fn send(&self, topic: String, payload: Py<PyBytes>) -> PyResult<()> {
|
||||
// this function attaches to the python interpreter synchronously to avoid holding the GIL
|
||||
let bytes = Python::attach(|py| Vec::from(payload.bind(py).as_bytes()));
|
||||
self.to_swarm
|
||||
.send(ToSwarm::Message(topic, bytes))
|
||||
.await
|
||||
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
async fn recv(
|
||||
&self,
|
||||
#[pyo3(cancel_handle)] mut cancel: CancelHandle,
|
||||
) -> PyResult<PySwarmEvent> {
|
||||
loop {
|
||||
return match AllowThreads(pin!(
|
||||
self.from_swarm
|
||||
.try_lock()
|
||||
.map_err(|_| PyRuntimeError::new_err("tried to recv twice"))?
|
||||
.recv()
|
||||
.or(async {
|
||||
cancel.cancelled().await;
|
||||
None
|
||||
})
|
||||
))
|
||||
/// Publishes a message with multiple topics to the `GossipSub` network.
|
||||
///
|
||||
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors => ignore messageID for now!!!
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
{
|
||||
Some(FromSwarm::PublishError(p)) => match p {
|
||||
PublishError::AllQueuesFull(_) => {
|
||||
Err(PyConnectionError::new_err("swarm overloaded"))
|
||||
}
|
||||
PublishError::MessageTooLarge => {
|
||||
Err(PyValueError::new_err("message too large"))
|
||||
}
|
||||
PublishError::NoPeersSubscribedToTopic => {
|
||||
continue;
|
||||
}
|
||||
// TODO(evan): logs here
|
||||
_ => continue,
|
||||
},
|
||||
None => Err(PyRuntimeError::new_err("swarm communication closed")),
|
||||
Some(fs) => Ok(PySwarmEvent(fs)),
|
||||
};
|
||||
}
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.map(|(t, d)| (t, d.pybytes())))
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
// Manually submit the run()/recv() stub because the cancelhandle is poorly understood
|
||||
submit! {
|
||||
gen_methods_from_python! {
|
||||
r#"
|
||||
class PyPeer:
|
||||
async def run(self): ...
|
||||
async def recv(self) -> PySwarmEvent: ...
|
||||
"#
|
||||
}
|
||||
}
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass]
|
||||
pub struct PySwarmEvent(FromSwarm);
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PySwarmEvent {
|
||||
// probably a better way to do this, but...
|
||||
fn downcast_discovered(&self) -> Option<String> {
|
||||
if let FromSwarm::Discovered(peer_id) = self.0 {
|
||||
Some(peer_id.to_base58())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn downcast_expired(&self) -> Option<String> {
|
||||
if let FromSwarm::Expired(peer_id) = self.0 {
|
||||
Some(peer_id.to_base58())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn downcast_message<'py>(
|
||||
&self,
|
||||
py: Python<'py>,
|
||||
) -> Option<(String, String, Bound<'py, PyBytes>)> {
|
||||
if let FromSwarm::Message(peer_id, topic, data) = &self.0 {
|
||||
Some((peer_id.to_base58(), topic.clone(), PyBytes::new(py, data)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
54
rust/exo_pyo3_bindings/tests/dummy.rs
Normal file
54
rust/exo_pyo3_bindings/tests/dummy.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::mem::drop;
|
||||
use core::option::Option::Some;
|
||||
use core::time::Duration;
|
||||
use tokio;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_channel() {
|
||||
struct Ping;
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Ping>(10);
|
||||
|
||||
let _ = tokio::spawn(async move {
|
||||
println!("TASK: entered");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = rx.recv() => {
|
||||
match result {
|
||||
Some(_) => {
|
||||
println!("TASK: pinged");
|
||||
}
|
||||
None => {
|
||||
println!("TASK: closing channel");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
|
||||
println!("TASK: heartbeat");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("TASK: exited");
|
||||
});
|
||||
|
||||
let tx2 = tx.clone();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx.send(Ping).await.expect("Should not fail");
|
||||
drop(tx);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx2.send(Ping).await.expect("Should not fail");
|
||||
drop(tx2);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
}
|
||||
}
|
||||
@@ -13,14 +13,32 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
@@ -1,6 +1,6 @@
|
||||
use libp2p::identity;
|
||||
use networking::{self, FromSwarm, ToSwarm};
|
||||
use tokio::sync::mpsc;
|
||||
use futures::stream::StreamExt as _;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
@@ -12,51 +12,63 @@ async fn main() {
|
||||
.try_init();
|
||||
|
||||
// Configure swarm
|
||||
let (to_client, mut from_swarm) = mpsc::channel(20);
|
||||
let (to_swarm, from_client) = mpsc::channel(20);
|
||||
let mut peer = networking::Peer::new(
|
||||
"chatroom!".to_string(),
|
||||
identity::Keypair::generate_ed25519(),
|
||||
to_client,
|
||||
from_client,
|
||||
)
|
||||
.expect("listen error");
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
let jh = tokio::spawn(async move { peer.run().await });
|
||||
_ = to_swarm
|
||||
.send(ToSwarm::Subscribe("chatting".to_string()))
|
||||
.await;
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
_ = to_swarm.send(ToSwarm::Message("chatting".to_string(), line.into_bytes())).await;
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = from_swarm.recv() => match event {
|
||||
event = swarm.select_next_some() => match event {
|
||||
// on gossipsub incoming
|
||||
Some(FromSwarm::Message(peer_id,_, data)) => println!(
|
||||
"\n\nGot message: '{}' from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&data),
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Some(FromSwarm::Discovered(peer_id)) => {
|
||||
println!("\n\nConnected to: {peer_id}\n\n");
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
Some(FromSwarm::Expired(peer_id)) => {
|
||||
println!("\n\nDisconnected from: {peer_id}\n\n");
|
||||
}
|
||||
Some(FromSwarm::PublishError(e)) => eprintln!("\n\nError {e:?}\n\n"),
|
||||
None => break,
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = jh.await;
|
||||
}
|
||||
|
||||
127
rust/networking/examples/chatroom_manual.rs
Normal file
127
rust/networking/examples/chatroom_manual.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
44
rust/networking/src/RESEARCH_NOTES.txt
Normal file
44
rust/networking/src/RESEARCH_NOTES.txt
Normal file
@@ -0,0 +1,44 @@
|
||||
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
|
||||
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
|
||||
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
|
||||
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
|
||||
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
|
||||
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
|
||||
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
|
||||
|
||||
https://stackoverflow.com/questions/17169298/af-packet-on-osx
|
||||
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
|
||||
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
|
||||
|
||||
https://www.unix.com/man_page/mojave/4/ip/
|
||||
^OS X manpages for IP
|
||||
|
||||
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
|
||||
^driver kit, system extensions & kexts for macOS
|
||||
|
||||
----
|
||||
|
||||
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
|
||||
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
|
||||
----> here is a guide on how to set up thunderbolt-ethernet on linux
|
||||
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
|
||||
|
||||
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
|
||||
^GPT discussion about making socket interface
|
||||
|
||||
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
|
||||
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
|
||||
^GPT discussion about accessing TB on MacOS low level interactions
|
||||
|
||||
--------------------------------
|
||||
|
||||
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
|
||||
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
|
||||
|
||||
|
||||
---------------------------------
|
||||
|
||||
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
|
||||
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
|
||||
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
|
||||
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c
|
||||
383
rust/networking/src/discovery.rs
Normal file
383
rust/networking/src/discovery.rs
Normal file
@@ -0,0 +1,383 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{
|
||||
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
|
||||
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
|
||||
THandlerOutEvent, ToSwarm, dummy,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
mdns: mdns_behaviour(keypair)?,
|
||||
ping: ping_behaviour(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
ttl: MDNS_RECORD_TTL,
|
||||
query_interval: MDNS_QUERY_INTERVAL,
|
||||
|
||||
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(
|
||||
ping::Config::new()
|
||||
.with_timeout(PING_TIMEOUT)
|
||||
.with_interval(PING_INTERVAL),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Events for when a listening connection is truly established and truly closed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
ConnectionEstablished {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
ConnectionClosed {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
|
||||
///
|
||||
/// The behaviour operates as such:
|
||||
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
|
||||
/// to the swarm.
|
||||
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
|
||||
/// immediately, and expired but connected peers are disconnected from immediately.
|
||||
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
|
||||
/// connected peers are disconnected from.
|
||||
pub struct Behaviour {
|
||||
// state-tracking for managed behaviors & mDNS-discovered peers
|
||||
managed: managed::Behaviour,
|
||||
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
|
||||
|
||||
retry_delay: Delay, // retry interval
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
managed: managed::Behaviour::new(keypair)?,
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
// push front to make this IMMEDIATE
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
|
||||
// get peer's multi-addresses or insert if missing
|
||||
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
|
||||
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
|
||||
continue;
|
||||
};
|
||||
|
||||
// multiaddress should never already be present - else something has gone wrong
|
||||
let is_new_addr = mas.insert(ma);
|
||||
assert!(is_new_addr, "cannot discover a discovered peer");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
// at this point, we *must* have the peer
|
||||
let mas = self
|
||||
.mdns_discovered
|
||||
.get_mut(&p)
|
||||
.expect("nonexistent peer cannot expire");
|
||||
|
||||
// at this point, we *must* have the multiaddress
|
||||
let was_present = mas.remove(&ma);
|
||||
assert!(was_present, "nonexistent multiaddress cannot expire");
|
||||
|
||||
// if empty, remove the peer-id entirely
|
||||
if mas.is_empty() {
|
||||
self.mdns_discovered.remove(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connection_established(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out connected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
|
||||
fn on_connection_closed(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out disconnected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkBehaviour for Behaviour {
|
||||
type ConnectionHandler =
|
||||
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
|
||||
type ToSwarm = Event;
|
||||
|
||||
// simply delegate to underlying mDNS behaviour
|
||||
|
||||
delegate! {
|
||||
to self.managed {
|
||||
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
|
||||
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_established_inbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
local_addr: &Multiaddr,
|
||||
remote_addr: &Multiaddr,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_inbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_question_mark)]
|
||||
fn handle_established_outbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
addr: &Multiaddr,
|
||||
role_override: Endpoint,
|
||||
port_use: PortUse,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_outbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
addr,
|
||||
role_override,
|
||||
port_use,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn on_connection_handler_event(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
event: THandlerOutEvent<Self>,
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => {
|
||||
self.managed
|
||||
.on_connection_handler_event(peer_id, connection_id, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hook into these methods to drive behavior
|
||||
|
||||
fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
self.managed.on_swarm_event(event); // let mDNS handle swarm events
|
||||
|
||||
// handle swarm events to update internal state:
|
||||
match event {
|
||||
FromSwarm::ConnectionEstablished(ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
|
||||
// since we are running TCP/IP transport layer, we are assuming that
|
||||
// no address changes can occur, hence encountering one is a fatal error
|
||||
FromSwarm::AddressChange(a) => {
|
||||
unreachable!("unhandlable: address change encountered: {:?}", a)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// delegate to managed behaviors for any behaviors they need to perform
|
||||
match self.managed.poll(cx) {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
}
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// since we just consumed an event, we should immediately wake just in case
|
||||
// there are more events to come where that came from
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
e.map_out(|_| unreachable!("events returning to swarm already handled"))
|
||||
.map_in(Either::Right),
|
||||
);
|
||||
}
|
||||
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
}
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
if let Some(e) = self.pending_events.pop_front(cx) {
|
||||
return Poll::Ready(e.map_in(Either::Left));
|
||||
}
|
||||
|
||||
// wait for pending events
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
44
rust/networking/src/keep_alive.rs
Normal file
44
rust/networking/src/keep_alive.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -1,299 +1,64 @@
|
||||
use libp2p::{
|
||||
Multiaddr, PeerId,
|
||||
futures::StreamExt,
|
||||
gossipsub::{self, TopicHash},
|
||||
identify,
|
||||
identity::Keypair,
|
||||
mdns,
|
||||
swarm::{NetworkBehaviour, SwarmEvent, dial_opts::DialOpts},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc;
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ListenError;
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
pub enum FromSwarm {
|
||||
PublishError(gossipsub::PublishError),
|
||||
Discovered(PeerId),
|
||||
Expired(PeerId),
|
||||
Message(PeerId, String, Vec<u8>),
|
||||
}
|
||||
pub enum ToSwarm {
|
||||
Message(String, Vec<u8>),
|
||||
Subscribe(String),
|
||||
Unsubscribe(String),
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
pub struct Peer {
|
||||
pub swarm: libp2p::Swarm<Behaviour>,
|
||||
to_client: mpsc::Sender<FromSwarm>,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
namespace: String,
|
||||
known_peers: HashMap<PeerId, Vec<Multiaddr>>,
|
||||
}
|
||||
impl Peer {
|
||||
pub fn new(
|
||||
namespace: String,
|
||||
kp: Keypair,
|
||||
to_client: mpsc::Sender<FromSwarm>,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
) -> Result<Self, ListenError> {
|
||||
let mut swarm = libp2p::SwarmBuilder::with_existing_identity(kp)
|
||||
.with_tokio()
|
||||
.with_quic()
|
||||
// TODO(evan) .with_bandwidth_metrics()
|
||||
.with_behaviour(|kp| Behaviour::new(namespace.clone(), kp))
|
||||
.expect("invalid swarm behaviour")
|
||||
.build();
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use extend::ext;
|
||||
use libp2p::Multiaddr;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
use std::net::IpAddr;
|
||||
|
||||
swarm
|
||||
.listen_on("/ip6/::/udp/0/quic-v1".parse().expect("invalid multiaddr"))
|
||||
.map_err(|_| ListenError)?;
|
||||
swarm
|
||||
.listen_on(
|
||||
"/ip4/0.0.0.0/udp/0/quic-v1"
|
||||
.parse()
|
||||
.expect("invalid multiaddr"),
|
||||
)
|
||||
.map_err(|_| ListenError)?;
|
||||
Ok(Self {
|
||||
swarm,
|
||||
to_client,
|
||||
from_client,
|
||||
namespace,
|
||||
known_peers: HashMap::default(),
|
||||
})
|
||||
}
|
||||
pub async fn run(&mut self) -> Result<(), ()> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = self.swarm.next() => self.handle_event(event.ok_or(())?).await?,
|
||||
msg = self.from_client.recv() => self.handle_message(msg.ok_or(())?).await?,
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn handle_message(&mut self, message: ToSwarm) -> Result<(), ()> {
|
||||
match message {
|
||||
ToSwarm::Message(topic, data) => {
|
||||
if let Err(e) = self
|
||||
.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.publish(TopicHash::from_raw(topic), data)
|
||||
{
|
||||
self.to_client
|
||||
.send(FromSwarm::PublishError(e))
|
||||
.await
|
||||
.map_err(|_| ())?;
|
||||
#[ext(pub, name = MultiaddrExt)]
|
||||
impl Multiaddr {
|
||||
/// If the multiaddress corresponds to a TCP address, extracts it
|
||||
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
|
||||
let mut ps = self.into_iter();
|
||||
let ip = if let Some(p) = ps.next() {
|
||||
match p {
|
||||
Protocol::Ip4(ip) => IpAddr::V4(ip),
|
||||
Protocol::Ip6(ip) => IpAddr::V6(ip),
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
ToSwarm::Subscribe(topic) => {
|
||||
match self
|
||||
.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&gossipsub::IdentTopic::new(topic))
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(gossipsub::SubscriptionError::NotAllowed) => {
|
||||
unreachable!("subscription filter hit")
|
||||
}
|
||||
Err(gossipsub::SubscriptionError::PublishError(e)) => self
|
||||
.to_client
|
||||
.send(FromSwarm::PublishError(e))
|
||||
.await
|
||||
.map_err(|_| ())?,
|
||||
}
|
||||
}
|
||||
ToSwarm::Unsubscribe(topic) => {
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.unsubscribe(&gossipsub::IdentTopic::new(topic));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
async fn handle_event(&mut self, event: SwarmEvent<BehaviourEvent>) -> Result<(), ()> {
|
||||
let SwarmEvent::Behaviour(event) = event else {
|
||||
return Ok(());
|
||||
};
|
||||
match event {
|
||||
BehaviourEvent::Gossipsub(gossipsub::Event::Message { message, .. }) => {
|
||||
if let Some(source) = message.source {
|
||||
self.to_client
|
||||
.send(FromSwarm::Message(
|
||||
source,
|
||||
message.topic.into_string(),
|
||||
message.data,
|
||||
))
|
||||
.await
|
||||
.map_err(|_| ())?;
|
||||
}
|
||||
}
|
||||
BehaviourEvent::Identify(identify::Event::Received { peer_id, info, .. }) => {
|
||||
log::debug!(
|
||||
"identify from {peer_id}: protocol_version='{}' agent_version='{}' (local namespace='{}')",
|
||||
info.protocol_version,
|
||||
info.agent_version,
|
||||
self.namespace
|
||||
);
|
||||
if info.protocol_version == self.namespace {
|
||||
self.passed_namespace(peer_id);
|
||||
self.to_client
|
||||
.send(FromSwarm::Discovered(peer_id))
|
||||
.await
|
||||
.map_err(|_| ())?;
|
||||
} else {
|
||||
self.failed_namespace(peer_id);
|
||||
}
|
||||
}
|
||||
BehaviourEvent::Mdns(mdns::Event::Discovered(v)) => {
|
||||
for (peer_id, addr) in v {
|
||||
self.known_peers.entry(peer_id).or_default().push(addr);
|
||||
}
|
||||
for (peer_id, addrs) in &self.known_peers {
|
||||
// dialopts handles rate limiting, we should check errors if we want to blacklist earlier
|
||||
let _ = self
|
||||
.swarm
|
||||
.dial(DialOpts::peer_id(*peer_id).addresses(addrs.clone()).build());
|
||||
}
|
||||
}
|
||||
BehaviourEvent::Mdns(mdns::Event::Expired(v)) => {
|
||||
for (peer_id, addr) in v {
|
||||
let addrs = self.known_peers.entry(peer_id).or_default();
|
||||
addrs.retain(|a| *a != addr);
|
||||
if addrs.is_empty() {
|
||||
self.known_peers.remove(&peer_id);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.remove_explicit_peer(&peer_id);
|
||||
self.to_client
|
||||
.send(FromSwarm::Expired(peer_id))
|
||||
.await
|
||||
.map_err(|_| ())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn passed_namespace(&mut self, peer_id: PeerId) {
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.remove_blacklisted_peer(&peer_id);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.add_explicit_peer(&peer_id);
|
||||
}
|
||||
fn failed_namespace(&mut self, peer_id: PeerId) {
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.blacklist_peer(&peer_id);
|
||||
self.swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
identify: identify::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
fn new(namespace: String, kp: &Keypair) -> Self {
|
||||
let mdns = mdns::Behaviour::new(mdns::Config::default(), kp.public().to_peer_id())
|
||||
.expect("mdns behaviour failed to build");
|
||||
|
||||
let identify =
|
||||
identify::Behaviour::new(identify::Config::new_with_signed_peer_record(namespace, kp));
|
||||
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(kp.clone()),
|
||||
gossipsub::ConfigBuilder::default()
|
||||
.max_transmit_size(1024 * 1024)
|
||||
.validation_mode(gossipsub::ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("invalid gossipsub configuration"),
|
||||
)
|
||||
.expect("gossipsub behaviour failed ot build");
|
||||
|
||||
Self {
|
||||
gossipsub,
|
||||
mdns,
|
||||
identify,
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
let Some(Protocol::Tcp(port)) = ps.next() else {
|
||||
return None;
|
||||
};
|
||||
Some((ip, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: more tests
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::time::{Duration, timeout};
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
fn make_peer(namespace: &str) -> (Peer, mpsc::Receiver<FromSwarm>, mpsc::Sender<ToSwarm>) {
|
||||
let kp = Keypair::generate_ed25519();
|
||||
|
||||
let (to_client_tx, to_client_rx) = mpsc::channel(64);
|
||||
let (to_peer_tx, to_peer_rx) = mpsc::channel(64);
|
||||
|
||||
let peer = Peer::new(namespace.to_string(), kp, to_client_tx, to_peer_rx)
|
||||
.expect("Peer::new should succeed in tests");
|
||||
|
||||
(peer, to_client_rx, to_peer_tx)
|
||||
}
|
||||
|
||||
async fn next_listen_addr(peer: &mut Peer) -> Multiaddr {
|
||||
loop {
|
||||
match peer.swarm.next().await {
|
||||
Some(SwarmEvent::NewListenAddr { address, .. }) => return address,
|
||||
Some(_) => {}
|
||||
None => panic!("swarm stream ended unexpectedly"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_and_unsubscribe_do_not_error() {
|
||||
let (mut peer, mut events_rx, commands_tx) = make_peer("ns-test");
|
||||
|
||||
// Drive the swarm just enough to get at least one listen address event,
|
||||
// so the background run loop has something initialized.
|
||||
let _addr = next_listen_addr(&mut peer).await;
|
||||
|
||||
// Run the peer loop in the background.
|
||||
let handle = tokio::spawn(async move {
|
||||
let _ = peer.run().await;
|
||||
});
|
||||
|
||||
commands_tx
|
||||
.send(ToSwarm::Subscribe("topic-a".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
commands_tx
|
||||
.send(ToSwarm::Unsubscribe("topic-a".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We don't *require* any FromSwarm events here; this is mainly a
|
||||
// smoke test that the message-handling path doesn't panic/hang.
|
||||
// Still, poll briefly to ensure the task is alive.
|
||||
let _ = timeout(Duration::from_millis(200), events_rx.recv()).await;
|
||||
|
||||
// Shut down: dropping the command sender closes the channel, causing run() to return Err.
|
||||
drop(commands_tx);
|
||||
let _ = handle.await;
|
||||
}
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
145
rust/networking/src/swarm.rs
Normal file
145
rust/networking/src/swarm.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
///
|
||||
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
|
||||
/// even be, and how to inject the right version into this config/initialization. E.g. should
|
||||
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
|
||||
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
.with_behaviour(Behaviour::new)?
|
||||
.build();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
}
|
||||
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
use libp2p::pnet::{PnetError, PnetOutput};
|
||||
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
|
||||
use std::{env, sync::LazyLock};
|
||||
|
||||
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
|
||||
/// See [`pnet_upgrade`] for more.
|
||||
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
|
||||
let builder = Sha3_256::new().update(b"exo_discovery_network");
|
||||
|
||||
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
|
||||
let bytes = var.into_bytes();
|
||||
builder.update(&bytes)
|
||||
} else {
|
||||
builder.update(NETWORK_VERSION)
|
||||
}
|
||||
.finalize()
|
||||
});
|
||||
|
||||
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
|
||||
/// also different-versioned instances of this same network.
|
||||
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
|
||||
async fn pnet_upgrade<TSocket>(
|
||||
socket: TSocket,
|
||||
_: impl Sized,
|
||||
) -> Result<PnetOutput<TSocket>, PnetError>
|
||||
where
|
||||
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
use pnet::{PnetConfig, PreSharedKey};
|
||||
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
|
||||
.handshake(socket)
|
||||
.await
|
||||
}
|
||||
|
||||
/// TCP/IP transport layer configuration.
|
||||
pub fn tcp_transport(
|
||||
keypair: &identity::Keypair,
|
||||
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
|
||||
use libp2p::{
|
||||
core::upgrade::Version,
|
||||
tcp::{Config, tokio},
|
||||
};
|
||||
|
||||
// `TCP_NODELAY` enabled => avoid latency
|
||||
let tcp_config = Config::default().nodelay(true);
|
||||
|
||||
// V1 + lazy flushing => 0-RTT negotiation
|
||||
let upgrade_version = Version::V1Lazy;
|
||||
|
||||
// Noise is faster than TLS + we don't care much for security
|
||||
let noise_config = noise::Config::new(keypair)?;
|
||||
|
||||
// Use default Yamux config for multiplexing
|
||||
let yamux_config = yamux::Config::default();
|
||||
|
||||
// Create new Tokio-driven TCP/IP transport layer
|
||||
let base_transport = tokio::Transport::new(tcp_config)
|
||||
.and_then(pnet_upgrade)
|
||||
.upgrade(upgrade_version)
|
||||
.authenticate(noise_config)
|
||||
.multiplex(yamux_config);
|
||||
|
||||
// Return boxed transport (to flatten complex type)
|
||||
Ok(base_transport.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
pub discovery: discovery::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
|
||||
Ok(Self {
|
||||
discovery: discovery::Behaviour::new(keypair)?,
|
||||
gossipsub: gossipsub_behaviour(keypair),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
|
||||
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
// => signed message authenticity + strict validation mode means the message-ID is
|
||||
// automatically provided by gossipsub w/out needing to provide custom message-ID function
|
||||
gossipsub::Behaviour::new(
|
||||
MessageAuthenticity::Signed(keypair.clone()),
|
||||
ConfigBuilder::default()
|
||||
.publish_queue_duration(Duration::from_secs(15))
|
||||
.max_transmit_size(1024 * 1024)
|
||||
.validation_mode(ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("the configuration should always be valid"),
|
||||
)
|
||||
.expect("creating gossipsub behavior should always work")
|
||||
}
|
||||
}
|
||||
7
rust/networking/tests/dummy.rs
Normal file
7
rust/networking/tests/dummy.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
// maybe this will hold test in the future...??
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn does_nothing() {}
|
||||
}
|
||||
2
rust/rust-toolchain.toml
Normal file
2
rust/rust-toolchain.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
25
rust/util/Cargo.toml
Normal file
25
rust/util/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "util"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "util"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
thiserror = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
internment = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
bon = { workspace = true }
|
||||
recursion = { workspace = true }
|
||||
53
rust/util/src/lib.rs
Normal file
53
rust/util/src/lib.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub mod nonempty;
|
||||
pub mod wakerdeque;
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub mod ext {
|
||||
use extend::ext;
|
||||
|
||||
#[ext(pub, name = BoxedSliceExt)]
|
||||
impl<T> Box<[T]> {
|
||||
#[inline]
|
||||
fn map<B, F>(self, f: F) -> Box<[B]>
|
||||
where
|
||||
F: FnMut(T) -> B,
|
||||
{
|
||||
self.into_iter().map(f).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = VecExt)]
|
||||
impl<T> Vec<T> {
|
||||
#[inline]
|
||||
fn map<B, F>(self, f: F) -> Vec<B>
|
||||
where
|
||||
F: FnMut(T) -> B,
|
||||
{
|
||||
self.into_iter().map(f).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
138
rust/util/src/nonempty.rs
Normal file
138
rust/util/src/nonempty.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use std::slice::SliceIndex;
|
||||
use std::{ops, slice};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")]
|
||||
pub struct EmptySliceError;
|
||||
|
||||
/// A pointer to a non-empty fixed-size slice allocated on the heap.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
#[repr(transparent)]
|
||||
pub struct NonemptyArray<T>(Box<[T]>);
|
||||
|
||||
#[allow(clippy::arbitrary_source_item_ordering)]
|
||||
impl<T> NonemptyArray<T> {
|
||||
#[inline]
|
||||
pub fn singleton(value: T) -> Self {
|
||||
Self(Box::new([value]))
|
||||
}
|
||||
|
||||
#[allow(clippy::missing_errors_doc)]
|
||||
#[inline]
|
||||
pub fn try_from_boxed_slice<S: Into<Box<[T]>>>(
|
||||
boxed_slice: S,
|
||||
) -> Result<Self, EmptySliceError> {
|
||||
let boxed_slice = boxed_slice.into();
|
||||
if boxed_slice.is_empty() {
|
||||
Err(EmptySliceError)
|
||||
} else {
|
||||
Ok(Self(boxed_slice))
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn into_boxed_slice(self) -> Box<[T]> {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn to_vec(&self) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.0.to_vec()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub const fn as_slice(&self) -> &[T] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn first(&self) -> &T {
|
||||
&self.0[0]
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn last(&self) -> &T {
|
||||
&self.0[self.0.len() - 1]
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn get<I>(&self, index: I) -> Option<&I::Output>
|
||||
where
|
||||
I: SliceIndex<[T]>,
|
||||
{
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub const fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
#[allow(clippy::iter_without_into_iter)]
|
||||
#[inline]
|
||||
pub fn iter(&self) -> slice::Iter<'_, T> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
#[allow(clippy::iter_without_into_iter)]
|
||||
#[inline]
|
||||
pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> NonemptyArray<U> {
|
||||
NonemptyArray(self.0.into_iter().map(f).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<NonemptyArray<T>> for Box<[T]> {
|
||||
#[inline]
|
||||
fn from(value: NonemptyArray<T>) -> Self {
|
||||
value.into_boxed_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ops::Index<usize> for NonemptyArray<T> {
|
||||
type Output = T;
|
||||
|
||||
#[inline]
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
self.0.index(index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoIterator for NonemptyArray<T> {
|
||||
type Item = T;
|
||||
type IntoIter = std::vec::IntoIter<T>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.into_boxed_slice().into_vec().into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> IntoIterator for &'a NonemptyArray<T> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = slice::Iter<'a, T>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
55
rust/util/src/wakerdeque.rs
Normal file
55
rust/util/src/wakerdeque.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
|
||||
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
|
||||
pub struct WakerDeque<T> {
|
||||
waker: Option<Waker>,
|
||||
deque: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Debug for WakerDeque<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.deque.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WakerDeque<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
waker: None,
|
||||
deque: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, cx: &mut Context<'_>) {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
fn wake(&mut self) {
|
||||
let Some(ref mut w) = self.waker else { return };
|
||||
w.wake_by_ref();
|
||||
self.waker = None;
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_back()
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_front(value);
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_back(value);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import importlib.metadata
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
@@ -46,9 +45,9 @@ class Node:
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_string())
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(keypair, namespace=args.namespace)
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
await router.register_topic(topics.LOCAL_EVENTS)
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
@@ -74,7 +73,7 @@ class Node:
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
if not args.no_api:
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
session_id,
|
||||
@@ -254,7 +253,7 @@ def main():
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"Namespace: {args.namespace}")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
@@ -271,13 +270,13 @@ def main():
|
||||
|
||||
|
||||
class Args(CamelCaseModel):
|
||||
verbosity: int
|
||||
force_master: bool
|
||||
no_api: bool
|
||||
api_port: PositiveInt
|
||||
verbosity: int = 0
|
||||
force_master: bool = False
|
||||
spawn_api: bool = False
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
namespace: str
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -307,15 +306,14 @@ class Args(CamelCaseModel):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-api",
|
||||
action="store_true",
|
||||
help="Disable the API server for this node",
|
||||
action="store_false",
|
||||
dest="spawn_api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-port",
|
||||
type=int,
|
||||
dest="api_port",
|
||||
default=52415,
|
||||
help="Which port the API server will be available on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-worker",
|
||||
@@ -326,11 +324,6 @@ class Args(CamelCaseModel):
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--namespace",
|
||||
default=importlib.metadata.version("exo"),
|
||||
help="Set the EXO namespace to run multiple isolated clusters",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from datetime import datetime, timezone
|
||||
@@ -150,6 +151,15 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
return f"image/{image_format or 'png'}"
|
||||
|
||||
|
||||
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
|
||||
"""Ensure advanced params has a seed set for distributed consistency."""
|
||||
if params is None:
|
||||
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
|
||||
if params.seed is None:
|
||||
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
|
||||
return params
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -709,6 +719,9 @@ class API:
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
task_params=payload,
|
||||
@@ -957,6 +970,9 @@ class API:
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
task_params=payload,
|
||||
@@ -988,6 +1004,7 @@ class API:
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
advanced_params = _ensure_seed(advanced_params)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
@@ -369,7 +369,7 @@ class Master:
|
||||
await self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.trace(f"Master indexing event: {str(event)[:100]}")
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
Sharding,
|
||||
ShardMetadata,
|
||||
@@ -74,40 +75,43 @@ def allocate_layers_proportionally(
|
||||
return result
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
):
|
||||
def _validate_cycle(cycle: Cycle) -> None:
|
||||
if not cycle.node_ids:
|
||||
raise ValueError("Cannot create shard assignments for empty node cycle")
|
||||
|
||||
cycle_memory = sum(
|
||||
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
|
||||
|
||||
def _compute_total_memory(
|
||||
node_ids: list[NodeId],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
) -> Memory:
|
||||
total_memory = sum(
|
||||
(node_memory[node_id].ram_available for node_id in node_ids),
|
||||
start=Memory(),
|
||||
)
|
||||
if cycle_memory.in_bytes == 0:
|
||||
if total_memory.in_bytes == 0:
|
||||
raise ValueError("Cannot create shard assignments: total available memory is 0")
|
||||
return total_memory
|
||||
|
||||
total_layers = model_card.n_layers
|
||||
world_size = len(cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
def _allocate_and_validate_layers(
|
||||
node_ids: list[NodeId],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
total_memory: Memory,
|
||||
model_card: ModelCard,
|
||||
) -> list[int]:
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=total_layers,
|
||||
total_layers=model_card.n_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node_id in cycle.node_ids
|
||||
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
|
||||
for node_id in node_ids
|
||||
],
|
||||
)
|
||||
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_card.storage_size.in_bytes / total_layers
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
required_memory = node_layers * memory_per_layer
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
total_layers = model_card.n_layers
|
||||
for i, node_id in enumerate(node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
@@ -116,32 +120,125 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
layers_assigned = 0
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
runner_id = RunnerId()
|
||||
return layer_allocations
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
) -> ShardAssignments:
|
||||
"""Create shard assignments for pipeline parallel execution."""
|
||||
world_size = len(cycle)
|
||||
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
|
||||
|
||||
if use_cfg_parallel:
|
||||
return _get_shard_assignments_for_cfg_parallel(model_card, cycle, node_memory)
|
||||
else:
|
||||
return _get_shard_assignments_for_pure_pipeline(model_card, cycle, node_memory)
|
||||
|
||||
|
||||
def _get_shard_assignments_for_cfg_parallel(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
) -> ShardAssignments:
|
||||
"""Create shard assignments for CFG parallel execution.
|
||||
|
||||
CFG parallel runs two independent pipelines. Group 0 processes the positive
|
||||
prompt, group 1 processes the negative prompt. The ring topology places
|
||||
group 1's ranks in reverse order so both "last stages" are neighbors for
|
||||
efficient CFG exchange.
|
||||
"""
|
||||
_validate_cycle(cycle)
|
||||
|
||||
world_size = len(cycle)
|
||||
cfg_world_size = 2
|
||||
pipeline_world_size = world_size // cfg_world_size
|
||||
|
||||
# Allocate layers for one pipeline group (both groups run the same layers)
|
||||
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
|
||||
pipeline_memory = _compute_total_memory(pipeline_node_ids, node_memory)
|
||||
layer_allocations = _allocate_and_validate_layers(
|
||||
pipeline_node_ids, node_memory, pipeline_memory, model_card
|
||||
)
|
||||
|
||||
# Ring topology: group 0 ascending [0,1,2,...], group 1 descending [...,2,1,0]
|
||||
# This places both last stages as neighbors for CFG exchange.
|
||||
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
|
||||
(1, r) for r in reversed(range(pipeline_world_size))
|
||||
]
|
||||
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
for device_rank, node_id in enumerate(cycle.node_ids):
|
||||
cfg_rank, pipeline_rank = position_to_cfg_pipeline[device_rank]
|
||||
layers_before = sum(layer_allocations[:pipeline_rank])
|
||||
node_layers = layer_allocations[pipeline_rank]
|
||||
|
||||
shard = CfgShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=i,
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
start_layer=layers_assigned,
|
||||
end_layer=layers_assigned + node_layers,
|
||||
n_layers=total_layers,
|
||||
start_layer=layers_before,
|
||||
end_layer=layers_before + node_layers,
|
||||
n_layers=model_card.n_layers,
|
||||
cfg_rank=cfg_rank,
|
||||
cfg_world_size=cfg_world_size,
|
||||
pipeline_rank=pipeline_rank,
|
||||
pipeline_world_size=pipeline_world_size,
|
||||
)
|
||||
|
||||
runner_id = RunnerId()
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
return ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
return shard_assignments
|
||||
|
||||
def _get_shard_assignments_for_pure_pipeline(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
) -> ShardAssignments:
|
||||
"""Create shard assignments for pure pipeline execution."""
|
||||
_validate_cycle(cycle)
|
||||
total_memory = _compute_total_memory(cycle.node_ids, node_memory)
|
||||
|
||||
layer_allocations = _allocate_and_validate_layers(
|
||||
cycle.node_ids, node_memory, total_memory, model_card
|
||||
)
|
||||
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
for pipeline_rank, node_id in enumerate(cycle.node_ids):
|
||||
layers_before = sum(layer_allocations[:pipeline_rank])
|
||||
node_layers = layer_allocations[pipeline_rank]
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=pipeline_rank,
|
||||
world_size=len(cycle),
|
||||
start_layer=layers_before,
|
||||
end_layer=layers_before + node_layers,
|
||||
n_layers=model_card.n_layers,
|
||||
)
|
||||
|
||||
runner_id = RunnerId()
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
|
||||
return ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
|
||||
def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
@@ -41,7 +41,7 @@ from exo.utils.channels import channel
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_string())
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
@@ -72,7 +72,7 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_string()}_sender")
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
await local_event_sender.send(
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
@@ -20,7 +21,11 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
Sharding,
|
||||
)
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@@ -487,3 +492,193 @@ def test_get_shard_assignments_insufficient_memory_raises():
|
||||
get_shard_assignments(
|
||||
model_card, selected_cycle, Sharding.Pipeline, node_memory
|
||||
)
|
||||
|
||||
|
||||
class TestCfgParallelPlacement:
|
||||
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
|
||||
topology = Topology()
|
||||
for node_id in node_ids:
|
||||
topology.add_node(node_id)
|
||||
|
||||
for i, node_id in enumerate(node_ids):
|
||||
next_node = node_ids[(i + 1) % len(node_ids)]
|
||||
conn = Connection(
|
||||
source=node_id,
|
||||
sink=next_node,
|
||||
edge=create_socket_connection(i + 1),
|
||||
)
|
||||
topology.add_connection(conn)
|
||||
|
||||
return topology
|
||||
|
||||
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
|
||||
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
# CFG models should get CfgShardMetadata
|
||||
for shard in shards:
|
||||
assert isinstance(shard, CfgShardMetadata)
|
||||
# Both nodes should have all layers (no pipeline split)
|
||||
assert shard.start_layer == 0
|
||||
assert shard.end_layer == 60
|
||||
assert shard.cfg_world_size == 2
|
||||
# Each node is the only stage in its pipeline group
|
||||
assert shard.pipeline_world_size == 1
|
||||
assert shard.pipeline_rank == 0
|
||||
|
||||
cfg_ranks = sorted(
|
||||
s.cfg_rank for s in shards if isinstance(s, CfgShardMetadata)
|
||||
)
|
||||
assert cfg_ranks == [0, 1]
|
||||
|
||||
def test_four_nodes_cfg_model_uses_hybrid(self):
|
||||
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
|
||||
nodes = [NodeId() for _ in range(4)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 4]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 4
|
||||
|
||||
# CFG models should get CfgShardMetadata
|
||||
for shard in shards:
|
||||
assert isinstance(shard, CfgShardMetadata)
|
||||
assert shard.cfg_world_size == 2
|
||||
assert shard.pipeline_world_size == 2
|
||||
assert shard.pipeline_rank in [0, 1]
|
||||
|
||||
# Check we have 2 nodes in each CFG group
|
||||
cfg_0_shards = [
|
||||
s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 0
|
||||
]
|
||||
cfg_1_shards = [
|
||||
s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 1
|
||||
]
|
||||
assert len(cfg_0_shards) == 2
|
||||
assert len(cfg_1_shards) == 2
|
||||
|
||||
# Both CFG groups should have the same layer assignments
|
||||
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
|
||||
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
|
||||
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
|
||||
|
||||
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
|
||||
"""Three nodes (odd) with CFG model should use sequential CFG (PipelineShardMetadata)."""
|
||||
nodes = [NodeId() for _ in range(3)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 3]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 3
|
||||
|
||||
# Odd node count with CFG model falls back to PipelineShardMetadata (sequential CFG)
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
|
||||
def test_two_nodes_non_cfg_model_uses_pipeline(self):
|
||||
"""Two nodes with non-CFG model should use pure pipeline (PipelineShardMetadata)."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("flux-test"),
|
||||
n_layers=57,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=False, # Non-CFG model
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
# Non-CFG models should get PipelineShardMetadata
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
|
||||
# Should have actual layer sharding (pipeline)
|
||||
layer_ranges = sorted(
|
||||
(s.start_layer, s.end_layer)
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata)
|
||||
)
|
||||
# First shard starts at 0, last shard ends at 57
|
||||
assert layer_ranges[0][0] == 0
|
||||
assert layer_ranges[-1][1] == 57
|
||||
|
||||
@@ -1,9 +1,37 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
|
||||
class ConnectionMessageType(Enum):
|
||||
Connected = 0
|
||||
Disconnected = 1
|
||||
|
||||
@staticmethod
|
||||
def from_update_type(update_type: ConnectionUpdateType):
|
||||
match update_type:
|
||||
case ConnectionUpdateType.Connected:
|
||||
return ConnectionMessageType.Connected
|
||||
case ConnectionUpdateType.Disconnected:
|
||||
return ConnectionMessageType.Disconnected
|
||||
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
expired: bool
|
||||
connection_type: ConnectionMessageType
|
||||
remote_ipv4: str
|
||||
remote_tcp_port: int
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import count
|
||||
from math import inf
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
@@ -13,14 +13,15 @@ from anyio import (
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from exo_pyo3_bindings import (
|
||||
AllQueuesFullError,
|
||||
Keypair,
|
||||
PyPeer,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
)
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -97,32 +98,28 @@ class TopicRouter[T: CamelCaseModel]:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Router:
|
||||
_peer: PyPeer
|
||||
topic_routers: dict[str, TopicRouter[CamelCaseModel]] = field(
|
||||
init=False, default_factory=dict
|
||||
)
|
||||
networking_receiver: Receiver[tuple[str, bytes]] = field(init=False)
|
||||
_tmp_networking_sender: Sender[tuple[str, bytes]] | None = field(init=False)
|
||||
_tg: TaskGroup | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._tmp_networking_sender, self.networking_receiver = channel()
|
||||
|
||||
@classmethod
|
||||
def create(cls, identity: Keypair, namespace: str) -> "Router":
|
||||
return cls(_peer=PyPeer.new(identity, namespace))
|
||||
def create(cls, identity: Keypair) -> "Router":
|
||||
return cls(handle=NetworkingHandle(identity))
|
||||
|
||||
def __init__(self, handle: NetworkingHandle):
|
||||
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
|
||||
send, recv = channel[tuple[str, bytes]]()
|
||||
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
|
||||
self._net: NetworkingHandle = handle
|
||||
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
|
||||
self._id_count = count()
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
|
||||
assert self._tg is None, "Attempted to register topic after setup time"
|
||||
send = self._tmp_networking_sender
|
||||
if send:
|
||||
self._tmp_networking_sender = None
|
||||
else:
|
||||
send = self.networking_receiver.clone_sender()
|
||||
router = TopicRouter[T](topic, send)
|
||||
if self._tg is not None:
|
||||
self._tg.start_soon(router.run)
|
||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||
await self._networking_subscribe(str(topic.topic))
|
||||
|
||||
@@ -149,18 +146,14 @@ class Router:
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
|
||||
async def _peer_run():
|
||||
await self._peer.run()
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
tg.start_soon(_peer_run)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
@@ -174,57 +167,46 @@ class Router:
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._peer.subscribe(topic)
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._peer.unsubscribe(topic)
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
try:
|
||||
swarm_event = await self._peer.recv()
|
||||
except ValueError:
|
||||
logger.error("Message too large for gossipsub, dropped")
|
||||
continue
|
||||
except ConnectionError:
|
||||
logger.error("All peer queues full, network overloaded")
|
||||
continue
|
||||
except RuntimeError:
|
||||
break
|
||||
|
||||
cm = None
|
||||
if (peer_id := swarm_event.downcast_discovered()) is not None:
|
||||
cm = ConnectionMessage(node_id=NodeId(peer_id), expired=False)
|
||||
if (peer_id := swarm_event.downcast_expired()) is not None:
|
||||
cm = ConnectionMessage(node_id=NodeId(peer_id), expired=True)
|
||||
|
||||
if cm is not None:
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(cm)
|
||||
continue
|
||||
|
||||
assert (msg := swarm_event.downcast_message()) is not None
|
||||
_origin, topic, payload = msg
|
||||
logger.debug(f"Received message on {topic} with payload {payload}")
|
||||
topic, data = await self._net.gossipsub_recv()
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(f"Received message on unknown or inactive topic {topic}")
|
||||
continue
|
||||
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(payload)
|
||||
await router.publish_bytes(data)
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
while True:
|
||||
update = await self._net.connection_update_recv()
|
||||
message = ConnectionMessage.from_update(update)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
|
||||
async def _networking_publish(self):
|
||||
with self.networking_receiver as networked_items:
|
||||
async for topic, data in networked_items:
|
||||
try:
|
||||
logger.trace(f"Sending message on {topic} with payload {data}")
|
||||
await self._peer.send(topic, data)
|
||||
except RuntimeError:
|
||||
break
|
||||
await self._net.gossipsub_publish(topic, data)
|
||||
# As a hack, this also catches AllQueuesFull
|
||||
# Need to fix that ASAP.
|
||||
except (NoPeersSubscribedToTopicError, AllQueuesFullError):
|
||||
pass
|
||||
|
||||
|
||||
def get_node_id_keypair(
|
||||
@@ -235,7 +217,7 @@ def get_node_id_keypair(
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate()
|
||||
return Keypair.generate_ed25519()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
@@ -39,7 +39,7 @@ RESOURCES_DIR = (
|
||||
)
|
||||
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
|
||||
DASHBOARD_DIR = (
|
||||
find_dashboard() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
|
||||
find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV
|
||||
)
|
||||
|
||||
# Log files (data/logs or cache)
|
||||
|
||||
@@ -65,9 +65,9 @@ class ComponentInfo(CamelCaseModel):
|
||||
component_name: str
|
||||
component_path: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt | None
|
||||
n_layers: PositiveInt | None = None
|
||||
can_shard: bool
|
||||
safetensors_index_filename: str | None
|
||||
safetensors_index_filename: str | None = None
|
||||
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
@@ -82,6 +82,7 @@ class ModelCard(CamelCaseModel):
|
||||
quantization: str = ""
|
||||
base_model: str = ""
|
||||
capabilities: list[str] = []
|
||||
uses_cfg: bool = False
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -155,87 +156,6 @@ def is_custom_card(model_id: ModelId) -> bool:
|
||||
return os.path.isfile(str(card_path))
|
||||
|
||||
|
||||
# TODO: quantizing and dynamically creating model cards
|
||||
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
# quantizations = [8, 6, 5, 4, 3]
|
||||
quantizations = [8, 4]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
model_id = ModelId(base_card.model_id + f"-{quant}bit")
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
@@ -330,7 +330,9 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(),
|
||||
expired=False,
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4="",
|
||||
remote_tcp_port=0,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import TypeAlias, final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -51,6 +52,7 @@ class BaseShardMetadata(TaggedModel):
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class PipelineShardMetadata(BaseShardMetadata):
|
||||
"""
|
||||
Pipeline parallelism shard meta.
|
||||
@@ -60,8 +62,23 @@ class PipelineShardMetadata(BaseShardMetadata):
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class CfgShardMetadata(BaseShardMetadata):
|
||||
"""Shard metadata for CFG-parallel image generation models."""
|
||||
|
||||
cfg_rank: int # 0 = positive branch, 1 = negative branch
|
||||
cfg_world_size: int = 2
|
||||
|
||||
# Pipeline-relative coordinates (computed at placement time)
|
||||
pipeline_rank: int # rank within the pipeline group (0, 1, 2, ...)
|
||||
pipeline_world_size: int # number of nodes per pipeline group
|
||||
|
||||
|
||||
@final
|
||||
class TensorShardMetadata(BaseShardMetadata):
|
||||
pass
|
||||
|
||||
|
||||
ShardMetadata = PipelineShardMetadata | TensorShardMetadata
|
||||
ShardMetadata: TypeAlias = (
|
||||
PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from PIL import Image
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import AdvancedImageParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
@@ -30,14 +30,19 @@ class DistributedImageModel:
|
||||
self,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
shard_metadata: PipelineShardMetadata | CfgShardMetadata,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
if group is not None:
|
||||
has_layer_sharding = (
|
||||
shard_metadata.start_layer != 0
|
||||
or shard_metadata.end_layer != shard_metadata.n_layers
|
||||
)
|
||||
|
||||
if group is not None and has_layer_sharding:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
@@ -75,8 +80,10 @@ class DistributedImageModel:
|
||||
model_path = build_model_path(model_id)
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
raise ValueError("Expected PipelineShardMetadata for image generation")
|
||||
if not isinstance(shard_metadata, (PipelineShardMetadata, CfgShardMetadata)):
|
||||
raise ValueError(
|
||||
"Expected PipelineShardMetadata or CfgShardMetadata for image generation"
|
||||
)
|
||||
|
||||
is_distributed = (
|
||||
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
|
||||
|
||||
@@ -86,6 +86,27 @@ class PromptData(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Get embeddings for a single CFG branch (positive or negative).
|
||||
|
||||
Used for sequential CFG and CFG parallel modes where we process
|
||||
one branch at a time instead of batching.
|
||||
|
||||
Args:
|
||||
positive: True for positive prompt, False for negative prompt
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- embeds: [1, seq, hidden] prompt embeddings
|
||||
- mask: [1, seq] attention mask or None
|
||||
- pooled: [1, hidden] pooled embeddings or None
|
||||
- conditioning_latents: [1, latent_seq, latent_dim] or None
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
|
||||
_config: ImageModelConfig
|
||||
|
||||
@@ -64,6 +64,12 @@ class FluxPromptData(PromptData):
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
return None
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Flux doesn't use CFG, but we return positive data for compatibility."""
|
||||
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
|
||||
|
||||
|
||||
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
|
||||
def __init__(
|
||||
|
||||
@@ -133,6 +133,24 @@ class QwenPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
|
||||
num_sync_steps_factor=0.25,
|
||||
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125,
|
||||
num_sync_steps_factor=0.25,
|
||||
guidance_scale=3.5,
|
||||
)
|
||||
|
||||
@@ -153,6 +153,24 @@ class QwenEditPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, batched_cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, final
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
@@ -11,7 +13,7 @@ from exo.shared.tracing import (
|
||||
clear_trace_buffer,
|
||||
trace,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
@@ -25,6 +27,16 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass(frozen=True)
|
||||
class CfgBranch:
|
||||
positive: bool
|
||||
embeds: mx.array
|
||||
mask: mx.array | None
|
||||
pooled: mx.array | None
|
||||
cond_latents: mx.array | None
|
||||
|
||||
|
||||
def calculate_patch_heights(
|
||||
latent_height: int, num_patches: int
|
||||
) -> tuple[list[int], int]:
|
||||
@@ -70,29 +82,18 @@ class DiffusionRunner:
|
||||
config: ImageModelConfig,
|
||||
adapter: ModelAdapter[Any, Any],
|
||||
group: Optional[mx.distributed.Group],
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
shard_metadata: PipelineShardMetadata | CfgShardMetadata,
|
||||
num_patches: Optional[int] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
self._init_cfg_topology(shard_metadata)
|
||||
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
self.num_patches = (
|
||||
num_patches if num_patches else max(1, self.pipeline_world_size)
|
||||
)
|
||||
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
@@ -102,6 +103,97 @@ class DiffusionRunner:
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _init_cfg_topology(
|
||||
self, shard_metadata: PipelineShardMetadata | CfgShardMetadata
|
||||
) -> None:
|
||||
"""Initialize CFG and pipeline topology from shard metadata.
|
||||
|
||||
Both CfgShardMetadata and PipelineShardMetadata represent pipeline parallel
|
||||
execution. CFG adds a second parallel pipeline for negative prompt processing,
|
||||
but within each pipeline group the communication pattern is identical.
|
||||
"""
|
||||
if self.group is None:
|
||||
# Single node - no distributed communication
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.start_layer = 0
|
||||
self.end_layer = self.config.total_blocks
|
||||
self.cfg_rank = 0
|
||||
self.cfg_world_size = 1
|
||||
self.cfg_parallel = False
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_world_size = 1
|
||||
self.next_pipeline_rank: int | None = None
|
||||
self.prev_pipeline_rank: int | None = None
|
||||
self.cfg_peer_rank: int | None = None
|
||||
self.first_pipeline_rank: int = 0
|
||||
self.last_pipeline_rank: int = 0
|
||||
return
|
||||
|
||||
# Common fields from base metadata
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
if isinstance(shard_metadata, CfgShardMetadata):
|
||||
# CFG parallel: two independent pipelines
|
||||
self.cfg_rank = shard_metadata.cfg_rank
|
||||
self.cfg_world_size = shard_metadata.cfg_world_size
|
||||
self.cfg_parallel = True
|
||||
self.pipeline_rank = shard_metadata.pipeline_rank
|
||||
self.pipeline_world_size = shard_metadata.pipeline_world_size
|
||||
else:
|
||||
# Pure pipeline: single pipeline group, sequential CFG
|
||||
self.cfg_rank = 0
|
||||
self.cfg_world_size = 1
|
||||
self.cfg_parallel = False
|
||||
self.pipeline_rank = shard_metadata.device_rank
|
||||
self.pipeline_world_size = shard_metadata.world_size
|
||||
|
||||
# Pipeline neighbor computation (same logic for both types)
|
||||
is_first = self.pipeline_rank == 0
|
||||
is_last = self.pipeline_rank == self.pipeline_world_size - 1
|
||||
|
||||
self.next_pipeline_rank = (
|
||||
None
|
||||
if is_last
|
||||
else self._device_rank_for(self.cfg_rank, self.pipeline_rank + 1)
|
||||
)
|
||||
self.prev_pipeline_rank = (
|
||||
None
|
||||
if is_first
|
||||
else self._device_rank_for(self.cfg_rank, self.pipeline_rank - 1)
|
||||
)
|
||||
|
||||
# CFG peer is the corresponding last stage in the other CFG group
|
||||
if self.cfg_parallel and is_last:
|
||||
other_cfg_rank = 1 - self.cfg_rank
|
||||
self.cfg_peer_rank = self._device_rank_for(
|
||||
other_cfg_rank, self.pipeline_rank
|
||||
)
|
||||
else:
|
||||
self.cfg_peer_rank = None
|
||||
|
||||
# First/last pipeline ranks for ring communication (latent broadcast)
|
||||
self.first_pipeline_rank = self._device_rank_for(self.cfg_rank, 0)
|
||||
self.last_pipeline_rank = self._device_rank_for(
|
||||
self.cfg_rank, self.pipeline_world_size - 1
|
||||
)
|
||||
|
||||
def _device_rank_for(self, cfg_rank: int, pipeline_rank: int) -> int:
|
||||
"""Convert (cfg_rank, pipeline_rank) to device_rank in the ring topology.
|
||||
|
||||
Ring layout: [cfg0_pipe0, cfg0_pipe1, ..., cfg1_pipeN-1, cfg1_pipeN-2, ..., cfg1_pipe0]
|
||||
Group 0 is in ascending order, group 1 is reversed so last stages are neighbors.
|
||||
"""
|
||||
if not self.cfg_parallel:
|
||||
return pipeline_rank
|
||||
if cfg_rank == 0:
|
||||
return pipeline_rank
|
||||
else:
|
||||
return self.world_size - 1 - pipeline_rank
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
@@ -138,11 +230,11 @@ class DiffusionRunner:
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.rank == 0
|
||||
return self.pipeline_rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.rank == self.world_size - 1
|
||||
return self.pipeline_rank == self.pipeline_world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
@@ -153,6 +245,97 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
|
||||
"""Yield the CFG branches this node should process.
|
||||
|
||||
- No CFG: yields one branch (positive)
|
||||
- CFG parallel: yields one branch (our assigned branch)
|
||||
- Sequential CFG: yields two branches (positive, then negative)
|
||||
"""
|
||||
if not self.adapter.needs_cfg:
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
elif self.cfg_parallel:
|
||||
positive = self.cfg_rank == 0
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
|
||||
yield CfgBranch(
|
||||
positive=positive,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
else:
|
||||
pos_embeds, pos_mask, pos_pooled, pos_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=True)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=pos_embeds,
|
||||
mask=pos_mask,
|
||||
pooled=pos_pooled,
|
||||
cond_latents=pos_cond,
|
||||
)
|
||||
neg_embeds, neg_mask, neg_pooled, neg_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=False)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=False,
|
||||
embeds=neg_embeds,
|
||||
mask=neg_mask,
|
||||
pooled=neg_pooled,
|
||||
cond_latents=neg_cond,
|
||||
)
|
||||
|
||||
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
|
||||
if len(results) == 1:
|
||||
positive, noise = results[0]
|
||||
if self.cfg_parallel and self.is_last_stage:
|
||||
# TODO(ciaran): try to remove
|
||||
mx.eval(noise)
|
||||
return self._exchange_and_apply_guidance(noise, positive)
|
||||
return noise
|
||||
|
||||
noise_neg = next(n for p, n in results if not p)
|
||||
noise_pos = next(n for p, n in results if p)
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _exchange_and_apply_guidance(
|
||||
self, noise: mx.array, is_positive: bool
|
||||
) -> mx.array:
|
||||
assert self.group is not None
|
||||
assert self.cfg_peer_rank is not None
|
||||
|
||||
if is_positive:
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_neg)
|
||||
noise_pos = noise
|
||||
else:
|
||||
noise_pos = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_pos)
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = noise
|
||||
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
|
||||
scale = self._get_effective_guidance_scale()
|
||||
assert scale is not None
|
||||
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -470,7 +653,9 @@ class DiffusionRunner:
|
||||
) -> mx.array:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
elif (
|
||||
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
|
||||
):
|
||||
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
@@ -496,42 +681,29 @@ class DiffusionRunner:
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
# Reset caches before each branch to ensure no state contamination
|
||||
self._reset_all_caches()
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate([latents, latents], axis=0)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = latents
|
||||
|
||||
noise = self._forward_pass(
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=cond_latents,
|
||||
)
|
||||
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=guidance_scale
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
|
||||
noise = self._forward_pass(
|
||||
latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=branch.cond_latents,
|
||||
)
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
noise = self._combine_cfg_results(results)
|
||||
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
|
||||
|
||||
def _create_patches(
|
||||
@@ -582,7 +754,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
@@ -594,19 +766,22 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
@@ -639,34 +814,45 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
concatenated, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
@@ -689,11 +875,14 @@ class DiffusionRunner:
|
||||
mx.eval(hidden_states)
|
||||
|
||||
if not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
@@ -716,83 +905,67 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
|
||||
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
step_latents = mx.concatenate(
|
||||
[scaled_hidden_states, scaled_hidden_states], axis=0
|
||||
|
||||
cond_latents = branch.cond_latents
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
branch.mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
encoder_mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
noise = self._combine_cfg_results(results)
|
||||
|
||||
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise, timestep=t, latents=prev_latents
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
with trace(name="send 0", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, 0, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.first_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
with trace(
|
||||
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.last_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -809,39 +982,10 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
@@ -853,34 +997,57 @@ class DiffusionRunner:
|
||||
and not is_first_async_step
|
||||
):
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
name=f"recv {self.last_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
patch, src=self.last_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=step_patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(branch.mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
branch.embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=branch.embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
noise = self._combine_cfg_results(results)
|
||||
|
||||
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise,
|
||||
@@ -890,10 +1057,14 @@ class DiffusionRunner:
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
name=f"send {self.first_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
patch_latents[patch_idx],
|
||||
self.first_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
@@ -933,26 +1104,31 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
@@ -988,39 +1164,54 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
patch_concat, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
@@ -1043,15 +1234,20 @@ class DiffusionRunner:
|
||||
mx.eval(patch)
|
||||
|
||||
if not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
patch = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch, text_embeddings)
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
|
||||
|
||||
return noise, encoder_hidden_states
|
||||
|
||||
@@ -48,6 +48,7 @@ from exo.shared.types.worker.instances import (
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
@@ -274,6 +275,11 @@ def shard_and_load(
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
case CfgShardMetadata():
|
||||
raise ValueError(
|
||||
"CfgShardMetadata is not supported for text model loading - "
|
||||
"this metadata type is only for image generation models"
|
||||
)
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
@@ -334,7 +334,7 @@ class Worker:
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.trace(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
|
||||
|
||||
@@ -292,14 +292,9 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
|
||||
@@ -66,7 +66,11 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
DistributedImageModel,
|
||||
@@ -87,6 +91,22 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
|
||||
"""Check if this node is the primary output node for image generation.
|
||||
|
||||
For CFG models: the last pipeline stage in CFG group 0 (positive prompt).
|
||||
For non-CFG models: the last pipeline stage.
|
||||
"""
|
||||
if isinstance(shard_metadata, CfgShardMetadata):
|
||||
is_pipeline_last = (
|
||||
shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1
|
||||
)
|
||||
return is_pipeline_last and shard_metadata.cfg_rank == 0
|
||||
elif isinstance(shard_metadata, PipelineShardMetadata):
|
||||
return shard_metadata.device_rank == shard_metadata.world_size - 1
|
||||
return False
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
@@ -120,12 +140,12 @@ def main(
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
logger.warning("repeat task - currently a logic bug, please report")
|
||||
continue
|
||||
seen.add(task.task_id)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
@@ -137,6 +157,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
@@ -153,6 +174,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
@@ -195,6 +217,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
@@ -234,6 +257,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
@@ -365,16 +389,14 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||
|
||||
if is_primary_output:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
@@ -399,7 +421,7 @@ def main(
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -430,14 +452,12 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
@@ -461,7 +481,7 @@ def main(
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -488,6 +508,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
|
||||
@@ -201,29 +201,29 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
@@ -231,10 +231,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user