mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
uh i did it i think
This commit is contained in:
55
Cargo.toml
55
Cargo.toml
@@ -1,10 +1,8 @@
|
||||
[workspace]
|
||||
resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
"exo_pyo3_bindings",
|
||||
"util", "iroh_networking",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -24,60 +22,49 @@ opt-level = 3
|
||||
# Common configurations include versions, paths, features, etc.
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
iroh_networking = { path = "iroh_networking" }
|
||||
util = { path = "util" }
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
n0-future = "0.3.1"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
postcard = "1.1.3"
|
||||
n0-error = "0.1.2"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
blake3 = "1.8.2"
|
||||
env_logger = "0.11"
|
||||
tracing-subscriber = "0.3.20"
|
||||
|
||||
# networking
|
||||
libp2p = "0.56"
|
||||
libp2p-tcp = "0.44"
|
||||
iroh = "0.95.1"
|
||||
iroh-gossip = "0.95.0"
|
||||
|
||||
# pyo3
|
||||
pyo3 = "0.27.1"
|
||||
pyo3-async-runtimes = "0.27.0"
|
||||
pyo3-log = "0.13.2"
|
||||
pyo3-stub-gen = "0.17.2"
|
||||
|
||||
# other
|
||||
rand = "0.9.2"
|
||||
|
||||
[workspace.lints.rust]
|
||||
static_mut_refs = "warn" # Or use "warn" instead of deny
|
||||
|
||||
1
TODO.md
1
TODO.md
@@ -19,6 +19,7 @@
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Really need to remove all mlx logic outside of the runner - API has a transitive dependency on engines which imports mlx
|
||||
|
||||
Potential refactors:
|
||||
|
||||
|
||||
13
flake.nix
13
flake.nix
@@ -69,14 +69,11 @@
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
((fenixToolchain system).withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
rustup # Just here to make RustRover happy
|
||||
cargo
|
||||
bacon
|
||||
rust-analyzer
|
||||
rustc
|
||||
rustfmt
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "exo"
|
||||
version = "0.3.0"
|
||||
version = "0.10.0"
|
||||
description = "Exo"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
|
||||
@@ -5,8 +5,6 @@ edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
path = "src/lib.rs"
|
||||
name = "exo_pyo3_bindings"
|
||||
|
||||
# "cdylib" needed to produce shared library for Python to import
|
||||
@@ -22,46 +20,25 @@ doc = false
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
networking = { workspace = true }
|
||||
iroh_networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-log = "0.13.2"
|
||||
pyo3 = { workspace = true, features = ["experimental-async"] }
|
||||
pyo3-stub-gen = { workspace = true }
|
||||
pyo3-async-runtimes = { workspace = true, features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-log = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
postcard = { workspace = true, features = ["use-std"] }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
n0-future = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
@@ -70,8 +47,9 @@ thiserror = { workspace = true }
|
||||
#console-subscriber = "0.1.5"
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
env_logger = { workspace = true }
|
||||
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
iroh = { workspace = true }
|
||||
iroh-gossip = { workspace = true }
|
||||
|
||||
@@ -2,220 +2,63 @@
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
import enum
|
||||
import typing
|
||||
|
||||
@typing.final
|
||||
class AllQueuesFullError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
class EndpointId:
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
class IpAddress:
|
||||
def __str__(self) -> builtins.str: ...
|
||||
def ip_addr(self) -> builtins.str: ...
|
||||
def port(self) -> builtins.int: ...
|
||||
def zone_id(self) -> typing.Optional[builtins.int]: ...
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ecdsa() -> Keypair:
|
||||
def from_postcard_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Generate a new ECDSA keypair.
|
||||
Decode a postcard structure into a keypair
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
def to_postcard_encoding(self) -> bytes:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
Encode a private key with the postcard format
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
def endpoint_id(self) -> EndpointId:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
Read out the endpoint id corresponding to this keypair
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
"""
|
||||
class RustConnectionMessage:
|
||||
@property
|
||||
def endpoint_id(self) -> EndpointId: ...
|
||||
@property
|
||||
def current_transport_addrs(self) -> builtins.set[IpAddress]: ...
|
||||
|
||||
@typing.final
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
|
||||
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
"""
|
||||
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Unsubscribes from a `GossipSub` topic.
|
||||
|
||||
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
"""
|
||||
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
|
||||
r"""
|
||||
Publishes a message with multiple topics to the `GossipSub` network.
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
class RustConnectionReceiver:
|
||||
async def receive(self) -> RustConnectionMessage: ...
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
class RustNetworkingHandle:
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
async def create(identity: Keypair, namespace: builtins.str) -> RustNetworkingHandle: ...
|
||||
async def subscribe(self, topic: builtins.str) -> tuple[RustSender, RustReceiver]: ...
|
||||
async def get_connection_receiver(self) -> RustConnectionReceiver: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
class RustReceiver:
|
||||
async def receive(self) -> builtins.list[builtins.int]: ...
|
||||
|
||||
@typing.final
|
||||
class RustSender:
|
||||
async def send(self, message: typing.Sequence[builtins.int]) -> None: ...
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
//!
|
||||
//! SEE: https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -10,10 +7,8 @@ use std::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
#[pin_project]
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct AllowThreads<F>(#[pin] F);
|
||||
pub(crate) struct AllowThreads<F>(F);
|
||||
|
||||
impl<F> AllowThreads<F>
|
||||
where
|
||||
@@ -26,15 +21,13 @@ where
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
F: Future + Unpin + Send,
|
||||
F::Output: Send,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ impl PyAsyncTaskHandle {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
_ = Python::attach(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
@@ -176,9 +176,9 @@ impl PyAsyncTaskHandle {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
let c = Python::attach(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
if let Some(f) = Python::attach(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
@@ -238,3 +238,13 @@ pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::attach(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
#![allow(clippy::multiple_inherent_impl, clippy::missing_const_for_fn)]
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
@@ -21,6 +16,7 @@ use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use std::pin::pin;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
use util::ext::VecExt as _;
|
||||
|
||||
@@ -393,11 +389,14 @@ impl PyNetworkingHandle {
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
let mg_fut = self.connection_update_rx.lock();
|
||||
|
||||
let mut mg = pin!(mg_fut)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.await;
|
||||
|
||||
let recv_fut = mg.recv_py();
|
||||
pin!(recv_fut)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
@@ -486,7 +485,7 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
66
rust/exo_pyo3_bindings/src/identity.rs
Normal file
66
rust/exo_pyo3_bindings/src/identity.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use iroh::{EndpointId, SecretKey, endpoint_info::EndpointIdExt};
|
||||
use postcard::ser_flavors::StdVec;
|
||||
|
||||
use crate::ext::ResultExt as _;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use rand::rng;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PyKeypair(pub(crate) SecretKey);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(SecretKey::generate(&mut rng()))
|
||||
}
|
||||
/// Decode a postcard structure into a keypair
|
||||
#[staticmethod]
|
||||
fn from_postcard_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(postcard::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
/// Encode a private key with the postcard format
|
||||
fn to_postcard_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = postcard::serialize_with_flavor(&self.0, StdVec::new()).pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
/// Read out the endpoint id corresponding to this keypair
|
||||
fn endpoint_id(&self) -> PyEndpointId {
|
||||
PyEndpointId(self.0.public())
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "EndpointId", frozen)]
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PyEndpointId(pub(crate) EndpointId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyEndpointId {
|
||||
pub fn __str__(&self) -> String {
|
||||
self.0.to_z32()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EndpointId> for PyEndpointId {
|
||||
fn from(value: EndpointId) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyEndpointId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
195
rust/exo_pyo3_bindings/src/iroh_networking.rs
Normal file
195
rust/exo_pyo3_bindings/src/iroh_networking.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
use crate::ext::{FutureExt, ResultExt};
|
||||
use crate::identity::{PyEndpointId, PyKeypair};
|
||||
use iroh::SecretKey;
|
||||
use iroh::discovery::EndpointInfo;
|
||||
use iroh::discovery::mdns::DiscoveryEvent;
|
||||
use iroh_gossip::api::{ApiError, Event, GossipReceiver, GossipSender, Message};
|
||||
use iroh_networking::ExoNet;
|
||||
use n0_future::{Stream, StreamExt};
|
||||
use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::{Pin, pin};
|
||||
use std::sync::LazyLock;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
static RUNTIME: LazyLock<Runtime> =
|
||||
LazyLock::new(|| Runtime::new().expect("Failed to create tokio runtime"));
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "IpAddress")]
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct PyIpAddress {
|
||||
inner: SocketAddr,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyIpAddress {
|
||||
pub fn __str__(&self) -> String {
|
||||
self.inner.to_string()
|
||||
}
|
||||
|
||||
pub fn ip_addr(&self) -> String {
|
||||
self.inner.ip().to_string()
|
||||
}
|
||||
|
||||
pub fn port(&self) -> u16 {
|
||||
self.inner.port()
|
||||
}
|
||||
|
||||
pub fn zone_id(&self) -> Option<u32> {
|
||||
match self.inner {
|
||||
SocketAddr::V6(ip) => Some(ip.scope_id()),
|
||||
SocketAddr::V4(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustNetworkingHandle")]
|
||||
pub struct PyNetworkingHandle {
|
||||
net: Mutex<ExoNet>,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNetworkingHandle {
|
||||
#[staticmethod]
|
||||
pub async fn create(identity: PyKeypair, namespace: String) -> PyResult<Self> {
|
||||
let loc: SecretKey = identity.0.clone();
|
||||
let net = RUNTIME
|
||||
.spawn(async move { ExoNet::init_iroh(loc, &namespace).await })
|
||||
.await
|
||||
// todo: pyerr better
|
||||
.pyerr()?
|
||||
.pyerr()?;
|
||||
Ok(Self {
|
||||
net: Mutex::new(net),
|
||||
})
|
||||
}
|
||||
|
||||
async fn subscribe(&mut self, topic: String) -> PyResult<(PySender, PyReceiver)> {
|
||||
let mut lock = self.net.lock().await;
|
||||
let fut = lock.subscribe(&topic);
|
||||
let (send, recv) = pin!(fut).allow_threads_py().await.pyerr()?;
|
||||
Ok((PySender { inner: send }, PyReceiver { inner: recv }))
|
||||
}
|
||||
|
||||
async fn get_connection_receiver(&mut self) -> PyResult<PyConnectionReceiver> {
|
||||
let mut lock = self.net.lock().await;
|
||||
let fut = lock.connection_info();
|
||||
let stream = fut.await;
|
||||
Ok(PyConnectionReceiver {
|
||||
inner: Mutex::new(Box::pin(stream)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustConnectionMessage")]
|
||||
pub struct PyConnectionMessage {
|
||||
#[pyo3(get)]
|
||||
pub endpoint_id: PyEndpointId,
|
||||
#[pyo3(get)]
|
||||
pub current_transport_addrs: BTreeSet<PyIpAddress>,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustSender")]
|
||||
struct PySender {
|
||||
inner: GossipSender,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PySender {
|
||||
async fn send(&mut self, message: Vec<u8>) -> PyResult<()> {
|
||||
self.inner.broadcast(message.into()).await.pyerr()
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustReceiver")]
|
||||
struct PyReceiver {
|
||||
inner: GossipReceiver,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyReceiver {
|
||||
async fn receive(&mut self) -> PyResult<Vec<u8>> {
|
||||
loop {
|
||||
match self.inner.next().await {
|
||||
// Successful cases
|
||||
Some(Ok(Event::Received(Message { content, .. }))) => {
|
||||
return Ok(content.to_vec());
|
||||
}
|
||||
Some(Ok(other)) => log::info!("Dropping gossip event {other:?}"),
|
||||
None => return Err(PyStopAsyncIteration::new_err("")),
|
||||
Some(Err(ApiError::Closed { .. })) => {
|
||||
return Err(PyStopAsyncIteration::new_err(""));
|
||||
}
|
||||
|
||||
// Failure case
|
||||
Some(Err(other)) => {
|
||||
return Err(PyRuntimeError::new_err(other.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustConnectionReceiver")]
|
||||
struct PyConnectionReceiver {
|
||||
inner: Mutex<Pin<Box<dyn Stream<Item = DiscoveryEvent> + Send>>>,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyConnectionReceiver {
|
||||
async fn receive(&mut self) -> PyResult<PyConnectionMessage> {
|
||||
loop {
|
||||
let mg_fut = self.inner.lock();
|
||||
let mut lock = pin!(mg_fut).allow_threads_py().await;
|
||||
match lock.next().allow_threads_py().await {
|
||||
// Successful cases
|
||||
Some(DiscoveryEvent::Discovered {
|
||||
endpoint_info: EndpointInfo { endpoint_id, data },
|
||||
..
|
||||
}) => {
|
||||
return Ok(PyConnectionMessage {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
current_transport_addrs: data
|
||||
.ip_addrs()
|
||||
.map(|it| PyIpAddress { inner: it.clone() })
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
Some(DiscoveryEvent::Expired { endpoint_id }) => {
|
||||
return Ok(PyConnectionMessage {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
current_transport_addrs: BTreeSet::new(),
|
||||
});
|
||||
}
|
||||
|
||||
// Failure case
|
||||
None => return Err(PyStopAsyncIteration::new_err("")),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyConnectionMessage>()?;
|
||||
m.add_class::<PyReceiver>()?;
|
||||
m.add_class::<PySender>()?;
|
||||
m.add_class::<PyConnectionReceiver>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -4,65 +4,28 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
mod identity;
|
||||
mod iroh_networking;
|
||||
// mod examples;
|
||||
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use crate::identity::ident_submodule;
|
||||
use crate::iroh_networking::networking_submodule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
/// Namespace for all the constants used by this crate.
|
||||
pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
Python::attach(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +40,7 @@ pub(crate) mod ext {
|
||||
}
|
||||
|
||||
pub trait FutureExt: Future + Sized {
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
/// SEE: https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await
|
||||
fn allow_threads_py(self) -> AllowThreads<Self>
|
||||
where
|
||||
AllowThreads<Self>: Future,
|
||||
@@ -98,7 +61,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
Python::attach(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
@@ -112,85 +75,6 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioRuntimeExt)]
|
||||
impl Runtime {
|
||||
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
|
||||
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscSenderExt)]
|
||||
impl<T> mpsc::Sender<T> {
|
||||
/// Sends a value, waiting until there is capacity.
|
||||
///
|
||||
/// A successful send occurs when it is determined that the other end of the
|
||||
/// channel has not hung up already. An unsuccessful send would be one where
|
||||
/// the corresponding receiver has already been closed.
|
||||
async fn send_py(&self, value: T) -> PyResult<()> {
|
||||
self.send(value)
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscReceiverExt)]
|
||||
impl<T> mpsc::Receiver<T> {
|
||||
/// Receives the next value for this receiver.
|
||||
async fn recv_py(&mut self) -> PyResult<T> {
|
||||
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
|
||||
}
|
||||
|
||||
/// Receives at most `limit` values for this receiver and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
|
||||
// get updates from receiver channel
|
||||
let mut updates = Vec::with_capacity(limit);
|
||||
let received = self.recv_many(&mut updates, limit).await;
|
||||
|
||||
// if we received zero items, then the channel was unexpectedly closed
|
||||
if limit != 0 && received == 0 {
|
||||
return Err(PyErr::receiver_channel_closed());
|
||||
}
|
||||
|
||||
Ok(updates)
|
||||
}
|
||||
|
||||
/// Tries to receive the next value for this receiver.
|
||||
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
|
||||
match self.try_recv() {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(TryRecvError::Empty) => Ok(None),
|
||||
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
@@ -201,16 +85,9 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
18
rust/iroh_networking/Cargo.toml
Normal file
18
rust/iroh_networking/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "iroh_networking"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
blake3 = { workspace = true, features = ["neon", "rayon"] }
|
||||
iroh = { workspace = true, features = ["discovery-local-network"] }
|
||||
iroh-gossip = { workspace = true }
|
||||
n0-error = { workspace = true }
|
||||
n0-future = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
thiserror.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
28
rust/iroh_networking/src/bin/main.rs
Normal file
28
rust/iroh_networking/src/bin/main.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use iroh::SecretKey;
|
||||
use iroh_networking::ExoNet;
|
||||
use n0_future::StreamExt;
|
||||
|
||||
// Launch a mock version of iroh for testing purposes
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
let mut net = ExoNet::init_iroh(SecretKey::generate(&mut rand::rng()), "")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut conn_info = net.connection_info().await;
|
||||
|
||||
let task = tokio::task::spawn(async move {
|
||||
println!("Inner task started!");
|
||||
loop {
|
||||
dbg!(conn_info.next().await);
|
||||
}
|
||||
});
|
||||
|
||||
println!("Task started!");
|
||||
|
||||
task.await.unwrap();
|
||||
}
|
||||
103
rust/iroh_networking/src/lib.rs
Normal file
103
rust/iroh_networking/src/lib.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use iroh::{
|
||||
Endpoint, SecretKey,
|
||||
discovery::{
|
||||
IntoDiscoveryError,
|
||||
mdns::{DiscoveryEvent, MdnsDiscovery},
|
||||
},
|
||||
endpoint::BindError,
|
||||
protocol::Router,
|
||||
};
|
||||
use iroh_gossip::{
|
||||
Gossip, TopicId,
|
||||
api::{ApiError, GossipReceiver, GossipSender},
|
||||
};
|
||||
|
||||
use n0_error::stack_error;
|
||||
use n0_future::Stream;
|
||||
|
||||
#[stack_error(derive, add_meta, from_sources)]
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
FailedBinding { source: BindError },
|
||||
/// The gossip topic was closed.
|
||||
#[error(transparent)]
|
||||
FailedCommunication { source: ApiError },
|
||||
#[error("No IP Protocol supported on device")]
|
||||
IPNotSupported { source: IntoDiscoveryError },
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ExoNet {
|
||||
router: Router,
|
||||
gossip: Gossip,
|
||||
mdns: MdnsDiscovery,
|
||||
}
|
||||
|
||||
impl ExoNet {
|
||||
pub async fn init_iroh(sk: SecretKey, namespace: &str) -> Result<Self, Error> {
|
||||
let endpoint = Endpoint::empty_builder(iroh::RelayMode::Disabled)
|
||||
.secret_key(sk)
|
||||
.bind()
|
||||
.await?;
|
||||
let mdns = MdnsDiscovery::builder().build(endpoint.id())?;
|
||||
endpoint.discovery().add(mdns.clone());
|
||||
let alpn = format!("/exo_discovery_network/{}", namespace).to_owned();
|
||||
let gossip = Gossip::builder().alpn(&alpn).spawn(endpoint.clone());
|
||||
let router = Router::builder(endpoint)
|
||||
.accept(&alpn, gossip.clone())
|
||||
.spawn();
|
||||
Ok(Self {
|
||||
router,
|
||||
gossip,
|
||||
mdns,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn connection_info(&mut self) -> impl Stream<Item = DiscoveryEvent> + Unpin + use<> {
|
||||
self.mdns.subscribe().await
|
||||
}
|
||||
|
||||
pub async fn subscribe(
|
||||
&mut self,
|
||||
topic: &str,
|
||||
) -> Result<(GossipSender, GossipReceiver), Error> {
|
||||
Ok(self
|
||||
.gossip
|
||||
.subscribe(str_to_topic_id(topic), vec![])
|
||||
.await?
|
||||
.split())
|
||||
}
|
||||
|
||||
pub async fn shutdown(&mut self) {
|
||||
self.router
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("Iroh Router failed to shutdown");
|
||||
}
|
||||
}
|
||||
|
||||
fn str_to_topic_id(data: &str) -> TopicId {
|
||||
TopicId::from_bytes(*blake3::hash(data.as_bytes()).as_bytes())
|
||||
}
|
||||
|
||||
// Dead code here is for asserting these compile
|
||||
#[allow(dead_code)]
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use iroh::{SecretKey, discovery::mdns::DiscoveryEvent};
|
||||
|
||||
use crate::ExoNet;
|
||||
|
||||
fn is_send<T: Send>(_: &T) {}
|
||||
|
||||
trait Probe: Send {}
|
||||
impl Probe for ExoNet {}
|
||||
impl Probe for DiscoveryEvent {}
|
||||
|
||||
#[test]
|
||||
fn test_is_send() {
|
||||
// todo: make rand a dev dep.
|
||||
let fut = ExoNet::init_iroh(SecretKey::generate(&mut rand::rng()), "");
|
||||
is_send(&fut);
|
||||
}
|
||||
}
|
||||
@@ -41,4 +41,4 @@ keccak-const = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures::FutureExt as _;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
@@ -62,8 +62,7 @@ mod managed {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id())
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
@@ -125,7 +124,7 @@ impl Behaviour {
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
@@ -133,7 +132,7 @@ impl Behaviour {
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
@@ -294,7 +293,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
self.on_connection_established(peer_id, connection_id, ip, port);
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
@@ -310,7 +309,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,7 +328,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
managed::BehaviourEvent::Mdns(e) => match e {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
}
|
||||
@@ -340,8 +339,8 @@ impl NetworkBehaviour for Behaviour {
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
if e.result.is_err() {
|
||||
self.close_connection(e.peer, e.connection);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -366,10 +365,10 @@ impl NetworkBehaviour for Behaviour {
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
self.dial(p, ma);
|
||||
}
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL); // reset timeout
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
@@ -54,11 +54,3 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
@@ -1,69 +0,0 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
@@ -5,26 +5,16 @@
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
// #![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(type_alias_impl_trait)]
|
||||
// #![feature(specialization)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(const_trait_impl)]
|
||||
// #![feature(fn_traits)]
|
||||
|
||||
pub mod nonempty;
|
||||
pub mod wakerdeque;
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub mod ext {
|
||||
use extend::ext;
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from importlib.metadata import version
|
||||
|
||||
__version__ = version("exo")
|
||||
|
||||
@@ -39,9 +39,9 @@ class Node:
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(str(keypair.endpoint_id()))
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(keypair)
|
||||
router = await Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
await router.register_topic(topics.LOCAL_EVENTS)
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
|
||||
@@ -154,6 +154,7 @@ def get_shard_assignments(
|
||||
|
||||
|
||||
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
# this function is wrong.
|
||||
cycles = cycle_digraph.get_cycles()
|
||||
expected_length = len(list(cycle_digraph.list_nodes()))
|
||||
cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
|
||||
@@ -178,15 +179,15 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == current_node.node_id
|
||||
and connection.send_back_node_id == next_node.node_id
|
||||
connection.source_id == current_node.node_id
|
||||
and connection.sink_id == next_node.node_id
|
||||
):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
assert connection.send_back_multiaddr is not None
|
||||
assert connection.sink_addr is not None
|
||||
host = Host(
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port,
|
||||
ip=str(connection.sink_addr.ip),
|
||||
port=connection.sink_addr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
@@ -242,10 +243,10 @@ def _find_connection_ip(
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
connection.source_id == node_i.node_id
|
||||
and connection.sink_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address
|
||||
yield str(connection.sink_addr.ip)
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CommandId,
|
||||
CreateInstance,
|
||||
ForwarderCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
@@ -140,7 +141,6 @@ async def test_master():
|
||||
origin=node_id,
|
||||
command=(
|
||||
ChatCompletion(
|
||||
command_id=CommandId(),
|
||||
request_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
|
||||
@@ -1,37 +1,37 @@
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||
|
||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
||||
from exo_pyo3_bindings import RustConnectionMessage
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
IpAddress = IPv4Address | IPv6Address
|
||||
|
||||
class ConnectionMessageType(Enum):
|
||||
Connected = 0
|
||||
Disconnected = 1
|
||||
|
||||
@staticmethod
|
||||
def from_update_type(update_type: ConnectionUpdateType):
|
||||
match update_type:
|
||||
case ConnectionUpdateType.Connected:
|
||||
return ConnectionMessageType.Connected
|
||||
case ConnectionUpdateType.Disconnected:
|
||||
return ConnectionMessageType.Disconnected
|
||||
class SocketAddress(CamelCaseModel):
|
||||
# could be the python IpAddress type if we're feeling fancy
|
||||
ip: IpAddress
|
||||
port: int
|
||||
zone_id: int | None
|
||||
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
connection_type: ConnectionMessageType
|
||||
remote_ipv4: str
|
||||
remote_tcp_port: int
|
||||
ips: set[SocketAddress]
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
def from_rust(cls, message: RustConnectionMessage) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
node_id=NodeId(str(message.endpoint_id)),
|
||||
ips=set(
|
||||
# TODO: better handle fallible conversion
|
||||
SocketAddress(
|
||||
ip=ip_address(addr.ip_addr()),
|
||||
port=addr.port(),
|
||||
zone_id=addr.zone_id(),
|
||||
)
|
||||
for addr in message.current_transport_addrs
|
||||
),
|
||||
)
|
||||
|
||||
@@ -13,14 +13,15 @@ from anyio import (
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from exo_pyo3_bindings import (
|
||||
AllQueuesFullError,
|
||||
Keypair,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
RustNetworkingHandle,
|
||||
RustReceiver,
|
||||
RustSender,
|
||||
)
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
|
||||
from exo import __version__
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -37,7 +38,8 @@ class TopicRouter[T: CamelCaseModel]:
|
||||
def __init__(
|
||||
self,
|
||||
topic: TypedTopic[T],
|
||||
networking_sender: Sender[tuple[str, bytes]],
|
||||
networking_sender: RustSender,
|
||||
networking_receiver: RustReceiver,
|
||||
max_buffer_size: float = inf,
|
||||
):
|
||||
self.topic: TypedTopic[T] = topic
|
||||
@@ -45,7 +47,7 @@ class TopicRouter[T: CamelCaseModel]:
|
||||
send, recv = channel[T]()
|
||||
self.receiver: Receiver[T] = recv
|
||||
self._sender: Sender[T] = send
|
||||
self.networking_sender: Sender[tuple[str, bytes]] = networking_sender
|
||||
self.networking_sender: RustSender = networking_sender
|
||||
|
||||
async def run(self):
|
||||
logger.debug(f"Topic Router {self.topic} ready to send")
|
||||
@@ -93,35 +95,24 @@ class TopicRouter[T: CamelCaseModel]:
|
||||
|
||||
async def _send_out(self, item: T):
|
||||
logger.trace(f"TopicRouter {self.topic.topic} sending {item}")
|
||||
await self.networking_sender.send(
|
||||
(str(self.topic.topic), self.topic.serialize(item))
|
||||
)
|
||||
await self.networking_sender.send(self.topic.serialize(item))
|
||||
|
||||
|
||||
class Router:
|
||||
@classmethod
|
||||
def create(cls, identity: Keypair) -> "Router":
|
||||
return cls(handle=NetworkingHandle(identity))
|
||||
async def create(cls, identity: Keypair) -> "Router":
|
||||
return cls(handle=await RustNetworkingHandle.create(identity, __version__))
|
||||
|
||||
def __init__(self, handle: NetworkingHandle):
|
||||
def __init__(self, handle: RustNetworkingHandle):
|
||||
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
|
||||
send, recv = channel[tuple[str, bytes]]()
|
||||
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
|
||||
self._net: NetworkingHandle = handle
|
||||
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
|
||||
self._net: RustNetworkingHandle = handle
|
||||
self._id_count = count()
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
|
||||
assert self._tg is None, "Attempted to register topic after setup time"
|
||||
send = self._tmp_networking_sender
|
||||
if send:
|
||||
self._tmp_networking_sender = None
|
||||
else:
|
||||
send = self.networking_receiver.clone_sender()
|
||||
router = TopicRouter[T](topic, send)
|
||||
router = TopicRouter[T](topic, *await self._net.subscribe(str(topic.topic)))
|
||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||
await self._networking_subscribe(str(topic.topic))
|
||||
|
||||
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
||||
router = self.topic_routers.get(topic.topic, None)
|
||||
@@ -151,13 +142,9 @@ class Router:
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
@@ -165,29 +152,10 @@ class Router:
|
||||
return
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
topic, data = await self._net.gossipsub_recv()
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(f"Received message on unknown or inactive topic {topic}")
|
||||
continue
|
||||
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
recv = await self._net.get_connection_receiver()
|
||||
while True:
|
||||
update = await self._net.connection_update_recv()
|
||||
message = ConnectionMessage.from_update(update)
|
||||
message = await recv.receive()
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
@@ -195,18 +163,7 @@ class Router:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
|
||||
async def _networking_publish(self):
|
||||
with self.networking_receiver as networked_items:
|
||||
async for topic, data in networked_items:
|
||||
try:
|
||||
logger.trace(f"Sending message on {topic} with payload {data}")
|
||||
await self._net.gossipsub_publish(topic, data)
|
||||
# As a hack, this also catches AllQueuesFull
|
||||
# Need to fix that ASAP.
|
||||
except (NoPeersSubscribedToTopicError, AllQueuesFullError):
|
||||
pass
|
||||
await router.publish(ConnectionMessage.from_rust(message))
|
||||
|
||||
|
||||
def get_node_id_keypair(
|
||||
@@ -225,16 +182,16 @@ def get_node_id_keypair(
|
||||
with open(path, "a+b") as f: # opens in append-mode => starts at EOF
|
||||
# if non-zero EOF, then file exists => use to get node-ID
|
||||
if f.tell() != 0:
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
protobuf_encoded = f.read()
|
||||
f.seek(0) # go to start & read postcard-encoded bytes
|
||||
postcard_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
return Keypair.from_postcard_encoding(postcard_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
f.write(keypair.to_postcard_encoding())
|
||||
return keypair
|
||||
|
||||
@@ -81,16 +81,16 @@ class Topology:
|
||||
self,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.local_node_id))
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
||||
if connection.source_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.source_id))
|
||||
if connection.sink_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.sink_id))
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
src_id = self._node_id_to_rx_id_map[connection.source_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.sink_id]
|
||||
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
@@ -188,10 +188,7 @@ class Topology:
|
||||
for rx_idx in rx_idxs:
|
||||
topology.add_node(self._graph[rx_idx])
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id in node_idxs
|
||||
and connection.send_back_node_id in node_idxs
|
||||
):
|
||||
if connection.source_id in node_idxs and connection.sink_id in node_idxs:
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Self
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
@@ -26,6 +28,17 @@ class PlaceInstance(BaseCommand):
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
|
||||
# Decision point - I like this syntax better than the typical fixtures,
|
||||
# but it's """bloat"""
|
||||
@classmethod
|
||||
def fixture(cls) -> Self:
|
||||
return cls(
|
||||
model_meta=ModelMetadata.fixture(),
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
instance: Instance
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Self
|
||||
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
@@ -14,3 +16,12 @@ class ModelMetadata(CamelCaseModel):
|
||||
pretty_name: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
|
||||
@classmethod
|
||||
def fixture(cls) -> Self:
|
||||
return cls(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Self
|
||||
|
||||
import psutil
|
||||
|
||||
from exo.routing.connection_message import IpAddress
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -49,7 +50,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
name: str
|
||||
ip_address: str
|
||||
ip_address: IpAddress
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -10,17 +10,17 @@ class NodeInfo(CamelCaseModel):
|
||||
|
||||
|
||||
class Connection(CamelCaseModel):
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr
|
||||
source_id: NodeId
|
||||
sink_id: NodeId
|
||||
sink_addr: SocketAddress
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
self.source_id,
|
||||
self.sink_id,
|
||||
self.sink_addr,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -28,10 +28,10 @@ class Connection(CamelCaseModel):
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
self.source_id == other.source_id
|
||||
and self.sink_id == other.sink_id
|
||||
and self.sink_addr == other.sink_addr
|
||||
)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
||||
return str(self.sink_addr.ip).startswith("169.254")
|
||||
|
||||
31
src/exo/worker/_connection_handler.py
Normal file
31
src/exo/worker/_connection_handler.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, TopologyEdgeCreated, TopologyEdgeDeleted
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
|
||||
|
||||
def check_connections(
|
||||
local_id: NodeId, msg: ConnectionMessage, state: State
|
||||
) -> list[Event]:
|
||||
remote_id = msg.node_id
|
||||
sockets = msg.ips
|
||||
del msg
|
||||
out: list[Event] = []
|
||||
if not state.topology.contains_node(remote_id) or remote_id in state.node_profiles:
|
||||
return out
|
||||
|
||||
conns = list(state.topology.list_connections())
|
||||
for iface in state.node_profiles[remote_id].network_interfaces:
|
||||
for sock in sockets:
|
||||
if iface.ip_address == sock.ip:
|
||||
conn = Connection(source_id=local_id, sink_id=remote_id, sink_addr=sock)
|
||||
if state.topology.contains_connection(conn):
|
||||
conns.remove(conn)
|
||||
continue
|
||||
out.append(TopologyEdgeCreated(edge=conn))
|
||||
|
||||
for conn in conns:
|
||||
out.append(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
return out
|
||||
@@ -1,43 +1,41 @@
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
if TYPE_CHECKING:
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
|
||||
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: list[KVCache] | None,
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: list[KVCache] | None,
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
class Detokenizer:
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
|
||||
@property
|
||||
def last_segment(self) -> str: ...
|
||||
|
||||
class Detokenizer:
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
class TokenizerWrapper:
|
||||
bos_token: str | None
|
||||
eos_token_ids: list[int]
|
||||
detokenizer: Detokenizer
|
||||
|
||||
@property
|
||||
def last_segment(self) -> str: ...
|
||||
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
|
||||
|
||||
|
||||
class TokenizerWrapper:
|
||||
bos_token: str | None
|
||||
eos_token_ids: list[int]
|
||||
detokenizer: Detokenizer
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages_dicts: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> str: ...
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages_dicts: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> str: ...
|
||||
|
||||
@@ -6,7 +6,7 @@ from anyio import CancelScope, create_task_group, current_time, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
@@ -20,10 +20,7 @@ from exo.shared.types.events import (
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -33,7 +30,6 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
@@ -258,33 +254,13 @@ class Worker:
|
||||
async def _connection_message_event_writer(self):
|
||||
with self.connection_message_receiver as connection_messages:
|
||||
async for msg in connection_messages:
|
||||
await self.event_sender.send(
|
||||
self._convert_connection_message_to_event(msg)
|
||||
)
|
||||
for event in self._convert_connection_message_to_event(msg):
|
||||
await self.event_sender.send(event)
|
||||
|
||||
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
return TopologyEdgeCreated(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
return TopologyEdgeDeleted(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
)
|
||||
)
|
||||
def _convert_connection_message_to_event(
|
||||
self, msg: ConnectionMessage
|
||||
) -> list[Event]:
|
||||
return check_connections(self.node_id, msg, self.state)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import socket
|
||||
import sys
|
||||
from ipaddress import ip_address
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
import psutil
|
||||
@@ -29,12 +30,6 @@ async def get_friendly_name() -> str:
|
||||
|
||||
|
||||
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
"""
|
||||
Retrieves detailed network interface information on macOS.
|
||||
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
|
||||
to determine interface names, IP addresses, and types (ethernet, wifi, vpn, other).
|
||||
Returns a list of NetworkInterfaceInfo objects.
|
||||
"""
|
||||
interfaces_info: list[NetworkInterfaceInfo] = []
|
||||
|
||||
for iface, services in psutil.net_if_addrs().items():
|
||||
@@ -42,7 +37,9 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
match service.family:
|
||||
case socket.AF_INET | socket.AF_INET6:
|
||||
interfaces_info.append(
|
||||
NetworkInterfaceInfo(name=iface, ip_address=service.address)
|
||||
NetworkInterfaceInfo(
|
||||
name=iface, ip_address=ip_address(service.address)
|
||||
)
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user