Compare commits

...

34 Commits

Author SHA1 Message Date
Evan
e9cfdee9b8 fix test 2025-12-24 19:54:48 +00:00
Evan
6a8182f5f3 fmt, lint 2025-12-24 19:54:48 +00:00
Evan
67a6045a88 too many logs 2025-12-24 19:54:48 +00:00
Evan
0f58a36061 no rust log for now 2025-12-24 19:54:48 +00:00
Evan
a05b0b91ad fixup 2025-12-24 19:54:48 +00:00
Evan
780fd1da1a update example 2025-12-24 19:54:48 +00:00
Evan
0cb0a237a5 this was stupid 2025-12-24 19:54:48 +00:00
Evan
240db4466b add py example 2025-12-24 19:54:48 +00:00
Evan
4465ebd176 filter the logger 2025-12-24 19:54:48 +00:00
Evan
c5425d44df stuff to fix 2025-12-24 19:54:48 +00:00
Evan
5eb7bd5ba3 renamed 2025-12-24 19:54:48 +00:00
Evan
922632b48d fix merge 2025-12-24 19:54:48 +00:00
Evan
427e0af1e4 clippy lints 2025-12-24 19:54:48 +00:00
Evan
12a7bc0fda cleanup 2025-12-24 19:54:48 +00:00
Evan
8cf3dac484 actually use it 2025-12-24 19:54:48 +00:00
Evan
1f6ebd875b add retry count 2025-12-24 19:54:48 +00:00
Evan
0dcd27568d more testing 2025-12-24 19:54:48 +00:00
Evan
9959357cab add log 2025-12-24 19:54:48 +00:00
Evan
182cab12b3 publish with mdns 2025-12-24 19:54:48 +00:00
Evan
30885e30ed print secret key 2025-12-24 19:54:48 +00:00
Evan
387a6bdda3 patch 2025-12-24 19:54:48 +00:00
Evan
8ab7103024 move networking bin to an example 2025-12-24 19:54:48 +00:00
Evan
e498898ced unfinished chunking 2025-12-24 19:54:48 +00:00
Evan
f53ab8d018 found the max message size 2025-12-24 19:54:48 +00:00
Evan
c8753c6237 think there's a bindings issue 2025-12-24 19:54:48 +00:00
Evan
d7e63a1dfb fix 2025-12-24 19:54:48 +00:00
Evan
e7ccf3474f logging 2025-12-24 19:54:48 +00:00
Evan
9fa07990cf thisll do 2025-12-24 19:54:48 +00:00
Evan
e95a57ddb1 actually got some comms throguh 2025-12-24 19:54:48 +00:00
Evan
c015975671 patch net-tools 2025-12-24 19:54:48 +00:00
Evan
e3c020b5f1 freeze that socket 2025-12-24 19:54:48 +00:00
Evan
57a65e36c8 auto dial 2025-12-24 19:54:48 +00:00
Evan
af38e8ceb8 andrei is so cool 2025-12-24 19:54:48 +00:00
Evan
424a3eea21 uh i did it i think 2025-12-24 19:54:48 +00:00
67 changed files with 2561 additions and 5412 deletions

3754
Cargo.lock generated
View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,8 @@
[workspace] [workspace]
resolver = "3" resolver = "3"
members = [ members = [
"rust/networking",
"rust/exo_pyo3_bindings", "rust/exo_pyo3_bindings",
"rust/system_custodian", "rust/networking",
"rust/util",
] ]
[workspace.package] [workspace.package]
@@ -25,63 +23,38 @@ opt-level = 3
[workspace.dependencies] [workspace.dependencies]
## Crate members as common dependencies ## Crate members as common dependencies
networking = { path = "rust/networking" } networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies # Async dependencies
tokio = "1.46" tokio = "1.46"
futures = "0.3" n0-future = "0.3.1"
futures-util = "0.3" postcard = "1.1.3"
futures-timer = "3.0" n0-error = "0.1.2"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging # Tracing/logging
log = "0.4" log = "0.4"
blake3 = "1.8.2"
env_logger = "0.11"
tracing-subscriber = "0.3.20"
# networking # networking
libp2p = "0.56" iroh = "0.95.1"
libp2p-tcp = "0.44" iroh-gossip = "0.95.0"
bytes = "1.11.0"
# pyo3
pyo3 = "0.27.1"
# pyo3-async-runtimes = "0.27.0"
pyo3-log = "0.13.2"
pyo3-stub-gen = "0.17.2"
# util
rand = "0.9.2"
extend = "1.2"
[patch.crates-io]
netwatch = { git = "https://github.com/Evanev7/net-tools.git", branch="patch-for-exo" }
[workspace.lints.rust] [workspace.lints.rust]
static_mut_refs = "warn" # Or use "warn" instead of deny
incomplete_features = "allow"
# Clippy's lint category level configurations; # Clippy's lint category level configurations;
# every member crate needs to inherit these by adding # every member crate needs to inherit these by adding

View File

@@ -19,6 +19,7 @@
25. Rethink retry logic 25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task. 26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels 27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Really need to remove all mlx logic outside of the runner - API has a transitive dependency on engines which imports mlx
Potential refactors: Potential refactors:

View File

@@ -31,22 +31,24 @@
"aarch64-darwin" "aarch64-darwin"
"aarch64-linux" "aarch64-linux"
]; ];
fenixToolchain = system: inputs.fenix.packages.${system}.complete; fenixToolchain = system: inputs.fenix.packages.${system}.stable;
in in
inputs.flake-utils.lib.eachSystem systems ( inputs.flake-utils.lib.eachSystem systems (
system: system:
let let
pkgs = import inputs.nixpkgs { pkgs = import inputs.nixpkgs {
inherit system; inherit system;
overlays = [ inputs.fenix.overlays.default ]; overlays = [ ];
}; };
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs { treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
projectRootFile = "flake.nix"; projectRootFile = "flake.nix";
programs.ruff-format.enable = true; programs = {
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ]; ruff-format.enable = true;
programs.rustfmt.enable = true; ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
programs.rustfmt.package = (fenixToolchain system).rustfmt; rustfmt.enable = true;
programs.nixpkgs-fmt.enable = true; rustfmt.package = (fenixToolchain system).rustfmt;
nixpkgs-fmt.enable = true;
};
}; };
in in
{ {
@@ -76,6 +78,8 @@
"rustfmt" "rustfmt"
"rust-src" "rust-src"
]) ])
cargo-machete
bacon
rustup # Just here to make RustRover happy rustup # Just here to make RustRover happy
# NIX # NIX

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "exo" name = "exo"
version = "0.3.0" version = "0.10.0"
description = "Exo" description = "Exo"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"

View File

@@ -1,2 +0,0 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -5,8 +5,6 @@ edition = { workspace = true }
publish = false publish = false
[lib] [lib]
doctest = false
path = "src/lib.rs"
name = "exo_pyo3_bindings" name = "exo_pyo3_bindings"
# "cdylib" needed to produce shared library for Python to import # "cdylib" needed to produce shared library for Python to import
@@ -22,46 +20,24 @@ doc = false
workspace = true workspace = true
[dependencies] [dependencies]
networking = { workspace = true } networking.workspace = true
# interop # interop
pyo3 = { version = "0.27.1", features = [ pyo3 = { workspace = true, features = ["experimental-async"] }
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11 pyo3-stub-gen.workspace = true
"nightly", # enables better-supported GIL integration # pyo3-async-runtimes = { workspace = true, features = ["attributes", "tokio-runtime", "testing"] }
"experimental-async", # async support in #[pyfunction] & #[pymethods] pyo3-log.workspace = true
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.17.2" }
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
pyo3-log = "0.13.2"
# macro dependencies # macro dependencies
extend = { workspace = true } extend.workspace = true
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime # async runtime
tokio = { workspace = true, features = ["full", "tracing"] } tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
# utility dependencies # utility dependencies
once_cell = "1.21.3" postcard = { workspace = true, features = ["use-std"] }
thread_local = "1.1.9" rand.workspace = true
util = { workspace = true } n0-future.workspace = true
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing # Tracing
@@ -70,8 +46,9 @@ thiserror = { workspace = true }
#console-subscriber = "0.1.5" #console-subscriber = "0.1.5"
#tracing-log = "0.2.0" #tracing-log = "0.2.0"
log = { workspace = true } log = { workspace = true }
env_logger = "0.11" env_logger = { workspace = true }
# Networking # Networking
libp2p = { workspace = true, features = ["full"] } iroh = { workspace = true }
iroh-gossip = { workspace = true }

View File

@@ -0,0 +1,16 @@
from exo_pyo3_bindings import RustNetworkingHandle, Keypair
from asyncio import run
async def main():
nh = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "mdns_example")
recv = await nh.get_connection_receiver()
while True:
cm = await recv.receive()
print(
f"Endpoint({cm.endpoint_id}, reachable={list(map(lambda it: it.ip_addr(), cm.current_transport_addrs)) if cm.current_transport_addrs is not None else None})"
)
if __name__ == "__main__":
run(main())

View File

@@ -2,220 +2,63 @@
# ruff: noqa: E501, F401 # ruff: noqa: E501, F401
import builtins import builtins
import enum
import typing import typing
@typing.final @typing.final
class AllQueuesFullError(builtins.Exception): class EndpointId:
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ... def __str__(self) -> builtins.str: ...
@typing.final @typing.final
class ConnectionUpdate: class IpAddress:
@property def __str__(self) -> builtins.str: ...
def update_type(self) -> ConnectionUpdateType: def ip_addr(self) -> builtins.str: ...
r""" def port(self) -> builtins.int: ...
Whether this is a connection or disconnection event def zone_id(self) -> typing.Optional[builtins.int]: ...
"""
@property
def peer_id(self) -> PeerId:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
@typing.final @typing.final
class Keypair: class Keypair:
r"""
Identity keypair of a node.
"""
@staticmethod @staticmethod
def generate_ed25519() -> Keypair: def generate_ed25519() -> Keypair:
r""" r"""
Generate a new Ed25519 keypair. Generate a new Ed25519 keypair.
""" """
@staticmethod @staticmethod
def generate_ecdsa() -> Keypair: def from_postcard_encoding(bytes: bytes) -> Keypair:
r""" r"""
Generate a new ECDSA keypair. Decode a postcard structure into a keypair
""" """
@staticmethod def to_postcard_encoding(self) -> bytes:
def generate_secp256k1() -> Keypair:
r""" r"""
Generate a new Secp256k1 keypair. Encode a private key with the postcard format
""" """
@staticmethod def endpoint_id(self) -> EndpointId:
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r""" r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`. Read out the endpoint id corresponding to this keypair
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
""" """
@typing.final @typing.final
class Multiaddr: class RustConnectionMessage:
r""" @property
Representation of a Multiaddr. def endpoint_id(self) -> EndpointId: ...
""" @property
@staticmethod def current_transport_addrs(self) -> typing.Optional[builtins.set[IpAddress]]: ...
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
"""
def to_string(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
"""
@typing.final @typing.final
class NetworkingHandle: class RustConnectionReceiver:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ... async def receive(self) -> RustConnectionMessage: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.
"""
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
"""
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Unsubscribes from a `GossipSub` topic.
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
"""
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
r"""
Publishes a message with multiple topics to the `GossipSub` network.
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
@typing.final @typing.final
class NoPeersSubscribedToTopicError(builtins.Exception): class RustNetworkingHandle:
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod @staticmethod
def random() -> PeerId: async def create(identity: Keypair, namespace: builtins.str) -> RustNetworkingHandle: ...
r""" async def subscribe(self, topic: builtins.str) -> tuple[RustSender, RustReceiver]: ...
Generates a random peer ID from a cryptographically secure PRNG. async def get_connection_receiver(self) -> RustConnectionReceiver: ...
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final @typing.final
class ConnectionUpdateType(enum.Enum): class RustReceiver:
r""" async def receive(self) -> bytes: ...
Connection or disconnection event discriminant type.
""" @typing.final
Connected = ... class RustSender:
Disconnected = ... async def send(self, message: bytes) -> None: ...

View File

@@ -8,7 +8,8 @@ version = "0.1.0"
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
authors = [ authors = [
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" } { name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
{ name = "Evan Quiney", email = "evanev7@gmail.com" }
] ]
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [] dependencies = []

View File

@@ -1,8 +1,6 @@
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await //! SEE: <https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await>
//!
use pin_project::pin_project; use pyo3::exceptions::PyRuntimeError;
use pyo3::marker::Ungil;
use pyo3::prelude::*; use pyo3::prelude::*;
use std::{ use std::{
future::Future, future::Future,
@@ -10,31 +8,36 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
#[pin_project]
#[repr(transparent)] #[repr(transparent)]
pub(crate) struct AllowThreads<F>(#[pin] F); pub struct AllowThreads<F>(F);
impl<F> AllowThreads<F> impl<F> AllowThreads<F>
where where
Self: Future, Self: Future,
{ {
pub fn new(f: F) -> Self { pub(crate) const fn new(f: F) -> Self {
Self(f) Self(f)
} }
} }
impl<F> Future for AllowThreads<F> impl<F> Future for AllowThreads<F>
where where
F: Future + Ungil, F: Future + Unpin + Send,
F::Output: Ungil, F::Output: Send,
{ {
type Output = F::Output; type Output = Result<F::Output, PyErr>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker(); let waker = cx.waker();
Python::with_gil(|py| { match Python::try_attach(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker))) py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker)))
}) }) {
Some(Poll::Pending) => Poll::Pending,
Some(Poll::Ready(t)) => Poll::Ready(Ok(t)),
// TODO: this doesn't actually work - graceful py shutdown handling
None => Poll::Ready(Err(PyRuntimeError::new_err(
"Python runtime shutdown while awaiting a future",
))),
}
} }
} }

View File

@@ -1,7 +1,6 @@
use pyo3_stub_gen::Result; use pyo3_stub_gen::Result;
fn main() -> Result<()> { fn main() -> Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
let stub = exo_pyo3_bindings::stub_info()?; let stub = exo_pyo3_bindings::stub_info()?;
stub.generate()?; stub.generate()?;
Ok(()) Ok(())

View File

@@ -1,240 +0,0 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -0,0 +1,66 @@
use iroh::{EndpointId, SecretKey, endpoint_info::EndpointIdExt as _};
use postcard::ser_flavors::StdVec;
use crate::ext::ResultExt as _;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use rand::rng;
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
#[derive(Debug, Clone)]
pub struct PyKeypair(pub(crate) SecretKey);
#[gen_stub_pymethods]
#[pymethods]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(SecretKey::generate(&mut rng()))
}
/// Decode a postcard structure into a keypair
#[staticmethod]
fn from_postcard_encoding(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(postcard::from_bytes(&bytes).pyerr()?))
}
/// Encode a private key with the postcard format
fn to_postcard_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = postcard::serialize_with_flavor(&self.0, StdVec::new()).pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Read out the endpoint id corresponding to this keypair
fn endpoint_id(&self) -> PyEndpointId {
PyEndpointId(self.0.public())
}
}
#[gen_stub_pyclass]
#[pyclass(name = "EndpointId", frozen)]
#[repr(transparent)]
#[derive(Debug, Clone)]
pub struct PyEndpointId(pub(crate) EndpointId);
#[gen_stub_pymethods]
#[pymethods]
impl PyEndpointId {
pub fn __str__(&self) -> String {
self.0.to_z32()
}
}
impl From<EndpointId> for PyEndpointId {
fn from(value: EndpointId) -> Self {
Self(value)
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyEndpointId>()?;
Ok(())
}

View File

@@ -4,65 +4,27 @@
//! //!
//! //!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading; mod allow_threading;
mod examples; mod identity;
pub(crate) mod networking; mod networking;
pub(crate) mod pylibp2p;
use crate::identity::ident_submodule;
use crate::networking::networking_submodule; use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer; use pyo3_stub_gen::define_stub_info_gatherer;
/// Namespace for all the constants used by this crate.
pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods /// Namespace for crate-wide extension traits/methods
pub(crate) mod ext { pub(crate) mod ext {
use crate::allow_threading::AllowThreads; use crate::allow_threading::AllowThreads;
use extend::ext; use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError}; use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes; use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python}; use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
#[ext(pub, name = ByteArrayExt)] #[ext(pub, name = ByteArrayExt)]
impl [u8] { impl [u8] {
fn pybytes(&self) -> Py<PyBytes> { fn pybytes(&self) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new(py, self).unbind()) Python::attach(|py| PyBytes::new(py, self).unbind())
} }
} }
@@ -77,7 +39,9 @@ pub(crate) mod ext {
} }
pub trait FutureExt: Future + Sized { pub trait FutureExt: Future + Sized {
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await /// SEE: <https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await>
/// An [`AllowThreads`] returns a Future with an Err output if python has shutdown while we
/// were awaiting something
fn allow_threads_py(self) -> AllowThreads<Self> fn allow_threads_py(self) -> AllowThreads<Self>
where where
AllowThreads<Self>: Future, AllowThreads<Self>: Future,
@@ -98,7 +62,7 @@ pub(crate) mod ext {
#[ext(pub, name = PyResultExt)] #[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> { impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> { fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py)) Python::attach(|py| self.write_unraisable_with(py))
} }
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> { fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
@@ -112,85 +76,6 @@ pub(crate) mod ext {
} }
} }
} }
#[ext(pub, name = TokioRuntimeExt)]
impl Runtime {
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
}
}
#[ext(pub, name = TokioMpscSenderExt)]
impl<T> mpsc::Sender<T> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
/// channel has not hung up already. An unsuccessful send would be one where
/// the corresponding receiver has already been closed.
async fn send_py(&self, value: T) -> PyResult<()> {
self.send(value)
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
}
#[ext(pub, name = TokioMpscReceiverExt)]
impl<T> mpsc::Receiver<T> {
/// Receives the next value for this receiver.
async fn recv_py(&mut self) -> PyResult<T> {
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
}
/// Receives at most `limit` values for this receiver and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
// get updates from receiver channel
let mut updates = Vec::with_capacity(limit);
let received = self.recv_many(&mut updates, limit).await;
// if we received zero items, then the channel was unexpectedly closed
if limit != 0 && received == 0 {
return Err(PyErr::receiver_channel_closed());
}
Ok(updates)
}
/// Tries to receive the next value for this receiver.
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
match self.try_recv() {
Ok(v) => Ok(Some(v)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
}
}
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
} }
/// A Python module implemented in Rust. The name of this function must match /// A Python module implemented in Rust. The name of this function must match
@@ -199,18 +84,18 @@ impl<T> Clone for ClonePy<T> {
#[pymodule(name = "exo_pyo3_bindings")] #[pymodule(name = "exo_pyo3_bindings")]
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> { fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger // install logger
pyo3_log::init(); /*
use log::LevelFilter;
#[allow(clippy::expect_used)]
pyo3_log::Logger::default()
.filter(LevelFilter::Warn)
.install()
.expect("logger install");
*/
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?; ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?; networking_submodule(m)?;
// top-level constructs
// TODO: ...
Ok(()) Ok(())
} }

View File

@@ -1,570 +1,194 @@
#![allow( use crate::ext::{ByteArrayExt as _, FutureExt as _, ResultExt as _};
clippy::multiple_inherent_impl, use crate::identity::{PyEndpointId, PyKeypair};
clippy::unnecessary_wraps, use iroh::SecretKey;
clippy::unused_self, use iroh::discovery::EndpointInfo;
clippy::needless_pass_by_value use iroh::discovery::mdns::DiscoveryEvent;
)] use iroh_gossip::api::{ApiError, Event, GossipReceiver, GossipSender, Message};
use n0_future::{Stream, StreamExt as _};
use crate::r#const::MPSC_CHANNEL_SIZE; use networking::ExoNet;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _}; use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _}; use pyo3::prelude::*;
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes; use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods}; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods}; use std::collections::BTreeSet;
use std::net::IpAddr; use std::net::SocketAddr;
use tokio::sync::{Mutex, mpsc, oneshot}; use std::pin::{Pin, pin};
use util::ext::VecExt as _; use std::sync::{Arc, LazyLock};
use tokio::runtime::Runtime;
mod exception { use tokio::sync::Mutex;
use pyo3::types::PyTuple;
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
pub struct PyNoPeersSubscribedToTopicError {}
impl PyNoPeersSubscribedToTopicError {
const MSG: &'static str = "\
No peers are currently subscribed to receive messages on this topic. \
Wait for peers to subscribe or check your network connectivity.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNoPeersSubscribedToTopicError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
pub struct PyAllQueuesFullError {}
impl PyAllQueuesFullError {
const MSG: &'static str =
"All libp2p peers are unresponsive, resend the message or reconnect.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyAllQueuesFullError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
}
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyAllQueuesFullError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
}
}
}
log::info!("RUST: networking task stopped");
}
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle {
// channels
to_task_tx: Option<mpsc::Sender<ToTask>>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
}
impl Drop for PyNetworkingHandle {
fn drop(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
}
#[allow(clippy::expect_used)] #[allow(clippy::expect_used)]
impl PyNetworkingHandle { static RUNTIME: LazyLock<Runtime> =
fn new( LazyLock::new(|| Runtime::new().expect("Failed to create tokio runtime"));
to_task_tx: mpsc::Sender<ToTask>,
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>, #[gen_stub_pyclass]
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>, #[pyclass(name = "IpAddress")]
) -> Self { #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
Self { pub struct PyIpAddress {
to_task_tx: Some(to_task_tx), inner: SocketAddr,
connection_update_rx: Mutex::new(connection_update_rx), }
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
} #[gen_stub_pymethods]
#[pymethods]
impl PyIpAddress {
pub fn __str__(&self) -> String {
self.inner.to_string()
} }
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> { pub fn ip_addr(&self) -> String {
self.to_task_tx self.inner.ip().to_string()
.as_ref()
.expect("The sender should only be None after de-initialization.")
} }
pub const fn port(&self) -> u16 {
self.inner.port()
}
pub const fn zone_id(&self) -> Option<u32> {
match self.inner {
SocketAddr::V6(ip) => Some(ip.scope_id()),
SocketAddr::V4(_) => None,
}
}
}
#[gen_stub_pyclass]
#[pyclass(name = "RustNetworkingHandle")]
pub struct PyNetworkingHandle {
net: Arc<ExoNet>,
} }
#[gen_stub_pymethods] #[gen_stub_pymethods]
#[pymethods] #[pymethods]
impl PyNetworkingHandle { impl PyNetworkingHandle {
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()` #[staticmethod]
// immediately beforehand to release the interpreter. pub async fn create(identity: PyKeypair, namespace: String) -> PyResult<Self> {
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await let loc: SecretKey = identity.0.clone();
let net = Arc::new(
RUNTIME
.spawn(async move { ExoNet::init_iroh(loc, &namespace).await })
.await
// todo: pyerr better
.pyerr()?
.pyerr()?,
);
let cloned = Arc::clone(&net);
RUNTIME.spawn(async move { cloned.start_auto_dialer().await });
// ---- Lifecycle management methods ---- Ok(Self { net })
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { create_swarm(identity) })
.pyerr()?;
// spawn tokio task running the networking logic
get_runtime().spawn(async move {
networking_task(
swarm,
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
} }
#[gen_stub(skip)] async fn subscribe(&self, topic: String) -> PyResult<(PySender, PyReceiver)> {
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { let fut = self.net.subscribe(&topic);
Ok(()) // This is needed purely so `__clear__` can work let (send, recv) = fut.await.pyerr()?;
Ok((PySender { inner: send }, PyReceiver { inner: recv }))
} }
#[gen_stub(skip)] async fn get_connection_receiver(&self) -> PyResult<PyConnectionReceiver> {
fn __clear__(&mut self) { let stream = self.net.connection_info().await;
// TODO: may or may not need to await a "kill-signal" oneshot channel message, Ok(PyConnectionReceiver {
// to ensure that the networking task is done BEFORE exiting the clear function... inner: Mutex::new(Box::pin(stream)),
// but this may require GIL?? and it may not be safe to call GIL here?? })
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
} }
}
// ---- Connection update receiver methods ---- #[gen_stub_pyclass]
#[pyclass(name = "RustConnectionMessage")]
pub struct PyConnectionMessage {
#[pyo3(get)]
pub endpoint_id: PyEndpointId,
#[pyo3(get)]
pub current_transport_addrs: Option<BTreeSet<PyIpAddress>>,
}
/// Receives the next `ConnectionUpdate` from networking. #[gen_stub_pyclass]
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> { #[pyclass(name = "RustSender")]
self.connection_update_rx struct PySender {
.lock() inner: GossipSender,
.allow_threads_py() // allow-threads-aware async call }
.await
.recv_py() #[gen_stub_pymethods]
.allow_threads_py() // allow-threads-aware async call #[pymethods]
.await impl PySender {
async fn send(&mut self, message: Py<PyBytes>) -> PyResult<()> {
let bytes = Python::attach(|py| message.as_bytes(py).to_vec());
let broadcast_fut = self.inner.broadcast(bytes.into());
pin!(broadcast_fut).allow_threads_py().await?.pyerr()
} }
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them. #[gen_stub_pyclass]
/// #[pyclass(name = "RustReceiver")]
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately. struct PyReceiver {
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method inner: GossipReceiver,
/// will sleep until a `ConnectionUpdate`s is sent. }
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx #[gen_stub_pymethods]
.lock() #[pymethods]
.allow_threads_py() // allow-threads-aware async call impl PyReceiver {
.await async fn receive(&mut self) -> PyResult<Py<PyBytes>> {
.recv_many_py(limit) loop {
.allow_threads_py() // allow-threads-aware async call let next_fut = self.inner.next();
.await match pin!(next_fut).allow_threads_py().await? {
// Successful cases
Some(Ok(Event::Received(Message { content, .. }))) => {
return Ok(content.to_vec().pybytes());
}
Some(Ok(other)) => log::info!("Dropping gossip event {other:?}"),
None => return Err(PyStopAsyncIteration::new_err("")),
Some(Err(ApiError::Closed { .. })) => {
return Err(PyStopAsyncIteration::new_err(""));
}
// Failure case
Some(Err(other)) => {
return Err(PyRuntimeError::new_err(other.to_string()));
}
}
}
} }
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex) #[gen_stub_pyclass]
// so its too dangerous to expose just yet. figure out a better semantics for handling this, #[pyclass(name = "RustConnectionReceiver")]
// so things don't randomly block struct PyConnectionReceiver {
// /// Tries to receive the next `ConnectionUpdate` from networking. inner: Mutex<Pin<Box<dyn Stream<Item = DiscoveryEvent> + Send>>>,
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> { }
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ---- #[gen_stub_pymethods]
#[pymethods]
/// Subscribe to a `GossipSub` topic. impl PyConnectionReceiver {
/// async fn receive(&mut self) -> PyResult<PyConnectionMessage> {
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed. // Errors on trying to receive twice - which is a dev error. This could just block the
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> { // async task, but I want the error to persist
let (tx, rx) = oneshot::channel(); let mut lock = self.inner.try_lock().pyerr()?;
match lock.next().allow_threads_py().await? {
// send off request to subscribe // Successful cases
self.to_task_tx() Some(DiscoveryEvent::Discovered {
.send_py(ToTask::GossipsubSubscribe { endpoint_info: EndpointInfo { endpoint_id, data },
topic, ..
result_tx: tx, }) => Ok(PyConnectionMessage {
}) endpoint_id: endpoint_id.into(),
.allow_threads_py() // allow-threads-aware async call current_transport_addrs: Some(
.await?; data.ip_addrs()
.map(|inner| PyIpAddress { inner: *inner })
// wait for response & return any errors .collect(),
rx.allow_threads_py() // allow-threads-aware async call ),
.await }),
.map_err(|_| PyErr::receiver_channel_closed())? Some(DiscoveryEvent::Expired { endpoint_id }) => Ok(PyConnectionMessage {
endpoint_id: endpoint_id.into(),
current_transport_addrs: None,
}),
// Failure case
None => Err(PyStopAsyncIteration::new_err("")),
}
} }
/// Unsubscribes from a `GossipSub` topic.
///
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
self.to_task_tx()
.send_py(ToTask::GossipsubUnsubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & convert any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
/// Publishes a message with multiple topics to the `GossipSub` network.
///
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,
data,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors => ignore messageID for now!!!
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())??;
Ok(())
}
// ---- Gossipsub message receiver methods ----
/// Receives the next message from the `GossipSub` network.
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
self.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
}
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.map(|(t, d)| (t, d.pybytes())))
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
} }
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?; m.add_class::<PyConnectionMessage>()?;
m.add_class::<exception::PyAllQueuesFullError>()?; m.add_class::<PyReceiver>()?;
m.add_class::<PySender>()?;
m.add_class::<PyConnectionUpdateType>()?; m.add_class::<PyConnectionReceiver>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?; m.add_class::<PyNetworkingHandle>()?;
Ok(()) Ok(())

View File

@@ -1,159 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -1,8 +0,0 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -1,81 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -1,54 +0,0 @@
#[cfg(test)]
mod tests {
use core::mem::drop;
use core::option::Option::Some;
use core::time::Duration;
use tokio;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_drop_channel() {
struct Ping;
let (tx, mut rx) = mpsc::channel::<Ping>(10);
let _ = tokio::spawn(async move {
println!("TASK: entered");
loop {
tokio::select! {
result = rx.recv() => {
match result {
Some(_) => {
println!("TASK: pinged");
}
None => {
println!("TASK: closing channel");
break;
}
}
}
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
println!("TASK: heartbeat");
}
}
}
println!("TASK: exited");
});
let tx2 = tx.clone();
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx.send(Ping).await.expect("Should not fail");
drop(tx);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx2.send(Ping).await.expect("Should not fail");
drop(tx2);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
}
}

View File

@@ -1,34 +1,47 @@
import asyncio import asyncio
import pytest import pytest
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError from exo_pyo3_bindings import (
Keypair,
RustNetworkingHandle,
RustReceiver,
RustConnectionReceiver,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sleep_on_multiple_items() -> None: async def test_sleep_on_multiple_items() -> None:
print("PYTHON: starting handle") print("PYTHON: starting handle")
h = NetworkingHandle(Keypair.generate_ed25519()) s_h = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "test")
r_h = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "test")
ct = asyncio.create_task(_await_cons(h)) await asyncio.sleep(1)
mt = asyncio.create_task(_await_msg(h))
cm = await r_h.get_connection_receiver()
_, recv = await r_h.subscribe("topic")
send, _ = await s_h.subscribe("topic")
ct = asyncio.create_task(_await_cons(cm))
mt = asyncio.create_task(_await_msg(recv))
# sleep for 4 ticks # sleep for 4 ticks
for i in range(4): for i in range(4):
await asyncio.sleep(1) await asyncio.sleep(1)
try: await send.send(b"somehting or other")
await h.gossipsub_publish("topic", b"somehting or other")
except NoPeersSubscribedToTopicError as e: await ct
print("caught it", e) await mt
async def _await_cons(h: NetworkingHandle): async def _await_cons(h: RustConnectionReceiver):
while True: while True:
c = await h.connection_update_recv() c = await h.receive()
print(f"PYTHON: connection update: {c}") print(f"PYTHON: connection update: {c}")
async def _await_msg(h: NetworkingHandle): async def _await_msg(r: RustReceiver):
while True: while True:
m = await h.gossipsub_recv() m = await r.receive()
print(f"PYTHON: message: {m}") print(f"PYTHON: message: {m}")

View File

@@ -1,44 +1,18 @@
[package] [package]
name = "networking" name = "networking"
version = { workspace = true } version.workspace = true
edition = { workspace = true } edition.workspace = true
publish = false
[lib] [dependencies]
doctest = false blake3 = { workspace = true, features = ["neon", "rayon"] }
name = "networking" iroh = { workspace = true, features = ["discovery-local-network"] }
path = "src/lib.rs" iroh-gossip.workspace = true
log.workspace = true
n0-error.workspace = true
n0-future.workspace = true
rand.workspace = true
tokio = { workspace = true, features = ["full"] }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
[lints] [lints]
workspace = true workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,74 +1,85 @@
use futures::stream::StreamExt as _; #![allow(clippy::expect_used, clippy::unwrap_used, clippy::cargo)]
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm}; use std::sync::Arc;
use tokio::{io, io::AsyncBufReadExt as _, select}; use std::time::Duration;
use iroh::SecretKey;
use iroh_gossip::api::{Event, Message};
use n0_future::StreamExt as _;
use networking::ExoNet;
use tokio::time::sleep;
use tokio::{io, io::AsyncBufReadExt as _};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::filter::LevelFilter;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let _ = tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into())) .with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init(); .try_init()
.expect("logger");
// Configure swarm // Configure swarm
let mut swarm = let net = Arc::new(
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed"); ExoNet::init_iroh(SecretKey::generate(&mut rand::rng()), "chatroom")
.await
.expect("iroh init shouldn't fail"),
);
let innet = Arc::clone(&net);
let jh1 = tokio::spawn(async move { innet.start_auto_dialer().await });
while net.known_peers.lock().await.is_empty() {
sleep(Duration::from_secs(1)).await;
}
// Create a Gossipsub topic & subscribe // Create a Gossipsub topic & subscribe
let topic = gossipsub::IdentTopic::new("test-net"); let (send, mut recv) = net
swarm .subscribe("chatting")
.behaviour_mut() .await
.gossipsub .expect("topic shouldn't fail");
.subscribe(&topic)
.expect("Subscribing to topic failed");
// Read full lines from stdin // Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines(); let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub"); println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off let jh2 = tokio::spawn(async move {
loop { loop {
select! { if let Ok(Some(line)) = stdin.next_line().await
// on gossipsub outgoing && let Err(e) = send.broadcast(line.into()).await
Ok(Some(line)) = stdin.next_line() => { {
if let Err(e) = swarm println!("Publish error: {e:?}");
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
}
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }
} }
} }
} });
tokio::spawn(async move {
while let Some(Ok(event)) = recv.next().await {
match event {
// on gossipsub incoming
Event::Received(Message {
content,
delivered_from,
..
}) => println!(
"\n\nGot message: '{}' with from peer: {delivered_from}\n\n",
String::from_utf8_lossy(&content),
),
// on discovery
Event::NeighborUp(peer_id) => {
println!("\n\nConnected to: {peer_id}\n\n");
}
Event::NeighborDown(peer_id) => {
eprintln!("\n\nDisconnected from: {peer_id}\n\n");
}
Event::Lagged => {
eprintln!("\n\nLagged\n\n");
}
}
}
})
.await
.unwrap();
jh1.await.unwrap();
jh2.await.unwrap();
} }

View File

@@ -1,127 +0,0 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -0,0 +1,30 @@
#![allow(clippy::cargo, clippy::unwrap_used)]
use iroh::{SecretKey, endpoint_info::EndpointIdExt as _};
use n0_future::StreamExt as _;
use networking::ExoNet;
// Launch a mock version of iroh for testing purposes
#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
let key = SecretKey::generate(&mut rand::rng());
let dbg_key = key.public().to_z32();
println!("Starting with pk: {dbg_key}");
let net = ExoNet::init_iroh(key, "").await.unwrap();
let mut conn_info = net.connection_info().await;
let task = tokio::task::spawn(async move {
println!("Inner task started!");
loop {
dbg!(conn_info.next().await);
}
});
println!("Task started!");
task.await.unwrap();
}

View File

@@ -1,44 +0,0 @@
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
https://stackoverflow.com/questions/17169298/af-packet-on-osx
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
https://www.unix.com/man_page/mojave/4/ip/
^OS X manpages for IP
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
^driver kit, system extensions & kexts for macOS
----
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
----> here is a guide on how to set up thunderbolt-ethernet on linux
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
^GPT discussion about making socket interface
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
^GPT discussion about accessing TB on MacOS low level interactions
--------------------------------
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
---------------------------------
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c

View File

@@ -1,383 +0,0 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
THandlerOutEvent, ToSwarm, dummy,
};
use libp2p::{Multiaddr, PeerId, identity, mdns};
use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible;
use std::io;
use std::net::IpAddr;
use std::task::{Context, Poll};
use std::time::Duration;
use util::wakerdeque::WakerDeque;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
use std::io;
use std::time::Duration;
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
mdns: mdns::tokio::Behaviour,
ping: ping::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
mdns: mdns_behaviour(keypair)?,
ping: ping_behaviour(),
})
}
}
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
use mdns::{Config, tokio};
// mDNS config => enable IPv6
let mdns_config = Config {
ttl: MDNS_RECORD_TTL,
query_interval: MDNS_QUERY_INTERVAL,
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
..Default::default()
};
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
Ok(mdns_behaviour?)
}
fn ping_behaviour() -> ping::Behaviour {
ping::Behaviour::new(
ping::Config::new()
.with_timeout(PING_TIMEOUT)
.with_interval(PING_INTERVAL),
)
}
}
/// Events for when a listening connection is truly established and truly closed.
#[derive(Debug, Clone)]
pub enum Event {
ConnectionEstablished {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
ConnectionClosed {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
}
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
///
/// The behaviour operates as such:
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
/// to the swarm.
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
/// immediately, and expired but connected peers are disconnected from immediately.
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
/// connected peers are disconnected from.
pub struct Behaviour {
// state-tracking for managed behaviors & mDNS-discovered peers
managed: managed::Behaviour,
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
retry_delay: Delay, // retry interval
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
managed: managed::Behaviour::new(keypair)?,
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
})
}
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
})
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
// push front to make this IMMEDIATE
self.pending_events.push_front(ToSwarm::CloseConnection {
peer_id,
connection: CloseConnection::One(connection),
})
}
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
self.dial(p, ma.clone()); // always connect
// get peer's multi-addresses or insert if missing
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
continue;
};
// multiaddress should never already be present - else something has gone wrong
let is_new_addr = mas.insert(ma);
assert!(is_new_addr, "cannot discover a discovered peer");
}
}
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
// at this point, we *must* have the peer
let mas = self
.mdns_discovered
.get_mut(&p)
.expect("nonexistent peer cannot expire");
// at this point, we *must* have the multiaddress
let was_present = mas.remove(&ma);
assert!(was_present, "nonexistent multiaddress cannot expire");
// if empty, remove the peer-id entirely
if mas.is_empty() {
self.mdns_discovered.remove(&p);
}
}
}
fn on_connection_established(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out connected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
fn on_connection_closed(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out disconnected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
}
impl NetworkBehaviour for Behaviour {
type ConnectionHandler =
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
type ToSwarm = Event;
// simply delegate to underlying mDNS behaviour
delegate! {
to self.managed {
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
}
}
fn handle_established_inbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
local_addr: &Multiaddr,
remote_addr: &Multiaddr,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_inbound_connection(
connection_id,
peer,
local_addr,
remote_addr,
)?,
))
}
#[allow(clippy::needless_question_mark)]
fn handle_established_outbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
addr: &Multiaddr,
role_override: Endpoint,
port_use: PortUse,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_outbound_connection(
connection_id,
peer,
addr,
role_override,
port_use,
)?,
))
}
fn on_connection_handler_event(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
event: THandlerOutEvent<Self>,
) {
match event {
Either::Left(ev) => libp2p::core::util::unreachable(ev),
Either::Right(ev) => {
self.managed
.on_connection_handler_event(peer_id, connection_id, ev)
}
}
}
// hook into these methods to drive behavior
fn on_swarm_event(&mut self, event: FromSwarm) {
self.managed.on_swarm_event(event); // let mDNS handle swarm events
// handle swarm events to update internal state:
match event {
FromSwarm::ConnectionEstablished(ConnectionEstablished {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection established event which is filtered correctly
self.on_connection_established(peer_id, connection_id, ip, port)
}
}
FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
}
// since we are running TCP/IP transport layer, we are assuming that
// no address changes can occur, hence encountering one is a fatal error
FromSwarm::AddressChange(a) => {
unreachable!("unhandlable: address change encountered: {:?}", a)
}
_ => {}
}
}
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
// delegate to managed behaviors for any behaviors they need to perform
match self.managed.poll(cx) {
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
match e {
// handle discovered and expired events from mDNS
managed::BehaviourEvent::Mdns(e) => match e.clone() {
mdns::Event::Discovered(peers) => {
self.handle_mdns_discovered(peers);
}
mdns::Event::Expired(peers) => {
self.handle_mdns_expired(peers);
}
},
// handle ping events => if error then disconnect
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
}
}
}
// since we just consumed an event, we should immediately wake just in case
// there are more events to come where that came from
cx.waker().wake_by_ref();
}
// forward any other mDNS event to the swarm or its connection handler(s)
Poll::Ready(e) => {
return Poll::Ready(
e.map_out(|_| unreachable!("events returning to swarm already handled"))
.map_in(Either::Right),
);
}
Poll::Pending => {}
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
}
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
}
// send out any pending events from our own service
if let Some(e) = self.pending_events.pop_front(cx) {
return Poll::Ready(e.map_in(Either::Left));
}
// wait for pending events
Poll::Pending
}
}

View File

@@ -1,44 +0,0 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -1,64 +1,149 @@
//! TODO: crate documentation use std::collections::BTreeSet;
//!
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience use iroh::{
#![feature(trait_alias)] Endpoint, EndpointId, SecretKey, TransportAddr,
// #![feature(stmt_expr_attributes)] discovery::{
// #![feature(unboxed_closures)] Discovery as _, EndpointData, IntoDiscoveryError,
// #![feature(assert_matches)] mdns::{DiscoveryEvent, MdnsDiscovery},
// #![feature(async_fn_in_dyn_trait)] },
// #![feature(async_for_loop)] endpoint::BindError,
// #![feature(auto_traits)] endpoint_info::EndpointIdExt as _,
// #![feature(negative_impls)] protocol::Router,
};
use iroh_gossip::{
Gossip, TopicId,
api::{ApiError, GossipReceiver, GossipSender},
};
pub mod discovery; use n0_error::{e, stack_error};
pub mod keep_alive; use n0_future::{Stream, StreamExt as _};
pub mod swarm; use tokio::sync::Mutex;
/// Namespace for all the type/trait aliases used by this crate. #[stack_error(derive, add_meta, from_sources)]
pub(crate) mod alias { pub enum ExoError {
use std::error::Error; #[error(transparent)]
FailedBinding { source: BindError },
pub type AnyError = Box<dyn Error + Send + Sync + 'static>; /// The gossip topic was closed.
pub type AnyResult<T> = Result<T, AnyError>; #[error(transparent)]
FailedCommunication { source: ApiError },
#[error("No IP Protocol supported on device")]
IPNotSupported { source: IntoDiscoveryError },
#[error("No peers found before subscribing")]
NoPeers,
} }
/// Namespace for crate-wide extension traits/methods #[derive(Debug)]
pub(crate) mod ext { pub struct ExoNet {
use extend::ext; pub alpn: String,
use libp2p::Multiaddr; pub router: Router,
use libp2p::multiaddr::Protocol; pub gossip: Gossip,
use std::net::IpAddr; pub mdns: MdnsDiscovery,
pub known_peers: Mutex<BTreeSet<EndpointId>>,
}
#[ext(pub, name = MultiaddrExt)] impl ExoNet {
impl Multiaddr { #[inline]
/// If the multiaddress corresponds to a TCP address, extracts it pub async fn init_iroh(sk: SecretKey, namespace: &str) -> Result<Self, ExoError> {
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> { let endpoint = Endpoint::empty_builder(iroh::RelayMode::Disabled)
let mut ps = self.into_iter(); .secret_key(sk)
let ip = if let Some(p) = ps.next() { .bind()
match p { .await?;
Protocol::Ip4(ip) => IpAddr::V4(ip), let mdns = MdnsDiscovery::builder().build(endpoint.id())?;
Protocol::Ip6(ip) => IpAddr::V6(ip), let endpoint_addr = endpoint.addr();
_ => return None,
let bound = endpoint_addr.ip_addrs().map(|it| TransportAddr::Ip(*it));
log::info!("publishing {endpoint_addr:?} with mdns");
mdns.publish(&EndpointData::new(bound));
endpoint.discovery().add(mdns.clone());
let alpn = format!("/exo_discovery_network/{namespace}");
// max msg size 4MB
let gossip = Gossip::builder()
.max_message_size(4 * 1024 * 1024)
.alpn(&alpn)
.spawn(endpoint.clone());
let router = Router::builder(endpoint)
.accept(&alpn, gossip.clone())
.spawn();
Ok(Self {
alpn,
router,
gossip,
mdns,
known_peers: Mutex::new(BTreeSet::new()),
})
}
#[inline]
pub async fn start_auto_dialer(&self) {
let mut recv = self.connection_info().await;
log::info!(
"Starting auto dialer for id {}",
self.router.endpoint().id().to_z32()
);
while let Some(item) = recv.next().await {
match item {
DiscoveryEvent::Discovered { endpoint_info, .. } => {
let id = endpoint_info.endpoint_id;
if id == self.router.endpoint().id() {
continue;
}
if !self
.known_peers
.lock()
.await
.contains(&endpoint_info.endpoint_id)
&& let Ok(conn) = self
.router
.endpoint()
.connect(endpoint_info, self.alpn.as_bytes())
.await
&& conn.alpn() == self.alpn.as_bytes()
{
self.known_peers.lock().await.insert(id);
match self.gossip.handle_connection(conn).await {
Ok(()) => log::info!("Successfully dialled"),
Err(_) => log::info!("Failed to dial peer"),
}
}
} }
} else { DiscoveryEvent::Expired { endpoint_id } => {
return None; log::info!("Peer expired {}", endpoint_id.to_z32());
}; self.known_peers.lock().await.remove(&endpoint_id);
let Some(Protocol::Tcp(port)) = ps.next() else { }
return None; }
};
Some((ip, port))
} }
log::info!("Auto dialer stopping");
}
#[inline]
pub async fn connection_info(&self) -> impl Stream<Item = DiscoveryEvent> + Unpin + use<> {
self.mdns.subscribe().await
}
#[inline]
pub async fn subscribe(&self, topic: &str) -> Result<(GossipSender, GossipReceiver), ExoError> {
if self.known_peers.lock().await.is_empty() {
return Err(e!(ExoError::NoPeers));
}
Ok(self
.gossip
.subscribe_and_join(
str_to_topic_id(topic),
self.known_peers.lock().await.clone().into_iter().collect(),
)
.await?
.split())
}
#[inline]
#[allow(clippy::expect_used)]
pub async fn shutdown(&self) {
self.router.shutdown().await.expect("router panic");
} }
} }
pub(crate) mod private { fn str_to_topic_id(data: &str) -> TopicId {
#![allow(dead_code)] TopicId::from_bytes(*blake3::hash(data.as_bytes()).as_bytes())
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
} }

View File

@@ -1,145 +0,0 @@
use crate::alias;
use crate::swarm::transport::tcp_transport;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
pub type Swarm = libp2p::Swarm<Behaviour>;
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
///
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
/// even be, and how to inject the right version into this config/initialization. E.g. should
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.build();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm)
}
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;
use libp2p::pnet::{PnetError, PnetOutput};
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
use std::{env, sync::LazyLock};
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
/// See [`pnet_upgrade`] for more.
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
let builder = Sha3_256::new().update(b"exo_discovery_network");
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
let bytes = var.into_bytes();
builder.update(&bytes)
} else {
builder.update(NETWORK_VERSION)
}
.finalize()
});
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
/// also different-versioned instances of this same network.
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
async fn pnet_upgrade<TSocket>(
socket: TSocket,
_: impl Sized,
) -> Result<PnetOutput<TSocket>, PnetError>
where
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
use pnet::{PnetConfig, PreSharedKey};
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
.handshake(socket)
.await
}
/// TCP/IP transport layer configuration.
pub fn tcp_transport(
keypair: &identity::Keypair,
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
use libp2p::{
core::upgrade::Version,
tcp::{Config, tokio},
};
// `TCP_NODELAY` enabled => avoid latency
let tcp_config = Config::default().nodelay(true);
// V1 + lazy flushing => 0-RTT negotiation
let upgrade_version = Version::V1Lazy;
// Noise is faster than TLS + we don't care much for security
let noise_config = noise::Config::new(keypair)?;
// Use default Yamux config for multiplexing
let yamux_config = yamux::Config::default();
// Create new Tokio-driven TCP/IP transport layer
let base_transport = tokio::Transport::new(tcp_config)
.and_then(pnet_upgrade)
.upgrade(upgrade_version)
.authenticate(noise_config)
.multiplex(yamux_config);
// Return boxed transport (to flatten complex type)
Ok(base_transport.boxed())
}
}
mod behaviour {
use crate::{alias, discovery};
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
use std::time::Duration;
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
#[derive(NetworkBehaviour)]
pub struct Behaviour {
pub discovery: discovery::Behaviour,
pub gossipsub: gossipsub::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
Ok(Self {
discovery: discovery::Behaviour::new(keypair)?,
gossipsub: gossipsub_behaviour(keypair),
})
}
}
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
// build a gossipsub network behaviour
// => signed message authenticity + strict validation mode means the message-ID is
// automatically provided by gossipsub w/out needing to provide custom message-ID function
gossipsub::Behaviour::new(
MessageAuthenticity::Signed(keypair.clone()),
ConfigBuilder::default()
.publish_queue_duration(Duration::from_secs(15))
.max_transmit_size(1024 * 1024)
.validation_mode(ValidationMode::Strict)
.build()
.expect("the configuration should always be valid"),
)
.expect("creating gossipsub behavior should always work")
}
}

View File

@@ -1,7 +0,0 @@
// maybe this will hold test in the future...??
#[cfg(test)]
mod tests {
#[test]
fn does_nothing() {}
}

View File

@@ -1,2 +0,0 @@
[toolchain]
channel = "nightly"

View File

@@ -1,47 +0,0 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -1,4 +0,0 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -1,69 +0,0 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -1,25 +0,0 @@
[package]
name = "util"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "util"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
# macro dependencies
extend = { workspace = true }
# utility dependencies
thiserror = { workspace = true }
once_cell = { workspace = true }
internment = { workspace = true }
derive_more = { workspace = true }
bon = { workspace = true }
recursion = { workspace = true }

View File

@@ -1,53 +0,0 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub mod nonempty;
pub mod wakerdeque;
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub mod ext {
use extend::ext;
#[ext(pub, name = BoxedSliceExt)]
impl<T> Box<[T]> {
#[inline]
fn map<B, F>(self, f: F) -> Box<[B]>
where
F: FnMut(T) -> B,
{
self.into_iter().map(f).collect()
}
}
#[ext(pub, name = VecExt)]
impl<T> Vec<T> {
#[inline]
fn map<B, F>(self, f: F) -> Vec<B>
where
F: FnMut(T) -> B,
{
self.into_iter().map(f).collect()
}
}
}

View File

@@ -1,138 +0,0 @@
use std::slice::SliceIndex;
use std::{ops, slice};
use thiserror::Error;
#[derive(Error, Debug)]
#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")]
pub struct EmptySliceError;
/// A pointer to a non-empty fixed-size slice allocated on the heap.
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(transparent)]
pub struct NonemptyArray<T>(Box<[T]>);
#[allow(clippy::arbitrary_source_item_ordering)]
impl<T> NonemptyArray<T> {
#[inline]
pub fn singleton(value: T) -> Self {
Self(Box::new([value]))
}
#[allow(clippy::missing_errors_doc)]
#[inline]
pub fn try_from_boxed_slice<S: Into<Box<[T]>>>(
boxed_slice: S,
) -> Result<Self, EmptySliceError> {
let boxed_slice = boxed_slice.into();
if boxed_slice.is_empty() {
Err(EmptySliceError)
} else {
Ok(Self(boxed_slice))
}
}
#[must_use]
#[inline]
pub fn into_boxed_slice(self) -> Box<[T]> {
self.0
}
#[must_use]
#[inline]
pub fn to_vec(&self) -> Vec<T>
where
T: Clone,
{
self.0.to_vec()
}
#[must_use]
#[inline]
pub const fn as_slice(&self) -> &[T] {
&self.0
}
#[allow(clippy::indexing_slicing)]
#[must_use]
#[inline]
pub fn first(&self) -> &T {
&self.0[0]
}
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
#[must_use]
#[inline]
pub fn last(&self) -> &T {
&self.0[self.0.len() - 1]
}
#[must_use]
#[inline]
pub fn get<I>(&self, index: I) -> Option<&I::Output>
where
I: SliceIndex<[T]>,
{
self.0.get(index)
}
#[allow(clippy::len_without_is_empty)]
#[must_use]
#[inline]
pub const fn len(&self) -> usize {
self.0.len()
}
#[allow(clippy::iter_without_into_iter)]
#[inline]
pub fn iter(&self) -> slice::Iter<'_, T> {
self.0.iter()
}
#[allow(clippy::iter_without_into_iter)]
#[inline]
pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
self.0.iter_mut()
}
#[inline]
#[must_use]
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> NonemptyArray<U> {
NonemptyArray(self.0.into_iter().map(f).collect())
}
}
impl<T> From<NonemptyArray<T>> for Box<[T]> {
#[inline]
fn from(value: NonemptyArray<T>) -> Self {
value.into_boxed_slice()
}
}
impl<T> ops::Index<usize> for NonemptyArray<T> {
type Output = T;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
self.0.index(index)
}
}
impl<T> IntoIterator for NonemptyArray<T> {
type Item = T;
type IntoIter = std::vec::IntoIter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.into_boxed_slice().into_vec().into_iter()
}
}
impl<'a, T> IntoIterator for &'a NonemptyArray<T> {
type Item = &'a T;
type IntoIter = slice::Iter<'a, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

View File

@@ -1,55 +0,0 @@
use std::collections::VecDeque;
use std::fmt::{Debug, Formatter};
use std::task::{Context, Waker};
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
pub struct WakerDeque<T> {
waker: Option<Waker>,
deque: VecDeque<T>,
}
impl<T: Debug> Debug for WakerDeque<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.deque.fmt(f)
}
}
impl<T> WakerDeque<T> {
pub fn new() -> Self {
Self {
waker: None,
deque: VecDeque::new(),
}
}
fn update(&mut self, cx: &mut Context<'_>) {
self.waker = Some(cx.waker().clone());
}
fn wake(&mut self) {
let Some(ref mut w) = self.waker else { return };
w.wake_by_ref();
self.waker = None;
}
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_front()
}
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_back()
}
pub fn push_front(&mut self, value: T) {
self.wake();
self.deque.push_front(value);
}
pub fn push_back(&mut self, value: T) {
self.wake();
self.deque.push_back(value);
}
}

View File

@@ -0,0 +1,3 @@
from importlib.metadata import version
__version__ = version("exo")

View File

@@ -39,9 +39,9 @@ class Node:
@classmethod @classmethod
async def create(cls, args: "Args") -> "Self": async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58()) node_id = NodeId(str(keypair.endpoint_id()))
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair) router = await Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS) await router.register_topic(topics.GLOBAL_EVENTS)
await router.register_topic(topics.LOCAL_EVENTS) await router.register_topic(topics.LOCAL_EVENTS)
await router.register_topic(topics.COMMANDS) await router.register_topic(topics.COMMANDS)

View File

@@ -95,6 +95,7 @@ class Master:
self._tg.cancel_scope.cancel() self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None: async def _command_processor(self) -> None:
retry_num = 0
with self.command_receiver as commands: with self.command_receiver as commands:
async for forwarder_command in commands: async for forwarder_command in commands:
try: try:
@@ -186,11 +187,12 @@ class Master:
command.finished_command_id command.finished_command_id
] ]
case RequestEventLog(): case RequestEventLog():
retry_num += 1
# We should just be able to send everything, since other buffers will ignore old messages # We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)): for i in range(command.since_idx, len(self._event_log)):
await self._send_event( event = self._event_log[i]
IndexedEvent(idx=i, event=self._event_log[i]) event.retry = retry_num
) await self._send_event(IndexedEvent(idx=i, event=event))
for event in generated_events: for event in generated_events:
await self.event_sender.send(event) await self.event_sender.send(event)
except ValueError as e: except ValueError as e:
@@ -232,7 +234,7 @@ class Master:
local_event.origin, local_event.origin,
) )
for event in self._multi_buffer.drain(): for event in self._multi_buffer.drain():
logger.debug(f"Master indexing event: {str(event)[:100]}") logger.trace(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log)) indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed) self.state = apply(self.state, indexed)

View File

@@ -13,6 +13,7 @@ from exo.master.placement_utils import (
get_shard_assignments, get_shard_assignments,
get_smallest_cycles, get_smallest_cycles,
) )
from exo.routing.connection_message import IpAddress
from exo.shared.topology import Topology from exo.shared.topology import Topology
from exo.shared.types.commands import ( from exo.shared.types.commands import (
CreateInstance, CreateInstance,
@@ -130,13 +131,13 @@ def place_instance(
jaccl_coordinators=mlx_jaccl_coordinators, jaccl_coordinators=mlx_jaccl_coordinators,
) )
case InstanceMeta.MlxRing: case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph) hosts: list[IpAddress] = get_hosts_from_subgraph(cycle_digraph)
target_instances[instance_id] = MlxRingInstance( target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id, instance_id=instance_id,
shard_assignments=shard_assignments, shard_assignments=shard_assignments,
hosts=[ hosts=[
Host( Host(
ip=host.ip, ip=str(host),
port=random_ephemeral_port(), port=random_ephemeral_port(),
) )
for host in hosts for host in hosts

View File

@@ -4,8 +4,9 @@ from typing import TypeGuard, cast
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from exo.routing.connection_message import IpAddress
from exo.shared.topology import Topology from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile from exo.shared.types.profiling import NodePerformanceProfile
@@ -153,7 +154,8 @@ def get_shard_assignments(
) )
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]: def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[IpAddress]:
# this function is wrong.
cycles = cycle_digraph.get_cycles() cycles = cycle_digraph.get_cycles()
expected_length = len(list(cycle_digraph.list_nodes())) expected_length = len(list(cycle_digraph.list_nodes()))
cycles = [cycle for cycle in cycles if len(cycle) == expected_length] cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
@@ -171,24 +173,20 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}") logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
cycle = cycles[0] cycle = cycles[0]
hosts: list[Host] = [] hosts: list[IpAddress] = []
for i in range(len(cycle)): for i in range(len(cycle)):
current_node = cycle[i] current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)] next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.list_connections(): for connection in cycle_digraph.list_connections():
if ( if (
connection.local_node_id == current_node.node_id connection.source_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id and connection.sink_id == next_node.node_id
): ):
if get_thunderbolt and not connection.is_thunderbolt(): if get_thunderbolt and not connection.is_thunderbolt():
continue continue
assert connection.send_back_multiaddr is not None assert connection.sink_addr is not None
host = Host( hosts.append(connection.sink_addr)
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break break
return hosts return hosts
@@ -242,10 +240,10 @@ def _find_connection_ip(
"""Find all IP addresses that connect node i to node j.""" """Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.list_connections(): for connection in cycle_digraph.list_connections():
if ( if (
connection.local_node_id == node_i.node_id connection.source_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id and connection.sink_id == node_j.node_id
): ):
yield connection.send_back_multiaddr.ip_address yield str(connection.sink_addr)
def _find_interface_name_for_ip( def _find_interface_name_for_ip(

View File

@@ -1,9 +1,7 @@
from typing import Callable from ipaddress import ip_address
from itertools import count
import pytest
from exo.shared.types.common import NodeId from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ( from exo.shared.types.profiling import (
MemoryPerformanceProfile, MemoryPerformanceProfile,
NodePerformanceProfile, NodePerformanceProfile,
@@ -11,57 +9,45 @@ from exo.shared.types.profiling import (
) )
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
ip_octet_iter = count()
@pytest.fixture
def create_node(): def create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo: if node_id is None:
if node_id is None: node_id = NodeId()
node_id = NodeId() return NodeInfo(
return NodeInfo( node_id=node_id,
node_id=node_id, node_profile=NodePerformanceProfile(
node_profile=NodePerformanceProfile( model_id="test",
model_id="test", chip_id="test",
chip_id="test", friendly_name="test",
friendly_name="test", memory=MemoryPerformanceProfile.from_bytes(
memory=MemoryPerformanceProfile.from_bytes( ram_total=1000,
ram_total=1000, ram_available=memory,
ram_available=memory, swap_total=1000,
swap_total=1000, swap_available=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
), ),
) network_interfaces=[],
system=SystemPerformanceProfile(),
return _create_node ),
)
# TODO: this is a hack to get the port for the send_back_multiaddr def create_connection(
@pytest.fixture source_node_id: NodeId,
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]: sink_node_id: NodeId,
port_counter = 1235 *,
ip_counter = 1 ip_octet: int | None = None,
) -> Connection:
global ip_octet_iter
def _create_connection( return Connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None source_id=source_node_id,
) -> Connection: sink_id=sink_node_id,
nonlocal port_counter sink_addr=ip_address(
nonlocal ip_counter f"169.254.0.{ip_octet if ip_octet is not None else next(ip_octet_iter)}"
# assign unique ips ),
ip_counter += 1 connection_profile=ConnectionProfile(
if send_back_port is None: throughput=1000, latency=1000, jitter=1000
send_back_port = port_counter ),
port_counter += 1 )
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
return _create_connection

View File

@@ -43,7 +43,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_master(): async def test_master():
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58()) node_id = NodeId(str(keypair.endpoint_id()))
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]() ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -74,7 +74,7 @@ async def test_master():
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(master.run) tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender") sender_node_id = NodeId(f"{keypair.to_postcard_encoding()}_sender")
# inject a NodePerformanceProfile event # inject a NodePerformanceProfile event
logger.info("inject a NodePerformanceProfile event") logger.info("inject a NodePerformanceProfile event")
await local_event_sender.send( await local_event_sender.send(
@@ -140,7 +140,6 @@ async def test_master():
origin=node_id, origin=node_id,
command=( command=(
ChatCompletion( ChatCompletion(
command_id=CommandId(),
request_params=ChatCompletionTaskParams( request_params=ChatCompletionTaskParams(
model="llama-3.2-1b", model="llama-3.2-1b",
messages=[ messages=[

View File

@@ -1,4 +1,4 @@
from typing import Callable from ipaddress import ip_address
import pytest import pytest
from loguru import logger from loguru import logger
@@ -7,6 +7,7 @@ from exo.master.placement import (
get_transition_events, get_transition_events,
place_instance, place_instance,
) )
from exo.master.tests.conftest import create_connection, create_node
from exo.shared.topology import Topology from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId from exo.shared.types.common import CommandId, NodeId
@@ -14,7 +15,6 @@ from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import ( from exo.shared.types.worker.instances import (
Instance, Instance,
InstanceId, InstanceId,
@@ -76,8 +76,6 @@ def test_get_instance_placements_create_instance(
expected_layers: tuple[int, int, int], expected_layers: tuple[int, int, int],
topology: Topology, topology: Topology,
model_meta: ModelMetadata, model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
): ):
# arrange # arrange
model_meta.n_layers = total_layers model_meta.n_layers = total_layers
@@ -123,9 +121,7 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit( def test_get_instance_placements_one_node_exact_fit() -> None:
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology() topology = Topology()
node_id = NodeId() node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id)) topology.add_node(create_node(1000 * 1024, node_id))
@@ -148,9 +144,7 @@ def test_get_instance_placements_one_node_exact_fit(
assert len(instance.shard_assignments.runner_to_shard) == 1 assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory( def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology() topology = Topology()
node_id = NodeId() node_id = NodeId()
topology.add_node(create_node(1001 * 1024, node_id)) topology.add_node(create_node(1001 * 1024, node_id))
@@ -173,9 +167,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
assert len(instance.shard_assignments.runner_to_shard) == 1 assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit( def test_get_instance_placements_one_node_not_fit() -> None:
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology() topology = Topology()
node_id = NodeId() node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id)) topology.add_node(create_node(1000 * 1024, node_id))
@@ -237,8 +229,6 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_prioritizes_leaf_cycle_with_less_memory( def test_placement_prioritizes_leaf_cycle_with_less_memory(
topology: Topology, topology: Topology,
model_meta: ModelMetadata, model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
): ):
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing # Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes # neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
@@ -316,8 +306,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
def test_tensor_rdma_backend_connectivity_matrix( def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology, topology: Topology,
model_meta: ModelMetadata, model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
): ):
model_meta.n_layers = 12 model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500 model_meta.storage_size.in_bytes = 1500
@@ -332,7 +320,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
ethernet_interface = NetworkInterfaceInfo( ethernet_interface = NetworkInterfaceInfo(
name="en0", name="en0",
ip_address="192.168.1.100", ip_address=ip_address("192.168.1.100"),
) )
assert node_a.node_profile is not None assert node_a.node_profile is not None
@@ -347,13 +335,13 @@ def test_tensor_rdma_backend_connectivity_matrix(
conn_c_b = create_connection(node_id_c, node_id_b) conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c) conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None assert conn_a_b.sink_addr is not None
assert conn_b_c.send_back_multiaddr is not None assert conn_b_c.sink_addr is not None
assert conn_c_a.send_back_multiaddr is not None assert conn_c_a.sink_addr is not None
assert conn_b_a.send_back_multiaddr is not None assert conn_b_a.sink_addr is not None
assert conn_c_b.send_back_multiaddr is not None assert conn_c_b.sink_addr is not None
assert conn_a_c.send_back_multiaddr is not None assert conn_a_c.sink_addr is not None
node_a.node_profile = NodePerformanceProfile( node_a.node_profile = NodePerformanceProfile(
model_id="test", model_id="test",
@@ -363,11 +351,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en3", name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address, ip_address=conn_c_a.sink_addr,
), ),
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en4", name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address, ip_address=conn_b_a.sink_addr,
), ),
ethernet_interface, ethernet_interface,
], ],
@@ -381,11 +369,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en3", name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address, ip_address=conn_c_b.sink_addr,
), ),
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en4", name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address, ip_address=conn_a_b.sink_addr,
), ),
ethernet_interface, ethernet_interface,
], ],
@@ -399,11 +387,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en3", name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address, ip_address=conn_a_c.sink_addr,
), ),
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en4", name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address, ip_address=conn_b_c.sink_addr,
), ),
ethernet_interface, ethernet_interface,
], ],

View File

@@ -1,5 +1,3 @@
from typing import Callable
import pytest import pytest
from exo.master.placement_utils import ( from exo.master.placement_utils import (
@@ -9,12 +7,12 @@ from exo.master.placement_utils import (
get_shard_assignments, get_shard_assignments,
get_smallest_cycles, get_smallest_cycles,
) )
from exo.master.tests.conftest import create_connection, create_node
from exo.shared.topology import Topology from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding from exo.shared.types.worker.shards import Sharding
@@ -26,8 +24,6 @@ def topology() -> Topology:
def test_filter_cycles_by_memory( def test_filter_cycles_by_memory(
topology: Topology, topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
): ):
# arrange # arrange
node1_id = NodeId() node1_id = NodeId()
@@ -60,8 +56,6 @@ def test_filter_cycles_by_memory(
def test_filter_cycles_by_insufficient_memory( def test_filter_cycles_by_insufficient_memory(
topology: Topology, topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
): ):
# arrange # arrange
node1_id = NodeId() node1_id = NodeId()
@@ -90,8 +84,6 @@ def test_filter_cycles_by_insufficient_memory(
def test_filter_multiple_cycles_by_memory( def test_filter_multiple_cycles_by_memory(
topology: Topology, topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
): ):
# arrange # arrange
node_a_id = NodeId() node_a_id = NodeId()
@@ -129,8 +121,6 @@ def test_filter_multiple_cycles_by_memory(
def test_get_smallest_cycles( def test_get_smallest_cycles(
topology: Topology, topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
): ):
# arrange # arrange
node_a_id = NodeId() node_a_id = NodeId()
@@ -169,8 +159,6 @@ def test_get_smallest_cycles(
) )
def test_get_shard_assignments( def test_get_shard_assignments(
topology: Topology, topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int], available_memory: tuple[int, int, int],
total_layers: int, total_layers: int,
expected_layers: tuple[int, int, int], expected_layers: tuple[int, int, int],
@@ -230,8 +218,6 @@ def test_get_shard_assignments(
def test_get_hosts_from_subgraph( def test_get_hosts_from_subgraph(
topology: Topology, topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
): ):
# arrange # arrange
node_a_id = NodeId() node_a_id = NodeId()
@@ -246,10 +232,10 @@ def test_get_hosts_from_subgraph(
topology.add_node(node_b) topology.add_node(node_b)
topology.add_node(node_c) topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001)) topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002)) topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003)) topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004)) topology.add_connection(create_connection(node_b_id, node_a_id))
# act # act
hosts = get_hosts_from_subgraph(topology) hosts = get_hosts_from_subgraph(topology)
@@ -267,8 +253,6 @@ def test_get_hosts_from_subgraph(
def test_get_mlx_jaccl_coordinators( def test_get_mlx_jaccl_coordinators(
topology: Topology, topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
): ):
# arrange # arrange
node_a_id = NodeId() node_a_id = NodeId()
@@ -279,12 +263,12 @@ def test_get_mlx_jaccl_coordinators(
node_b = create_node(500 * 1024, node_b_id) node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id) node_c = create_node(1000 * 1024, node_c_id)
conn_a_b = create_connection(node_a_id, node_b_id, 5001) conn_a_b = create_connection(node_a_id, node_b_id)
conn_b_a = create_connection(node_b_id, node_a_id, 5002) conn_b_a = create_connection(node_b_id, node_a_id)
conn_b_c = create_connection(node_b_id, node_c_id, 5003) conn_b_c = create_connection(node_b_id, node_c_id)
conn_c_b = create_connection(node_c_id, node_b_id, 5004) conn_c_b = create_connection(node_c_id, node_b_id)
conn_c_a = create_connection(node_c_id, node_a_id, 5005) conn_c_a = create_connection(node_c_id, node_a_id)
conn_a_c = create_connection(node_a_id, node_c_id, 5006) conn_a_c = create_connection(node_a_id, node_c_id)
# Update node profiles with network interfaces before adding to topology # Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None assert node_a.node_profile is not None
@@ -299,11 +283,11 @@ def test_get_mlx_jaccl_coordinators(
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en3", name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address, ip_address=conn_a_b.sink_addr.ip,
), ),
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en4", name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address, ip_address=conn_a_c.sink_addr.ip,
), ),
], ],
system=node_a.node_profile.system, system=node_a.node_profile.system,
@@ -316,11 +300,11 @@ def test_get_mlx_jaccl_coordinators(
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en3", name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address, ip_address=conn_b_a.sink_addr.ip,
), ),
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en4", name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address, ip_address=conn_b_c.sink_addr.ip,
), ),
], ],
system=node_b.node_profile.system, system=node_b.node_profile.system,
@@ -333,11 +317,11 @@ def test_get_mlx_jaccl_coordinators(
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en3", name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address, ip_address=conn_c_b.sink_addr.ip,
), ),
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en4", name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address, ip_address=conn_c_a.sink_addr.ip,
), ),
], ],
system=node_c.node_profile.system, system=node_c.node_profile.system,
@@ -387,11 +371,11 @@ def test_get_mlx_jaccl_coordinators(
# Non-rank-0 nodes should use the specific IP from their connection to rank 0 # Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a) # node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == ( assert coordinators[node_b_id] == (f"{conn_b_a.sink_addr.ip}:5000"), (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000" "node_b should use the IP from conn_b_a"
), "node_b should use the IP from conn_b_a" )
# node_c uses the IP from conn_c_a (node_c -> node_a) # node_c uses the IP from conn_c_a (node_c -> node_a)
assert coordinators[node_c_id] == ( assert coordinators[node_c_id] == (f"{conn_c_a.sink_addr.ip}:5000"), (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000" "node_c should use the IP from conn_c_a"
), "node_c should use the IP from conn_c_a" )

View File

@@ -1,7 +1,9 @@
from ipaddress import ip_address
import pytest import pytest
from exo.routing.connection_message import SocketAddress
from exo.shared.topology import Topology from exo.shared.topology import Topology
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ( from exo.shared.types.profiling import (
MemoryPerformanceProfile, MemoryPerformanceProfile,
NodePerformanceProfile, NodePerformanceProfile,
@@ -18,9 +20,9 @@ def topology() -> Topology:
@pytest.fixture @pytest.fixture
def connection() -> Connection: def connection() -> Connection:
return Connection( return Connection(
local_node_id=NodeId(), source_id=NodeId(),
send_back_node_id=NodeId(), sink_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"), sink_addr=SocketAddress(ip=ip_address("127.0.0.1"), port=1235, zone_id=None),
connection_profile=ConnectionProfile( connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000 throughput=1000, latency=1000, jitter=1000
), ),
@@ -64,12 +66,8 @@ def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
): ):
# arrange # arrange
topology.add_node( topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection) topology.add_connection(connection)
# act # act
@@ -83,12 +81,8 @@ def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
): ):
# arrange # arrange
topology.add_node( topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection) topology.add_connection(connection)
new_node_profile = NodePerformanceProfile( new_node_profile = NodePerformanceProfile(
@@ -103,12 +97,10 @@ def test_update_node_profile(
) )
# act # act
topology.update_node_profile( topology.update_node_profile(connection.source_id, node_profile=new_node_profile)
connection.local_node_id, node_profile=new_node_profile
)
# assert # assert
data = topology.get_node_profile(connection.local_node_id) data = topology.get_node_profile(connection.source_id)
assert data == new_node_profile assert data == new_node_profile
@@ -116,21 +108,17 @@ def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
): ):
# arrange # arrange
topology.add_node( topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection) topology.add_connection(connection)
new_connection_profile = ConnectionProfile( new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000 throughput=2000, latency=2000, jitter=2000
) )
connection = Connection( connection = Connection(
local_node_id=connection.local_node_id, source_id=connection.source_id,
send_back_node_id=connection.send_back_node_id, sink_id=connection.sink_id,
send_back_multiaddr=connection.send_back_multiaddr, sink_addr=connection.sink_addr,
connection_profile=new_connection_profile, connection_profile=new_connection_profile,
) )
@@ -146,12 +134,8 @@ def test_remove_connection_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
): ):
# arrange # arrange
topology.add_node( topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection) topology.add_connection(connection)
# act # act
@@ -165,31 +149,23 @@ def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
): ):
# arrange # arrange
topology.add_node( topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection) topology.add_connection(connection)
# act # act
topology.remove_node(connection.local_node_id) topology.remove_node(connection.source_id)
# assert # assert
assert topology.get_node_profile(connection.local_node_id) is None assert topology.get_node_profile(connection.source_id) is None
def test_list_nodes( def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
): ):
# arrange # arrange
topology.add_node( topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection) topology.add_connection(connection)
# act # act
@@ -199,6 +175,6 @@ def test_list_nodes(
assert len(nodes) == 2 assert len(nodes) == 2
assert all(isinstance(node, NodeInfo) for node in nodes) assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == { assert {node.node_id for node in nodes} == {
connection.local_node_id, connection.source_id,
connection.send_back_node_id, connection.sink_id,
} }

View File

@@ -1,37 +1,44 @@
from enum import Enum from ipaddress import IPv4Address, IPv6Address, ip_address
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType from exo_pyo3_bindings import RustConnectionMessage
from pydantic import ConfigDict
from exo.shared.types.common import NodeId from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages""" """Serialisable types for Connection Updates/Messages"""
IpAddress = IPv4Address | IPv6Address
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod class SocketAddress(CamelCaseModel):
def from_update_type(update_type: ConnectionUpdateType): # could be the python IpAddress type if we're feeling fancy
match update_type: ip: IpAddress
case ConnectionUpdateType.Connected: port: int
return ConnectionMessageType.Connected zone_id: int | None
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected model_config = ConfigDict(
frozen=True,
)
class ConnectionMessage(CamelCaseModel): class ConnectionMessage(CamelCaseModel):
node_id: NodeId node_id: NodeId
connection_type: ConnectionMessageType ips: set[SocketAddress] | None
remote_ipv4: str
remote_tcp_port: int
@classmethod @classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage": def from_rust(cls, message: RustConnectionMessage) -> "ConnectionMessage":
return cls( return cls(
node_id=NodeId(update.peer_id.to_base58()), node_id=NodeId(str(message.endpoint_id)),
connection_type=ConnectionMessageType.from_update_type(update.update_type), ips=None
remote_ipv4=update.remote_ipv4, if message.current_transport_addrs is None
remote_tcp_port=update.remote_tcp_port, else set(
# TODO: better handle fallible conversion
SocketAddress(
ip=ip_address(addr.ip_addr()),
port=addr.port(),
zone_id=addr.zone_id(),
)
for addr in message.current_transport_addrs
),
) )

View File

@@ -5,6 +5,7 @@ from os import PathLike
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
import anyio
from anyio import ( from anyio import (
BrokenResourceError, BrokenResourceError,
ClosedResourceError, ClosedResourceError,
@@ -13,14 +14,15 @@ from anyio import (
) )
from anyio.abc import TaskGroup from anyio.abc import TaskGroup
from exo_pyo3_bindings import ( from exo_pyo3_bindings import (
AllQueuesFullError,
Keypair, Keypair,
NetworkingHandle, RustNetworkingHandle,
NoPeersSubscribedToTopicError, RustReceiver,
RustSender,
) )
from filelock import FileLock from filelock import FileLock
from loguru import logger from loguru import logger
from exo import __version__
from exo.shared.constants import EXO_NODE_ID_KEYPAIR from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.utils.channels import Receiver, Sender, channel from exo.utils.channels import Receiver, Sender, channel
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
@@ -37,7 +39,6 @@ class TopicRouter[T: CamelCaseModel]:
def __init__( def __init__(
self, self,
topic: TypedTopic[T], topic: TypedTopic[T],
networking_sender: Sender[tuple[str, bytes]],
max_buffer_size: float = inf, max_buffer_size: float = inf,
): ):
self.topic: TypedTopic[T] = topic self.topic: TypedTopic[T] = topic
@@ -45,26 +46,41 @@ class TopicRouter[T: CamelCaseModel]:
send, recv = channel[T]() send, recv = channel[T]()
self.receiver: Receiver[T] = recv self.receiver: Receiver[T] = recv
self._sender: Sender[T] = send self._sender: Sender[T] = send
self.networking_sender: Sender[tuple[str, bytes]] = networking_sender self.networking_sender: RustSender | None = None
self.networking_receiver: RustReceiver | None = None
self._tg: TaskGroup = create_task_group()
async def run(self): async def run(self):
async with self._tg as tg:
tg.start_soon(self.receive_loop)
async def receive_loop(self):
logger.debug(f"Topic Router {self.topic} ready to send") logger.debug(f"Topic Router {self.topic} ready to send")
with self.receiver as items: with self.receiver as items:
async for item in items: async for item in items:
# Check if we should send to network # Check if we should send to network
if ( if (
len(self.senders) == 0 self.topic.publish_policy is PublishPolicy.Always
and self.topic.publish_policy is PublishPolicy.Minimal and self.networking_sender is not None
): ):
await self._send_out(item) await self._send_out(item)
continue
if self.topic.publish_policy is PublishPolicy.Always:
await self._send_out(item)
# Then publish to all senders # Then publish to all senders
await self.publish(item) await self.publish(item)
logger.debug(f"Shut down Topic Router {self.topic}")
async def net_receive_loop(self):
assert self.networking_receiver is not None
while True:
item = self.topic.deserialize(await self.networking_receiver.receive())
await self.publish(item)
def subscribe_with(self, net_send: RustSender, net_recv: RustReceiver):
self.networking_sender = net_send
self.networking_receiver = net_recv
self._tg.start_soon(self.net_receive_loop)
async def shutdown(self): async def shutdown(self):
logger.debug(f"Shutting down Topic Router {self.topic}")
# Close all the things! # Close all the things!
for sender in self.senders: for sender in self.senders:
sender.close() sender.close()
@@ -85,43 +101,32 @@ class TopicRouter[T: CamelCaseModel]:
to_clear.add(sender) to_clear.add(sender)
self.senders -= to_clear self.senders -= to_clear
async def publish_bytes(self, data: bytes):
await self.publish(self.topic.deserialize(data))
def new_sender(self) -> Sender[T]: def new_sender(self) -> Sender[T]:
return self._sender.clone() return self._sender.clone()
async def _send_out(self, item: T): async def _send_out(self, item: T):
assert self.networking_sender is not None
logger.trace(f"TopicRouter {self.topic.topic} sending {item}") logger.trace(f"TopicRouter {self.topic.topic} sending {item}")
await self.networking_sender.send( await self.networking_sender.send(self.topic.serialize(item))
(str(self.topic.topic), self.topic.serialize(item))
)
class Router: class Router:
@classmethod @classmethod
def create(cls, identity: Keypair) -> "Router": async def create(cls, identity: Keypair) -> "Router":
return cls(handle=NetworkingHandle(identity)) return cls(handle=await RustNetworkingHandle.create(identity, __version__))
def __init__(self, handle: NetworkingHandle): def __init__(self, handle: RustNetworkingHandle):
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {} self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
send, recv = channel[tuple[str, bytes]]() self._unsubbed: list[str] = []
self.networking_receiver: Receiver[tuple[str, bytes]] = recv self._net: RustNetworkingHandle = handle
self._net: NetworkingHandle = handle
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
self._id_count = count() self._id_count = count()
self._tg: TaskGroup | None = None self._tg: TaskGroup | None = None
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]): async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
assert self._tg is None, "Attempted to register topic after setup time" assert self._tg is None, "Attempted to register topic after setup time"
send = self._tmp_networking_sender router = TopicRouter[T](topic)
if send:
self._tmp_networking_sender = None
else:
send = self.networking_receiver.clone_sender()
router = TopicRouter[T](topic, send)
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router) self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
await self._networking_subscribe(str(topic.topic)) self._unsubbed.append(topic.topic)
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]: def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
router = self.topic_routers.get(topic.topic, None) router = self.topic_routers.get(topic.topic, None)
@@ -151,13 +156,9 @@ class Router:
for topic in self.topic_routers: for topic in self.topic_routers:
router = self.topic_routers[topic] router = self.topic_routers[topic]
tg.start_soon(router.run) tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages) tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it. # Router only shuts down if you cancel it.
await sleep_forever() await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self): async def shutdown(self):
logger.debug("Shutting down Router") logger.debug("Shutting down Router")
@@ -165,48 +166,33 @@ class Router:
return return
self._tg.cancel_scope.cancel() self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
async def _networking_recv(self):
while True:
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
if topic not in self.topic_routers:
logger.warning(f"Received message on unknown or inactive topic {topic}")
continue
router = self.topic_routers[topic]
await router.publish_bytes(data)
async def _networking_recv_connection_messages(self): async def _networking_recv_connection_messages(self):
recv = await self._net.get_connection_receiver()
while True: while True:
update = await self._net.connection_update_recv() message = await recv.receive()
message = ConnectionMessage.from_update(update) await anyio.sleep(0.2)
logger.trace( logger.trace(
f"Received message on connection_messages with payload {message}" f"Received message on connection_messages with payload {message}"
) )
to_clear: list[str] = []
for topic in self._unsubbed:
try:
rsend, rrecv = await self._net.subscribe(topic)
logger.info(f"Subscribed to peer on {topic}")
to_clear.append(topic)
self.topic_routers[topic].subscribe_with(rsend, rrecv)
# TODO: real error
except RuntimeError:
pass
if to_clear:
assert to_clear == self._unsubbed
self._unsubbed = [i for i in self._unsubbed if i not in to_clear]
if CONNECTION_MESSAGES.topic in self.topic_routers: if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic] router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router) router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message) await router.publish(ConnectionMessage.from_rust(message))
async def _networking_publish(self):
with self.networking_receiver as networked_items:
async for topic, data in networked_items:
try:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
# As a hack, this also catches AllQueuesFull
# Need to fix that ASAP.
except (NoPeersSubscribedToTopicError, AllQueuesFullError):
pass
def get_node_id_keypair( def get_node_id_keypair(
@@ -225,16 +211,16 @@ def get_node_id_keypair(
with open(path, "a+b") as f: # opens in append-mode => starts at EOF with open(path, "a+b") as f: # opens in append-mode => starts at EOF
# if non-zero EOF, then file exists => use to get node-ID # if non-zero EOF, then file exists => use to get node-ID
if f.tell() != 0: if f.tell() != 0:
f.seek(0) # go to start & read protobuf-encoded bytes f.seek(0) # go to start & read postcard-encoded bytes
protobuf_encoded = f.read() postcard_encoded = f.read()
try: # if decoded successfully, save & return try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded) return Keypair.from_postcard_encoding(postcard_encoded)
except ValueError as e: # on runtime error, assume corrupt file except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}") logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist # if no valid credentials, create new ones and persist
with open(path, "w+b") as f: with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519() keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding()) f.write(keypair.to_postcard_encoding())
return keypair return keypair

View File

@@ -13,8 +13,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
class PublishPolicy(str, Enum): class PublishPolicy(str, Enum):
Never = "Never" Never = "Never"
"""Never publish to the network - this is a local message""" """Never publish to the network - this is a local message"""
Minimal = "Minimal"
"""Only publish when there is no local receiver for this type of message"""
Always = "Always" Always = "Always"
"""Always publish to the network""" """Always publish to the network"""

View File

@@ -164,28 +164,38 @@ class Election:
self._candidates.append(message) self._candidates.append(message)
async def _connection_receiver(self) -> None: async def _connection_receiver(self) -> None:
current_peers: set[NodeId] = set()
with self._cm_receiver as connection_messages: with self._cm_receiver as connection_messages:
async for first in connection_messages: async for first in connection_messages:
# Delay after connection message for time to symmetrically setup if first.node_id not in current_peers or first.ips is None:
await anyio.sleep(0.2) if first.node_id not in current_peers:
rest = connection_messages.collect() current_peers.add(first.node_id)
if first.ips is None:
current_peers.remove(first.node_id)
# Delay after connection message for time to symmetrically setup
await anyio.sleep(0.2)
rest = connection_messages.collect()
for msg in rest:
if msg.node_id not in current_peers:
current_peers.add(first.node_id)
if msg.ips is None:
current_peers.remove(first.node_id)
logger.debug( logger.info(
f"Connection messages received: {first} followed by {rest}" f"Connection messages received: {first} followed by {rest}"
) )
logger.debug(f"Current clock: {self.clock}") logger.info(f"Current clock: {self.clock}")
# These messages are strictly peer to peer # These messages are strictly peer to peer
self.clock += 1 self.clock += 1
logger.debug(f"New clock: {self.clock}") logger.info(f"New clock: {self.clock}")
assert self._tg is not None candidates: list[ElectionMessage] = []
candidates: list[ElectionMessage] = [] self._candidates = candidates
self._candidates = candidates logger.info("Starting new campaign")
logger.debug("Starting new campaign") assert self._tg is not None
self._tg.start_soon( self._tg.start_soon(
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
) )
logger.debug("Campaign started") logger.info("Campaign started")
logger.debug("Connection message added")
async def _command_counter(self) -> None: async def _command_counter(self) -> None:
with self._co_receiver as commands: with self._co_receiver as commands:

View File

@@ -1,7 +1,7 @@
import pytest import pytest
from anyio import create_task_group, fail_after, move_on_after from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import Election, ElectionMessage, ElectionResult from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId from exo.shared.types.common import NodeId, SessionId
@@ -330,9 +330,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
await cm_tx.send( await cm_tx.send(
ConnectionMessage( ConnectionMessage(
node_id=NodeId(), node_id=NodeId(),
connection_type=ConnectionMessageType.Connected, ips=set(),
remote_ipv4="",
remote_tcp_port=0,
) )
) )

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release() sem.release()
# wait to be told to begin simultaneous read # wait to be told to begin simultaneous read
ev.wait() ev.wait()
queue.put(get_node_id_keypair().to_protobuf_encoding()) queue.put(get_node_id_keypair().to_postcard_encoding())
def _get_keypair_concurrent(num_procs: int) -> bytes: def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -1,5 +1,7 @@
from ipaddress import ip_address
from exo.routing.connection_message import SocketAddress
from exo.shared.types.common import NodeId from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State from exo.shared.types.state import State
from exo.shared.types.topology import Connection from exo.shared.types.topology import Connection
@@ -12,9 +14,9 @@ def test_state_serialization_roundtrip() -> None:
node_b = NodeId("node-b") node_b = NodeId("node-b")
connection = Connection( connection = Connection(
local_node_id=node_a, sink_id=node_a,
send_back_node_id=node_b, source_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"), sink_addr=SocketAddress(ip=ip_address("127.0.0.1"), port=5354, zone_id=None),
) )
state = State() state = State()

View File

@@ -81,16 +81,16 @@ class Topology:
self, self,
connection: Connection, connection: Connection,
) -> None: ) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map: if connection.source_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id)) self.add_node(NodeInfo(node_id=connection.source_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map: if connection.sink_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id)) self.add_node(NodeInfo(node_id=connection.sink_id))
if connection in self._edge_id_to_rx_id_map: if connection in self._edge_id_to_rx_id_map:
return return
src_id = self._node_id_to_rx_id_map[connection.local_node_id] src_id = self._node_id_to_rx_id_map[connection.source_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id] sink_id = self._node_id_to_rx_id_map[connection.sink_id]
rx_id = self._graph.add_edge(src_id, sink_id, connection) rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id self._edge_id_to_rx_id_map[connection] = rx_id
@@ -132,10 +132,7 @@ class Topology:
return return
for connection in self.list_connections(): for connection in self.list_connections():
if ( if connection.source_id == node_id or connection.sink_id == node_id:
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection) self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id] rx_idx = self._node_id_to_rx_id_map[node_id]
@@ -188,10 +185,7 @@ class Topology:
for rx_idx in rx_idxs: for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx]) topology.add_node(self._graph[rx_idx])
for connection in self.list_connections(): for connection in self.list_connections():
if ( if connection.source_id in node_idxs and connection.sink_id in node_idxs:
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection) topology.add_connection(connection)
return topology return topology

View File

@@ -1,3 +1,5 @@
from typing import Self
from pydantic import Field from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams from exo.shared.types.api import ChatCompletionTaskParams
@@ -26,6 +28,17 @@ class PlaceInstance(BaseCommand):
instance_meta: InstanceMeta instance_meta: InstanceMeta
min_nodes: int min_nodes: int
# Decision point - I like this syntax better than the typical fixtures,
# but it's """bloat"""
@classmethod
def fixture(cls) -> Self:
return cls(
model_meta=ModelMetadata.fixture(),
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
)
class CreateInstance(BaseCommand): class CreateInstance(BaseCommand):
instance: Instance instance: Instance

View File

@@ -22,7 +22,8 @@ class EventId(Id):
class BaseEvent(TaggedModel): class BaseEvent(TaggedModel):
event_id: EventId = Field(default_factory=EventId) event_id: EventId = Field(default_factory=EventId)
# Internal, for debugging. Please don't rely on this field for anything! # Internal, for debugging. Please don't rely on this field for anything!
_master_time_stamp: None | datetime = None _master_time_stamp: datetime | None = None
retry: int | None = None
class TestEvent(BaseEvent): class TestEvent(BaseEvent):

View File

@@ -1,3 +1,5 @@
from typing import Self
from pydantic import PositiveInt from pydantic import PositiveInt
from exo.shared.types.common import Id from exo.shared.types.common import Id
@@ -14,3 +16,12 @@ class ModelMetadata(CamelCaseModel):
pretty_name: str pretty_name: str
storage_size: Memory storage_size: Memory
n_layers: PositiveInt n_layers: PositiveInt
@classmethod
def fixture(cls) -> Self:
return cls(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
)

View File

@@ -2,6 +2,7 @@ from typing import Self
import psutil import psutil
from exo.routing.connection_message import IpAddress
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
@@ -49,7 +50,7 @@ class SystemPerformanceProfile(CamelCaseModel):
class NetworkInterfaceInfo(CamelCaseModel): class NetworkInterfaceInfo(CamelCaseModel):
name: str name: str
ip_address: str ip_address: IpAddress
class NodePerformanceProfile(CamelCaseModel): class NodePerformanceProfile(CamelCaseModel):

View File

@@ -1,5 +1,5 @@
from exo.routing.connection_message import IpAddress
from exo.shared.types.common import NodeId from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
@@ -10,17 +10,17 @@ class NodeInfo(CamelCaseModel):
class Connection(CamelCaseModel): class Connection(CamelCaseModel):
local_node_id: NodeId source_id: NodeId
send_back_node_id: NodeId sink_id: NodeId
send_back_multiaddr: Multiaddr sink_addr: IpAddress
connection_profile: ConnectionProfile | None = None connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int: def __hash__(self) -> int:
return hash( return hash(
( (
self.local_node_id, self.source_id,
self.send_back_node_id, self.sink_id,
self.send_back_multiaddr.address, self.sink_addr,
) )
) )
@@ -28,10 +28,10 @@ class Connection(CamelCaseModel):
if not isinstance(other, Connection): if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection") raise ValueError("Cannot compare Connection with non-Connection")
return ( return (
self.local_node_id == other.local_node_id self.source_id == other.source_id
and self.send_back_node_id == other.send_back_node_id and self.sink_id == other.sink_id
and self.send_back_multiaddr == other.send_back_multiaddr and self.sink_addr == other.sink_addr
) )
def is_thunderbolt(self) -> bool: def is_thunderbolt(self) -> bool:
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254") return str(self.sink_addr).startswith("169.254")

View File

@@ -0,0 +1,32 @@
from exo.routing.connection_message import ConnectionMessage
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, TopologyEdgeCreated
from exo.shared.types.state import State
from exo.shared.types.topology import Connection
def check_connections(
local_id: NodeId, msg: ConnectionMessage, state: State
) -> list[Event]:
remote_id = msg.node_id
sockets = msg.ips
if (
not state.topology.contains_node(remote_id)
or remote_id not in state.node_profiles
):
return []
out: list[Event] = []
conns = list(state.topology.list_connections())
for iface in state.node_profiles[remote_id].network_interfaces:
if sockets is None:
continue
for sock in sockets:
if iface.ip_address == sock.ip:
conn = Connection(source_id=local_id, sink_id=remote_id, sink_addr=sock)
if state.topology.contains_connection(conn):
conns.remove(conn)
continue
out.append(TopologyEdgeCreated(edge=conn))
return out

View File

@@ -6,7 +6,7 @@ from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio.abc import TaskGroup from anyio.abc import TaskGroup
from loguru import logger from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType from exo.routing.connection_message import ConnectionMessage
from exo.shared.apply import apply from exo.shared.apply import apply
from exo.shared.types.commands import ForwarderCommand, RequestEventLog from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId from exo.shared.types.common import NodeId, SessionId
@@ -23,7 +23,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated, TopologyEdgeCreated,
TopologyEdgeDeleted, TopologyEdgeDeleted,
) )
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State from exo.shared.types.state import State
from exo.shared.types.tasks import ( from exo.shared.types.tasks import (
@@ -257,34 +256,12 @@ class Worker:
async def _connection_message_event_writer(self): async def _connection_message_event_writer(self):
with self.connection_message_receiver as connection_messages: with self.connection_message_receiver as connection_messages:
async for msg in connection_messages: async for _msg in connection_messages:
await self.event_sender.send( break
self._convert_connection_message_to_event(msg) # TODO: use mdns for partial discovery
) # for event in check_connections(self.node_id, msg, self.state):
# logger.info(f"Worker discovered connection {event}")
def _convert_connection_message_to_event(self, msg: ConnectionMessage): # await self.event_sender.send(event)
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
)
async def _nack_request(self, since_idx: int) -> None: async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index. # We request all events after (and including) the missing index.
@@ -403,7 +380,7 @@ class Worker:
session=self.session_id, session=self.session_id,
event=event, event=event,
) )
logger.debug( logger.trace(
f"Worker published event {self.local_event_index}: {str(event)[:100]}" f"Worker published event {self.local_event_index}: {str(event)[:100]}"
) )
self.local_event_index += 1 self.local_event_index += 1
@@ -412,30 +389,21 @@ class Worker:
async def _poll_connection_updates(self): async def _poll_connection_updates(self):
while True: while True:
# TODO: EdgeDeleted edges = self.state.topology.out_edges(self.node_id)
edges = set(self.state.topology.list_connections()) pure_edges = set(edge for _, edge in edges)
conns = await check_reachable(self.state.topology) conns = await check_reachable(self.state.topology)
for nid, conn in edges:
if nid in conns and conn.sink_addr in conns.get(nid, set()):
continue
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
for nid in conns: for nid in conns:
for ip in conns[nid]: for ip in conns[nid]:
edge = Connection( edge = Connection(sink_id=self.node_id, source_id=nid, sink_addr=ip)
local_node_id=self.node_id, if edge not in pure_edges:
send_back_node_id=nid,
# nonsense multiaddr
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}") logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(TopologyEdgeCreated(edge=edge)) await self.event_sender.send(TopologyEdgeCreated(edge=edge))
for nid, conn in self.state.topology.out_edges(self.node_id):
if (
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10) await anyio.sleep(10)

View File

@@ -1,19 +1,21 @@
import socket import socket
from ipaddress import ip_address
from anyio import create_task_group, to_thread from anyio import create_task_group, to_thread
from exo.routing.connection_message import IpAddress
from exo.shared.topology import Topology from exo.shared.topology import Topology
from exo.shared.types.common import NodeId from exo.shared.types.common import NodeId
# TODO: ref. api port # TODO: ref. api port
async def check_reachability( async def check_reachability(
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]] target_ip: IpAddress, target_node_id: NodeId, out: dict[NodeId, set[IpAddress]]
) -> None: ) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout sock.settimeout(1) # 1 second timeout
try: try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415)) result = await to_thread.run_sync(sock.connect_ex, (str(target_ip), 52415))
except socket.gaierror: except socket.gaierror:
# seems to throw on ipv6 loopback. oh well # seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}") # logger.warning(f"invalid {target_ip=}")
@@ -24,11 +26,11 @@ async def check_reachability(
if result == 0: if result == 0:
if target_node_id not in out: if target_node_id not in out:
out[target_node_id] = set() out[target_node_id] = set()
out[target_node_id].add(target_ip) out[target_node_id].add(ip_address(target_ip))
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]: async def check_reachable(topology: Topology) -> dict[NodeId, set[IpAddress]]:
reachable: dict[NodeId, set[str]] = {} reachable: dict[NodeId, set[IpAddress]] = {}
async with create_task_group() as tg: async with create_task_group() as tg:
for node in topology.list_nodes(): for node in topology.list_nodes():
if not node.node_profile: if not node.node_profile:

View File

@@ -1,5 +1,6 @@
import socket import socket
import sys import sys
from ipaddress import ip_address
from subprocess import CalledProcessError from subprocess import CalledProcessError
import psutil import psutil
@@ -29,12 +30,6 @@ async def get_friendly_name() -> str:
def get_network_interfaces() -> list[NetworkInterfaceInfo]: def get_network_interfaces() -> list[NetworkInterfaceInfo]:
"""
Retrieves detailed network interface information on macOS.
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
to determine interface names, IP addresses, and types (ethernet, wifi, vpn, other).
Returns a list of NetworkInterfaceInfo objects.
"""
interfaces_info: list[NetworkInterfaceInfo] = [] interfaces_info: list[NetworkInterfaceInfo] = []
for iface, services in psutil.net_if_addrs().items(): for iface, services in psutil.net_if_addrs().items():
@@ -42,7 +37,9 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
match service.family: match service.family:
case socket.AF_INET | socket.AF_INET6: case socket.AF_INET | socket.AF_INET6:
interfaces_info.append( interfaces_info.append(
NetworkInterfaceInfo(name=iface, ip_address=service.address) NetworkInterfaceInfo(
name=iface, ip_address=ip_address(service.address)
)
) )
case _: case _:
pass pass

2
uv.lock generated
View File

@@ -316,7 +316,7 @@ wheels = [
[[package]] [[package]]
name = "exo" name = "exo"
version = "0.3.0" version = "0.10.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },