uh i did it i think

This commit is contained in:
Evan
2025-11-27 12:37:29 +00:00
parent 0a7fe5d943
commit 7354578a16
41 changed files with 709 additions and 768 deletions

View File

@@ -1,10 +1,8 @@
[workspace]
resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
"exo_pyo3_bindings",
"util", "iroh_networking",
]
[workspace.package]
@@ -24,60 +22,49 @@ opt-level = 3
# Common configurations include versions, paths, features, etc.
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
iroh_networking = { path = "iroh_networking" }
util = { path = "util" }
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-util = "0.3"
futures-timer = "3.0"
n0-future = "0.3.1"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
postcard = "1.1.3"
n0-error = "0.1.2"
# Tracing/logging
log = "0.4"
blake3 = "1.8.2"
env_logger = "0.11"
tracing-subscriber = "0.3.20"
# networking
libp2p = "0.56"
libp2p-tcp = "0.44"
iroh = "0.95.1"
iroh-gossip = "0.95.0"
# pyo3
pyo3 = "0.27.1"
pyo3-async-runtimes = "0.27.0"
pyo3-log = "0.13.2"
pyo3-stub-gen = "0.17.2"
# other
rand = "0.9.2"
[workspace.lints.rust]
static_mut_refs = "warn" # Or use "warn" instead of deny

View File

@@ -19,6 +19,7 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Really need to remove all mlx logic outside of the runner - API has a transitive dependency on engines which imports mlx
Potential refactors:

View File

@@ -69,14 +69,11 @@
basedpyright
# RUST
((fenixToolchain system).withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
cargo
bacon
rust-analyzer
rustc
rustfmt
# NIX
nixpkgs-fmt

View File

@@ -1,6 +1,6 @@
[project]
name = "exo"
version = "0.3.0"
version = "0.10.0"
description = "Exo"
readme = "README.md"
requires-python = ">=3.13"

View File

@@ -5,8 +5,6 @@ edition = { workspace = true }
publish = false
[lib]
doctest = false
path = "src/lib.rs"
name = "exo_pyo3_bindings"
# "cdylib" needed to produce shared library for Python to import
@@ -22,46 +20,25 @@ doc = false
workspace = true
[dependencies]
networking = { workspace = true }
iroh_networking = { workspace = true }
# interop
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.17.2" }
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
pyo3-log = "0.13.2"
pyo3 = { workspace = true, features = ["experimental-async"] }
pyo3-stub-gen = { workspace = true }
pyo3-async-runtimes = { workspace = true, features = ["attributes", "tokio-runtime", "testing"] }
pyo3-log = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
postcard = { workspace = true, features = ["use-std"] }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
rand = { workspace = true }
n0-future = { workspace = true }
# Tracing
@@ -70,8 +47,9 @@ thiserror = { workspace = true }
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
env_logger = { workspace = true }
# Networking
libp2p = { workspace = true, features = ["full"] }
iroh = { workspace = true }
iroh-gossip = { workspace = true }

View File

@@ -2,220 +2,63 @@
# ruff: noqa: E501, F401
import builtins
import enum
import typing
@typing.final
class AllQueuesFullError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
def __repr__(self) -> builtins.str: ...
class EndpointId:
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
class IpAddress:
def __str__(self) -> builtins.str: ...
def ip_addr(self) -> builtins.str: ...
def port(self) -> builtins.int: ...
def zone_id(self) -> typing.Optional[builtins.int]: ...
@typing.final
class Keypair:
r"""
Identity keypair of a node.
"""
@staticmethod
def generate_ed25519() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def generate_ecdsa() -> Keypair:
def from_postcard_encoding(bytes: bytes) -> Keypair:
r"""
Generate a new ECDSA keypair.
Decode a postcard structure into a keypair
"""
@staticmethod
def generate_secp256k1() -> Keypair:
def to_postcard_encoding(self) -> bytes:
r"""
Generate a new Secp256k1 keypair.
Encode a private key with the postcard format
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
def endpoint_id(self) -> EndpointId:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
Read out the endpoint id corresponding to this keypair
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
"""
def to_string(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
"""
class RustConnectionMessage:
@property
def endpoint_id(self) -> EndpointId: ...
@property
def current_transport_addrs(self) -> builtins.set[IpAddress]: ...
@typing.final
class NetworkingHandle:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.
"""
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
"""
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Unsubscribes from a `GossipSub` topic.
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
"""
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
r"""
Publishes a message with multiple topics to the `GossipSub` network.
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
class RustConnectionReceiver:
async def receive(self) -> RustConnectionMessage: ...
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
class RustNetworkingHandle:
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
async def create(identity: Keypair, namespace: builtins.str) -> RustNetworkingHandle: ...
async def subscribe(self, topic: builtins.str) -> tuple[RustSender, RustReceiver]: ...
async def get_connection_receiver(self) -> RustConnectionReceiver: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...
class RustReceiver:
async def receive(self) -> builtins.list[builtins.int]: ...
@typing.final
class RustSender:
async def send(self, message: typing.Sequence[builtins.int]) -> None: ...

View File

@@ -1,8 +1,5 @@
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
//!
//! SEE: https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
@@ -10,10 +7,8 @@ use std::{
task::{Context, Poll},
};
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
#[pin_project]
#[repr(transparent)]
pub(crate) struct AllowThreads<F>(#[pin] F);
pub(crate) struct AllowThreads<F>(F);
impl<F> AllowThreads<F>
where
@@ -26,15 +21,13 @@ where
impl<F> Future for AllowThreads<F>
where
F: Future + Ungil,
F::Output: Ungil,
F: Future + Unpin + Send,
F::Output: Send,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
}
}

View File

@@ -158,7 +158,7 @@ impl PyAsyncTaskHandle {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
_ = Python::attach(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
@@ -176,9 +176,9 @@ impl PyAsyncTaskHandle {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
let c = Python::attach(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
if let Some(f) = Python::attach(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
@@ -238,3 +238,13 @@ pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
Ok(())
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::attach(|py| Self(self.0.clone_ref(py)))
}
}

View File

@@ -1,9 +1,4 @@
#![allow(
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
#![allow(clippy::multiple_inherent_impl, clippy::missing_const_for_fn)]
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
@@ -21,6 +16,7 @@ use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use std::pin::pin;
use tokio::sync::{Mutex, mpsc, oneshot};
use util::ext::VecExt as _;
@@ -393,11 +389,14 @@ impl PyNetworkingHandle {
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
let mg_fut = self.connection_update_rx.lock();
let mut mg = pin!(mg_fut)
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.await;
let recv_fut = mg.recv_py();
pin!(recv_fut)
.allow_threads_py() // allow-threads-aware async call
.await
}
@@ -486,7 +485,7 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,

View File

@@ -0,0 +1,66 @@
use iroh::{EndpointId, SecretKey, endpoint_info::EndpointIdExt};
use postcard::ser_flavors::StdVec;
use crate::ext::ResultExt as _;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use rand::rng;
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
#[derive(Debug, Clone)]
pub struct PyKeypair(pub(crate) SecretKey);
#[gen_stub_pymethods]
#[pymethods]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(SecretKey::generate(&mut rng()))
}
/// Decode a postcard structure into a keypair
#[staticmethod]
fn from_postcard_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(postcard::from_bytes(&bytes).pyerr()?))
}
/// Encode a private key with the postcard format
fn to_postcard_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = postcard::serialize_with_flavor(&self.0, StdVec::new()).pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Read out the endpoint id corresponding to this keypair
fn endpoint_id(&self) -> PyEndpointId {
PyEndpointId(self.0.public())
}
}
#[gen_stub_pyclass]
#[pyclass(name = "EndpointId", frozen)]
#[repr(transparent)]
#[derive(Debug, Clone)]
pub struct PyEndpointId(pub(crate) EndpointId);
#[gen_stub_pymethods]
#[pymethods]
impl PyEndpointId {
pub fn __str__(&self) -> String {
self.0.to_z32()
}
}
impl From<EndpointId> for PyEndpointId {
fn from(value: EndpointId) -> Self {
Self(value)
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyEndpointId>()?;
Ok(())
}

View File

@@ -0,0 +1,195 @@
use crate::ext::{FutureExt, ResultExt};
use crate::identity::{PyEndpointId, PyKeypair};
use iroh::SecretKey;
use iroh::discovery::EndpointInfo;
use iroh::discovery::mdns::DiscoveryEvent;
use iroh_gossip::api::{ApiError, Event, GossipReceiver, GossipSender, Message};
use iroh_networking::ExoNet;
use n0_future::{Stream, StreamExt};
use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration};
use pyo3::prelude::*;
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::collections::BTreeSet;
use std::net::SocketAddr;
use std::pin::{Pin, pin};
use std::sync::LazyLock;
use tokio::runtime::Runtime;
use tokio::sync::Mutex;
static RUNTIME: LazyLock<Runtime> =
LazyLock::new(|| Runtime::new().expect("Failed to create tokio runtime"));
#[gen_stub_pyclass]
#[pyclass(name = "IpAddress")]
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct PyIpAddress {
inner: SocketAddr,
}
#[gen_stub_pymethods]
#[pymethods]
impl PyIpAddress {
pub fn __str__(&self) -> String {
self.inner.to_string()
}
pub fn ip_addr(&self) -> String {
self.inner.ip().to_string()
}
pub fn port(&self) -> u16 {
self.inner.port()
}
pub fn zone_id(&self) -> Option<u32> {
match self.inner {
SocketAddr::V6(ip) => Some(ip.scope_id()),
SocketAddr::V4(_) => None,
}
}
}
#[gen_stub_pyclass]
#[pyclass(name = "RustNetworkingHandle")]
pub struct PyNetworkingHandle {
net: Mutex<ExoNet>,
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNetworkingHandle {
#[staticmethod]
pub async fn create(identity: PyKeypair, namespace: String) -> PyResult<Self> {
let loc: SecretKey = identity.0.clone();
let net = RUNTIME
.spawn(async move { ExoNet::init_iroh(loc, &namespace).await })
.await
// todo: pyerr better
.pyerr()?
.pyerr()?;
Ok(Self {
net: Mutex::new(net),
})
}
async fn subscribe(&mut self, topic: String) -> PyResult<(PySender, PyReceiver)> {
let mut lock = self.net.lock().await;
let fut = lock.subscribe(&topic);
let (send, recv) = pin!(fut).allow_threads_py().await.pyerr()?;
Ok((PySender { inner: send }, PyReceiver { inner: recv }))
}
async fn get_connection_receiver(&mut self) -> PyResult<PyConnectionReceiver> {
let mut lock = self.net.lock().await;
let fut = lock.connection_info();
let stream = fut.await;
Ok(PyConnectionReceiver {
inner: Mutex::new(Box::pin(stream)),
})
}
}
#[gen_stub_pyclass]
#[pyclass(name = "RustConnectionMessage")]
pub struct PyConnectionMessage {
#[pyo3(get)]
pub endpoint_id: PyEndpointId,
#[pyo3(get)]
pub current_transport_addrs: BTreeSet<PyIpAddress>,
}
#[gen_stub_pyclass]
#[pyclass(name = "RustSender")]
struct PySender {
inner: GossipSender,
}
#[gen_stub_pymethods]
#[pymethods]
impl PySender {
async fn send(&mut self, message: Vec<u8>) -> PyResult<()> {
self.inner.broadcast(message.into()).await.pyerr()
}
}
#[gen_stub_pyclass]
#[pyclass(name = "RustReceiver")]
struct PyReceiver {
inner: GossipReceiver,
}
#[gen_stub_pymethods]
#[pymethods]
impl PyReceiver {
async fn receive(&mut self) -> PyResult<Vec<u8>> {
loop {
match self.inner.next().await {
// Successful cases
Some(Ok(Event::Received(Message { content, .. }))) => {
return Ok(content.to_vec());
}
Some(Ok(other)) => log::info!("Dropping gossip event {other:?}"),
None => return Err(PyStopAsyncIteration::new_err("")),
Some(Err(ApiError::Closed { .. })) => {
return Err(PyStopAsyncIteration::new_err(""));
}
// Failure case
Some(Err(other)) => {
return Err(PyRuntimeError::new_err(other.to_string()));
}
}
}
}
}
#[gen_stub_pyclass]
#[pyclass(name = "RustConnectionReceiver")]
struct PyConnectionReceiver {
inner: Mutex<Pin<Box<dyn Stream<Item = DiscoveryEvent> + Send>>>,
}
#[gen_stub_pymethods]
#[pymethods]
impl PyConnectionReceiver {
async fn receive(&mut self) -> PyResult<PyConnectionMessage> {
loop {
let mg_fut = self.inner.lock();
let mut lock = pin!(mg_fut).allow_threads_py().await;
match lock.next().allow_threads_py().await {
// Successful cases
Some(DiscoveryEvent::Discovered {
endpoint_info: EndpointInfo { endpoint_id, data },
..
}) => {
return Ok(PyConnectionMessage {
endpoint_id: endpoint_id.into(),
current_transport_addrs: data
.ip_addrs()
.map(|it| PyIpAddress { inner: it.clone() })
.collect(),
});
}
Some(DiscoveryEvent::Expired { endpoint_id }) => {
return Ok(PyConnectionMessage {
endpoint_id: endpoint_id.into(),
current_transport_addrs: BTreeSet::new(),
});
}
// Failure case
None => return Err(PyStopAsyncIteration::new_err("")),
}
}
}
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyConnectionMessage>()?;
m.add_class::<PyReceiver>()?;
m.add_class::<PySender>()?;
m.add_class::<PyConnectionReceiver>()?;
m.add_class::<PyNetworkingHandle>()?;
Ok(())
}

View File

@@ -4,65 +4,28 @@
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
mod identity;
mod iroh_networking;
// mod examples;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use crate::identity::ident_submodule;
use crate::iroh_networking::networking_submodule;
use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
/// Namespace for all the constants used by this crate.
pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new(py, self).unbind())
Python::attach(|py| PyBytes::new(py, self).unbind())
}
}
@@ -77,7 +40,7 @@ pub(crate) mod ext {
}
pub trait FutureExt: Future + Sized {
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
/// SEE: https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await
fn allow_threads_py(self) -> AllowThreads<Self>
where
AllowThreads<Self>: Future,
@@ -98,7 +61,7 @@ pub(crate) mod ext {
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
Python::attach(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
@@ -112,85 +75,6 @@ pub(crate) mod ext {
}
}
}
#[ext(pub, name = TokioRuntimeExt)]
impl Runtime {
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
}
}
#[ext(pub, name = TokioMpscSenderExt)]
impl<T> mpsc::Sender<T> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
/// channel has not hung up already. An unsuccessful send would be one where
/// the corresponding receiver has already been closed.
async fn send_py(&self, value: T) -> PyResult<()> {
self.send(value)
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
}
#[ext(pub, name = TokioMpscReceiverExt)]
impl<T> mpsc::Receiver<T> {
/// Receives the next value for this receiver.
async fn recv_py(&mut self) -> PyResult<T> {
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
}
/// Receives at most `limit` values for this receiver and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
// get updates from receiver channel
let mut updates = Vec::with_capacity(limit);
let received = self.recv_many(&mut updates, limit).await;
// if we received zero items, then the channel was unexpectedly closed
if limit != 0 && received == 0 {
return Err(PyErr::receiver_channel_closed());
}
Ok(updates)
}
/// Tries to receive the next value for this receiver.
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
match self.try_recv() {
Ok(v) => Ok(Some(v)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
}
}
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
@@ -201,16 +85,9 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
// top-level constructs
// TODO: ...
Ok(())
}

View File

@@ -0,0 +1,18 @@
[package]
name = "iroh_networking"
version.workspace = true
edition.workspace = true
[dependencies]
blake3 = { workspace = true, features = ["neon", "rayon"] }
iroh = { workspace = true, features = ["discovery-local-network"] }
iroh-gossip = { workspace = true }
n0-error = { workspace = true }
n0-future = { workspace = true }
rand = { workspace = true }
thiserror.workspace = true
tokio = { workspace = true, features = ["full"] }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
[lints]
workspace = true

View File

@@ -0,0 +1,28 @@
use iroh::SecretKey;
use iroh_networking::ExoNet;
use n0_future::StreamExt;
// Launch a mock version of iroh for testing purposes
#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
let mut net = ExoNet::init_iroh(SecretKey::generate(&mut rand::rng()), "")
.await
.unwrap();
let mut conn_info = net.connection_info().await;
let task = tokio::task::spawn(async move {
println!("Inner task started!");
loop {
dbg!(conn_info.next().await);
}
});
println!("Task started!");
task.await.unwrap();
}

View File

@@ -0,0 +1,103 @@
use iroh::{
Endpoint, SecretKey,
discovery::{
IntoDiscoveryError,
mdns::{DiscoveryEvent, MdnsDiscovery},
},
endpoint::BindError,
protocol::Router,
};
use iroh_gossip::{
Gossip, TopicId,
api::{ApiError, GossipReceiver, GossipSender},
};
use n0_error::stack_error;
use n0_future::Stream;
#[stack_error(derive, add_meta, from_sources)]
pub enum Error {
#[error(transparent)]
FailedBinding { source: BindError },
/// The gossip topic was closed.
#[error(transparent)]
FailedCommunication { source: ApiError },
#[error("No IP Protocol supported on device")]
IPNotSupported { source: IntoDiscoveryError },
}
#[derive(Debug)]
pub struct ExoNet {
router: Router,
gossip: Gossip,
mdns: MdnsDiscovery,
}
impl ExoNet {
pub async fn init_iroh(sk: SecretKey, namespace: &str) -> Result<Self, Error> {
let endpoint = Endpoint::empty_builder(iroh::RelayMode::Disabled)
.secret_key(sk)
.bind()
.await?;
let mdns = MdnsDiscovery::builder().build(endpoint.id())?;
endpoint.discovery().add(mdns.clone());
let alpn = format!("/exo_discovery_network/{}", namespace).to_owned();
let gossip = Gossip::builder().alpn(&alpn).spawn(endpoint.clone());
let router = Router::builder(endpoint)
.accept(&alpn, gossip.clone())
.spawn();
Ok(Self {
router,
gossip,
mdns,
})
}
pub async fn connection_info(&mut self) -> impl Stream<Item = DiscoveryEvent> + Unpin + use<> {
self.mdns.subscribe().await
}
pub async fn subscribe(
&mut self,
topic: &str,
) -> Result<(GossipSender, GossipReceiver), Error> {
Ok(self
.gossip
.subscribe(str_to_topic_id(topic), vec![])
.await?
.split())
}
pub async fn shutdown(&mut self) {
self.router
.shutdown()
.await
.expect("Iroh Router failed to shutdown");
}
}
fn str_to_topic_id(data: &str) -> TopicId {
TopicId::from_bytes(*blake3::hash(data.as_bytes()).as_bytes())
}
// Dead code here is for asserting these compile
#[allow(dead_code)]
#[cfg(test)]
mod test {
use iroh::{SecretKey, discovery::mdns::DiscoveryEvent};
use crate::ExoNet;
fn is_send<T: Send>(_: &T) {}
trait Probe: Send {}
impl Probe for ExoNet {}
impl Probe for DiscoveryEvent {}
#[test]
fn test_is_send() {
// todo: make rand a dev dep.
let fut = ExoNet::init_iroh(SecretKey::generate(&mut rand::rng()), "");
is_send(&fut);
}
}

View File

@@ -41,4 +41,4 @@ keccak-const = { workspace = true }
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -2,7 +2,7 @@ use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures::FutureExt as _;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
@@ -62,8 +62,7 @@ mod managed {
..Default::default()
};
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
Ok(mdns_behaviour?)
tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id())
}
fn ping_behaviour() -> ping::Behaviour {
@@ -125,7 +124,7 @@ impl Behaviour {
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
})
});
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
@@ -133,7 +132,7 @@ impl Behaviour {
self.pending_events.push_front(ToSwarm::CloseConnection {
peer_id,
connection: CloseConnection::One(connection),
})
});
}
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
@@ -294,7 +293,7 @@ impl NetworkBehaviour for Behaviour {
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection established event which is filtered correctly
self.on_connection_established(peer_id, connection_id, ip, port)
self.on_connection_established(peer_id, connection_id, ip, port);
}
}
FromSwarm::ConnectionClosed(ConnectionClosed {
@@ -310,7 +309,7 @@ impl NetworkBehaviour for Behaviour {
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
self.on_connection_closed(peer_id, connection_id, ip, port);
}
}
@@ -329,7 +328,7 @@ impl NetworkBehaviour for Behaviour {
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
match e {
// handle discovered and expired events from mDNS
managed::BehaviourEvent::Mdns(e) => match e.clone() {
managed::BehaviourEvent::Mdns(e) => match e {
mdns::Event::Discovered(peers) => {
self.handle_mdns_discovered(peers);
}
@@ -340,8 +339,8 @@ impl NetworkBehaviour for Behaviour {
// handle ping events => if error then disconnect
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
if e.result.is_err() {
self.close_connection(e.peer, e.connection);
}
}
}
@@ -366,10 +365,10 @@ impl NetworkBehaviour for Behaviour {
if self.retry_delay.poll_unpin(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
self.dial(p, ma);
}
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
self.retry_delay.reset(RETRY_CONNECT_INTERVAL); // reset timeout
}
// send out any pending events from our own service

View File

@@ -5,7 +5,7 @@
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
@@ -54,11 +54,3 @@ pub(crate) mod ext {
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -1,47 +0,0 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -1,4 +0,0 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -1,69 +0,0 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -5,26 +5,16 @@
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
// #![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(type_alias_impl_trait)]
// #![feature(specialization)]
// #![feature(unboxed_closures)]
// #![feature(const_trait_impl)]
// #![feature(fn_traits)]
pub mod nonempty;
pub mod wakerdeque;
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub mod ext {
use extend::ext;

View File

@@ -0,0 +1,3 @@
from importlib.metadata import version
__version__ = version("exo")

View File

@@ -39,9 +39,9 @@ class Node:
@classmethod
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
node_id = NodeId(str(keypair.endpoint_id()))
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
router = await Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)
await router.register_topic(topics.LOCAL_EVENTS)
await router.register_topic(topics.COMMANDS)

View File

@@ -154,6 +154,7 @@ def get_shard_assignments(
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
# this function is wrong.
cycles = cycle_digraph.get_cycles()
expected_length = len(list(cycle_digraph.list_nodes()))
cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
@@ -178,15 +179,15 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
connection.source_id == current_node.node_id
and connection.sink_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
assert connection.sink_addr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
ip=str(connection.sink_addr.ip),
port=connection.sink_addr.port,
)
hosts.append(host)
break
@@ -242,10 +243,10 @@ def _find_connection_ip(
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
connection.source_id == node_i.node_id
and connection.sink_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address
yield str(connection.sink_addr.ip)
def _find_interface_name_for_ip(

View File

@@ -11,6 +11,7 @@ from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
CommandId,
CreateInstance,
ForwarderCommand,
PlaceInstance,
)
@@ -140,7 +141,6 @@ async def test_master():
origin=node_id,
command=(
ChatCompletion(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[

View File

@@ -1,37 +1,37 @@
from enum import Enum
from ipaddress import IPv4Address, IPv6Address, ip_address
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo_pyo3_bindings import RustConnectionMessage
from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages"""
IpAddress = IPv4Address | IPv6Address
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class SocketAddress(CamelCaseModel):
# could be the python IpAddress type if we're feeling fancy
ip: IpAddress
port: int
zone_id: int | None
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connection_type: ConnectionMessageType
remote_ipv4: str
remote_tcp_port: int
ips: set[SocketAddress]
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
def from_rust(cls, message: RustConnectionMessage) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
node_id=NodeId(str(message.endpoint_id)),
ips=set(
# TODO: better handle fallible conversion
SocketAddress(
ip=ip_address(addr.ip_addr()),
port=addr.port(),
zone_id=addr.zone_id(),
)
for addr in message.current_transport_addrs
),
)

View File

@@ -13,14 +13,15 @@ from anyio import (
)
from anyio.abc import TaskGroup
from exo_pyo3_bindings import (
AllQueuesFullError,
Keypair,
NetworkingHandle,
NoPeersSubscribedToTopicError,
RustNetworkingHandle,
RustReceiver,
RustSender,
)
from filelock import FileLock
from loguru import logger
from exo import __version__
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.pydantic_ext import CamelCaseModel
@@ -37,7 +38,8 @@ class TopicRouter[T: CamelCaseModel]:
def __init__(
self,
topic: TypedTopic[T],
networking_sender: Sender[tuple[str, bytes]],
networking_sender: RustSender,
networking_receiver: RustReceiver,
max_buffer_size: float = inf,
):
self.topic: TypedTopic[T] = topic
@@ -45,7 +47,7 @@ class TopicRouter[T: CamelCaseModel]:
send, recv = channel[T]()
self.receiver: Receiver[T] = recv
self._sender: Sender[T] = send
self.networking_sender: Sender[tuple[str, bytes]] = networking_sender
self.networking_sender: RustSender = networking_sender
async def run(self):
logger.debug(f"Topic Router {self.topic} ready to send")
@@ -93,35 +95,24 @@ class TopicRouter[T: CamelCaseModel]:
async def _send_out(self, item: T):
logger.trace(f"TopicRouter {self.topic.topic} sending {item}")
await self.networking_sender.send(
(str(self.topic.topic), self.topic.serialize(item))
)
await self.networking_sender.send(self.topic.serialize(item))
class Router:
@classmethod
def create(cls, identity: Keypair) -> "Router":
return cls(handle=NetworkingHandle(identity))
async def create(cls, identity: Keypair) -> "Router":
return cls(handle=await RustNetworkingHandle.create(identity, __version__))
def __init__(self, handle: NetworkingHandle):
def __init__(self, handle: RustNetworkingHandle):
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
send, recv = channel[tuple[str, bytes]]()
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
self._net: NetworkingHandle = handle
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
self._net: RustNetworkingHandle = handle
self._id_count = count()
self._tg: TaskGroup | None = None
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
assert self._tg is None, "Attempted to register topic after setup time"
send = self._tmp_networking_sender
if send:
self._tmp_networking_sender = None
else:
send = self.networking_receiver.clone_sender()
router = TopicRouter[T](topic, send)
router = TopicRouter[T](topic, *await self._net.subscribe(str(topic.topic)))
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
await self._networking_subscribe(str(topic.topic))
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
router = self.topic_routers.get(topic.topic, None)
@@ -151,13 +142,9 @@ class Router:
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self):
logger.debug("Shutting down Router")
@@ -165,29 +152,10 @@ class Router:
return
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
async def _networking_recv(self):
while True:
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
if topic not in self.topic_routers:
logger.warning(f"Received message on unknown or inactive topic {topic}")
continue
router = self.topic_routers[topic]
await router.publish_bytes(data)
async def _networking_recv_connection_messages(self):
recv = await self._net.get_connection_receiver()
while True:
update = await self._net.connection_update_recv()
message = ConnectionMessage.from_update(update)
message = await recv.receive()
logger.trace(
f"Received message on connection_messages with payload {message}"
)
@@ -195,18 +163,7 @@ class Router:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
async def _networking_publish(self):
with self.networking_receiver as networked_items:
async for topic, data in networked_items:
try:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
# As a hack, this also catches AllQueuesFull
# Need to fix that ASAP.
except (NoPeersSubscribedToTopicError, AllQueuesFullError):
pass
await router.publish(ConnectionMessage.from_rust(message))
def get_node_id_keypair(
@@ -225,16 +182,16 @@ def get_node_id_keypair(
with open(path, "a+b") as f: # opens in append-mode => starts at EOF
# if non-zero EOF, then file exists => use to get node-ID
if f.tell() != 0:
f.seek(0) # go to start & read protobuf-encoded bytes
protobuf_encoded = f.read()
f.seek(0) # go to start & read postcard-encoded bytes
postcard_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded)
return Keypair.from_postcard_encoding(postcard_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
f.write(keypair.to_postcard_encoding())
return keypair

View File

@@ -81,16 +81,16 @@ class Topology:
self,
connection: Connection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection.source_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.source_id))
if connection.sink_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.sink_id))
if connection in self._edge_id_to_rx_id_map:
return
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._node_id_to_rx_id_map[connection.source_id]
sink_id = self._node_id_to_rx_id_map[connection.sink_id]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
@@ -188,10 +188,7 @@ class Topology:
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
if connection.source_id in node_idxs and connection.sink_id in node_idxs:
topology.add_connection(connection)
return topology

View File

@@ -1,3 +1,5 @@
from typing import Self
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
@@ -26,6 +28,17 @@ class PlaceInstance(BaseCommand):
instance_meta: InstanceMeta
min_nodes: int
# Decision point - I like this syntax better than the typical fixtures,
# but it's """bloat"""
@classmethod
def fixture(cls) -> Self:
return cls(
model_meta=ModelMetadata.fixture(),
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
)
class CreateInstance(BaseCommand):
instance: Instance

View File

@@ -1,3 +1,5 @@
from typing import Self
from pydantic import PositiveInt
from exo.shared.types.common import Id
@@ -14,3 +16,12 @@ class ModelMetadata(CamelCaseModel):
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
@classmethod
def fixture(cls) -> Self:
return cls(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
)

View File

@@ -2,6 +2,7 @@ from typing import Self
import psutil
from exo.routing.connection_message import IpAddress
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
@@ -49,7 +50,7 @@ class SystemPerformanceProfile(CamelCaseModel):
class NetworkInterfaceInfo(CamelCaseModel):
name: str
ip_address: str
ip_address: IpAddress
class NodePerformanceProfile(CamelCaseModel):

View File

@@ -1,5 +1,5 @@
from exo.routing.connection_message import SocketAddress
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
@@ -10,17 +10,17 @@ class NodeInfo(CamelCaseModel):
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
source_id: NodeId
sink_id: NodeId
sink_addr: SocketAddress
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
self.source_id,
self.sink_id,
self.sink_addr,
)
)
@@ -28,10 +28,10 @@ class Connection(CamelCaseModel):
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
self.source_id == other.source_id
and self.sink_id == other.sink_id
and self.sink_addr == other.sink_addr
)
def is_thunderbolt(self) -> bool:
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
return str(self.sink_addr.ip).startswith("169.254")

View File

@@ -0,0 +1,31 @@
from exo.routing.connection_message import ConnectionMessage
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, TopologyEdgeCreated, TopologyEdgeDeleted
from exo.shared.types.state import State
from exo.shared.types.topology import Connection
def check_connections(
local_id: NodeId, msg: ConnectionMessage, state: State
) -> list[Event]:
remote_id = msg.node_id
sockets = msg.ips
del msg
out: list[Event] = []
if not state.topology.contains_node(remote_id) or remote_id in state.node_profiles:
return out
conns = list(state.topology.list_connections())
for iface in state.node_profiles[remote_id].network_interfaces:
for sock in sockets:
if iface.ip_address == sock.ip:
conn = Connection(source_id=local_id, sink_id=remote_id, sink_addr=sock)
if state.topology.contains_connection(conn):
conns.remove(conn)
continue
out.append(TopologyEdgeCreated(edge=conn))
for conn in conns:
out.append(TopologyEdgeDeleted(edge=conn))
return out

View File

@@ -1,43 +1,41 @@
from typing import Any
from typing import TYPE_CHECKING, Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
if TYPE_CHECKING:
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
class Model(nn.Module):
layers: list[nn.Module]
class Model(nn.Module):
layers: list[nn.Module]
def __call__(
self,
x: mx.array,
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...
def __call__(
self,
x: mx.array,
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...
class Detokenizer:
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
@property
def last_segment(self) -> str: ...
class Detokenizer:
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
class TokenizerWrapper:
bos_token: str | None
eos_token_ids: list[int]
detokenizer: Detokenizer
@property
def last_segment(self) -> str: ...
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
class TokenizerWrapper:
bos_token: str | None
eos_token_ids: list[int]
detokenizer: Detokenizer
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
def apply_chat_template(
self,
messages_dicts: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = True,
) -> str: ...
def apply_chat_template(
self,
messages_dicts: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = True,
) -> str: ...

View File

@@ -6,7 +6,7 @@ from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.routing.connection_message import ConnectionMessage
from exo.shared.apply import apply
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
@@ -20,10 +20,7 @@ from exo.shared.types.events import (
NodePerformanceMeasured,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -33,7 +30,6 @@ from exo.shared.types.tasks import (
Task,
TaskStatus,
)
from exo.shared.types.topology import Connection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -258,33 +254,13 @@ class Worker:
async def _connection_message_event_writer(self):
with self.connection_message_receiver as connection_messages:
async for msg in connection_messages:
await self.event_sender.send(
self._convert_connection_message_to_event(msg)
)
for event in self._convert_connection_message_to_event(msg):
await self.event_sender.send(event)
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
)
def _convert_connection_message_to_event(
self, msg: ConnectionMessage
) -> list[Event]:
return check_connections(self.node_id, msg, self.state)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.

View File

@@ -1,5 +1,6 @@
import socket
import sys
from ipaddress import ip_address
from subprocess import CalledProcessError
import psutil
@@ -29,12 +30,6 @@ async def get_friendly_name() -> str:
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
"""
Retrieves detailed network interface information on macOS.
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
to determine interface names, IP addresses, and types (ethernet, wifi, vpn, other).
Returns a list of NetworkInterfaceInfo objects.
"""
interfaces_info: list[NetworkInterfaceInfo] = []
for iface, services in psutil.net_if_addrs().items():
@@ -42,7 +37,9 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
match service.family:
case socket.AF_INET | socket.AF_INET6:
interfaces_info.append(
NetworkInterfaceInfo(name=iface, ip_address=service.address)
NetworkInterfaceInfo(
name=iface, ip_address=ip_address(service.address)
)
)
case _:
pass

2
uv.lock generated
View File

@@ -316,7 +316,7 @@ wheels = [
[[package]]
name = "exo"
version = "0.3.0"
version = "0.10.0"
source = { editable = "." }
dependencies = [
{ name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },