mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-26 03:06:05 -05:00
Compare commits
1 Commits
leo/handle
...
move-messa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
541930d813 |
@@ -249,8 +249,7 @@ class ChunkedKVCache(KVCache):
|
||||
...
|
||||
|
||||
class CacheList(_BaseCache):
|
||||
caches: tuple[_BaseCache, ...]
|
||||
def __init__(self, *caches: _BaseCache) -> None: ...
|
||||
def __init__(self, *caches) -> None: ...
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
|
||||
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -216,6 +216,28 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
@@ -2759,6 +2781,7 @@ dependencies = [
|
||||
name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"delegate",
|
||||
"either",
|
||||
"extend",
|
||||
@@ -2767,6 +2790,7 @@ dependencies = [
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
|
||||
@@ -34,6 +34,7 @@ delegate = "0.13"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Async dependencies
|
||||
async-stream = "0.3"
|
||||
tokio = "1.46"
|
||||
futures-lite = "2.6.1"
|
||||
futures-timer = "3.0"
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
import enum
|
||||
import typing
|
||||
|
||||
@typing.final
|
||||
@@ -11,29 +10,6 @@ class AllQueuesFullError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> builtins.str:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
@@ -58,21 +34,15 @@ class Keypair:
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class MessageTooLargeError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
@@ -91,24 +61,7 @@ class NetworkingHandle:
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class MessageTooLargeError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
async def recv(self) -> PyFromSwarm: ...
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
@@ -116,11 +69,26 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
class PyFromSwarm:
|
||||
@typing.final
|
||||
class Connection(PyFromSwarm):
|
||||
__match_args__ = ("peer_id", "connected",)
|
||||
@property
|
||||
def peer_id(self) -> builtins.str: ...
|
||||
@property
|
||||
def connected(self) -> builtins.bool: ...
|
||||
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...
|
||||
|
||||
@typing.final
|
||||
class Message(PyFromSwarm):
|
||||
__match_args__ = ("origin", "topic", "data",)
|
||||
@property
|
||||
def origin(self) -> builtins.str: ...
|
||||
@property
|
||||
def topic(self) -> builtins.str: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
|
||||
|
||||
...
|
||||
|
||||
|
||||
@@ -4,11 +4,12 @@ build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "exo_pyo3_bindings"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" }
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
|
||||
{ name = "Evan Quiney", email = "evanev7@gmail.com" }
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
|
||||
@@ -155,6 +155,9 @@ pub(crate) mod ext {
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
let mut builder = tokio::runtime::Builder::new_multi_thread();
|
||||
builder.enable_all();
|
||||
pyo3_async_runtimes::tokio::init(builder);
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::exception::{
|
||||
PyAllQueuesFullError, PyMessageTooLargeError, PyNoPeersSubscribedToTopicError,
|
||||
};
|
||||
use crate::pyclass;
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use futures_lite::{Stream, StreamExt as _};
|
||||
use libp2p::gossipsub::PublishError;
|
||||
use networking::swarm::{FromSwarm, ToSwarm, create_swarm};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{
|
||||
gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
|
||||
};
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
@@ -131,237 +129,45 @@ mod exception {
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: String,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else if let Err(PublishError::MessageTooLarge) = result {
|
||||
Err(exception::PyMessageTooLargeError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
pub to_swarm: mpsc::Sender<ToSwarm>,
|
||||
pub swarm: Arc<Mutex<Pin<Box<dyn Stream<Item = FromSwarm> + Send>>>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
#[gen_stub_pyclass_complex_enum]
|
||||
#[pyclass]
|
||||
enum PyFromSwarm {
|
||||
Connection {
|
||||
peer_id: String,
|
||||
connected: bool,
|
||||
},
|
||||
Message {
|
||||
origin: String,
|
||||
topic: String,
|
||||
data: Py<PyBytes>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
impl From<FromSwarm> for PyFromSwarm {
|
||||
fn from(value: FromSwarm) -> Self {
|
||||
match value {
|
||||
FromSwarm::Discovered { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: true,
|
||||
},
|
||||
FromSwarm::Expired { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: false,
|
||||
},
|
||||
FromSwarm::Message { from, topic, data } => Self::Message {
|
||||
origin: from.to_base58(),
|
||||
topic: topic,
|
||||
data: data.pybytes(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
@@ -375,97 +181,36 @@ impl PyNetworkingHandle {
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
|
||||
let swarm = create_swarm(identity, from_client).pyerr()?.into_stream();
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
Ok(Self {
|
||||
swarm: Arc::new(Mutex::new(swarm)),
|
||||
to_swarm,
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||
let swarm = Arc::clone(&self.swarm);
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
swarm
|
||||
.try_lock()
|
||||
.map_err(|_| PyRuntimeError::new_err("called recv twice concurrently"))?
|
||||
.next()
|
||||
.await
|
||||
.ok_or(PyErr::receiver_channel_closed())
|
||||
.map(PyFromSwarm::from)
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
@@ -475,10 +220,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -487,6 +232,7 @@ impl PyNetworkingHandle {
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
@@ -496,10 +242,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -518,11 +264,11 @@ impl PyNetworkingHandle {
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -531,64 +277,26 @@ impl PyNetworkingHandle {
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.map_err(|e| match e {
|
||||
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
|
||||
PublishError::MessageTooLarge => PyMessageTooLargeError::new_err(),
|
||||
PublishError::NoPeersSubscribedToTopic => {
|
||||
PyNoPeersSubscribedToTopicError::new_err()
|
||||
}
|
||||
e => PyRuntimeError::new_err(e.to_string()),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
pyo3_stub_gen::inventory::submit! {
|
||||
gen_methods_from_python! {
|
||||
r#"
|
||||
class PyNetworkingHandle:
|
||||
async def recv() -> PyFromSwarm: ...
|
||||
"#
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
@@ -596,10 +304,8 @@ pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
m.add_class::<exception::PyMessageTooLargeError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
m.add_class::<PyFromSwarm>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -21,9 +21,10 @@ extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
async-stream = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
@@ -35,3 +36,4 @@ log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use libp2p::identity;
|
||||
use networking::swarm;
|
||||
use networking::swarm::{FromSwarm, ToSwarm};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
@@ -11,64 +13,69 @@ async fn main() {
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
let (to_swarm, from_client) = mpsc::channel(20);
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
|
||||
.expect("Swarm creation failed")
|
||||
.into_stream();
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
let (tx, rx) = oneshot::channel();
|
||||
_ = to_swarm
|
||||
.send(ToSwarm::Subscribe {
|
||||
topic: "test-net".to_string(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
.expect("should send");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
rx.await
|
||||
.expect("tx not dropped")
|
||||
.expect("subscribe shouldn't fail");
|
||||
loop {
|
||||
if let Ok(Some(line)) = stdin.next_line().await {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Err(e) = to_swarm
|
||||
.send(swarm::ToSwarm::Publish {
|
||||
topic: "test-net".to_string(),
|
||||
data: line.as_bytes().to_vec(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
{
|
||||
println!("Send error: {e:?}");
|
||||
return;
|
||||
};
|
||||
match rx.await {
|
||||
Ok(Err(e)) => println!("Publish error: {e:?}"),
|
||||
Err(e) => println!("Publish error: {e:?}"),
|
||||
Ok(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
// on gossipsub outgoing
|
||||
match swarm.next().await {
|
||||
// on gossipsub incoming
|
||||
Some(FromSwarm::Discovered { peer_id }) => {
|
||||
println!("\n\nconnected to {peer_id}\n\n")
|
||||
}
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
Some(FromSwarm::Expired { peer_id }) => {
|
||||
println!("\n\ndisconnected from {peer_id}\n\n")
|
||||
}
|
||||
Some(FromSwarm::Message { from, topic, data }) => {
|
||||
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
use std::pin::Pin;
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
use crate::{alias, discovery};
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use futures_lite::{Stream, StreamExt};
|
||||
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
@@ -15,8 +17,139 @@ pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
pub enum ToSwarm {
|
||||
Unsubscribe {
|
||||
topic: String,
|
||||
// Sender for the unsubscribe result (False = not subscribed)
|
||||
result_sender: oneshot::Sender<bool>,
|
||||
},
|
||||
Subscribe {
|
||||
topic: String,
|
||||
// Sender for the subscribe result (False = not subscribed), errors if we can't publish our
|
||||
// subscription to peers
|
||||
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
|
||||
},
|
||||
Publish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
// Sender for the publish result, makes it easier to correlate publish with publish
|
||||
// errors
|
||||
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
|
||||
},
|
||||
}
|
||||
pub enum FromSwarm {
|
||||
Message {
|
||||
from: PeerId,
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
},
|
||||
Discovered {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
Expired {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Swarm {
|
||||
swarm: libp2p::Swarm<Behaviour>,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
}
|
||||
|
||||
impl Swarm {
|
||||
pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = FromSwarm> + Send>> {
|
||||
let Swarm {
|
||||
mut swarm,
|
||||
mut from_client,
|
||||
} = self;
|
||||
let stream = async_stream::stream! {
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = from_client.recv() => {
|
||||
let Some(msg) = msg else { break };
|
||||
on_message(&mut swarm, msg);
|
||||
}
|
||||
event = swarm.next() => {
|
||||
let Some(event) = event else { break };
|
||||
if let Some(item) = filter_swarm_event(event) {
|
||||
yield item;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Box::pin(stream)
|
||||
}
|
||||
}
|
||||
|
||||
fn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {
|
||||
match message {
|
||||
ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&gossipsub::IdentTopic::new(topic));
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.unsubscribe(&gossipsub::IdentTopic::new(topic));
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.publish(gossipsub::IdentTopic::new(topic), data);
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
|
||||
match event {
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
message:
|
||||
gossipsub::Message {
|
||||
source: Some(peer_id),
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => Some(FromSwarm::Message {
|
||||
from: peer_id,
|
||||
topic: topic.into_string(),
|
||||
data,
|
||||
}),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionEstablished { peer_id, .. },
|
||||
)) => Some(FromSwarm::Discovered { peer_id }),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {
|
||||
peer_id,
|
||||
..
|
||||
})) => Some(FromSwarm::Expired { peer_id }),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
pub fn create_swarm(
|
||||
keypair: identity::Keypair,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
@@ -25,7 +158,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
Ok(Swarm { swarm, from_client })
|
||||
}
|
||||
|
||||
mod transport {
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
||||
from exo_pyo3_bindings import PyFromSwarm
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -8,30 +6,10 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
|
||||
class ConnectionMessageType(Enum):
|
||||
Connected = 0
|
||||
Disconnected = 1
|
||||
|
||||
@staticmethod
|
||||
def from_update_type(update_type: ConnectionUpdateType):
|
||||
match update_type:
|
||||
case ConnectionUpdateType.Connected:
|
||||
return ConnectionMessageType.Connected
|
||||
case ConnectionUpdateType.Disconnected:
|
||||
return ConnectionMessageType.Disconnected
|
||||
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
connection_type: ConnectionMessageType
|
||||
remote_ipv4: str
|
||||
remote_tcp_port: int
|
||||
connected: bool
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
)
|
||||
def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
|
||||
return cls(node_id=NodeId(update.peer_id), connected=update.connected)
|
||||
|
||||
@@ -17,6 +17,7 @@ from exo_pyo3_bindings import (
|
||||
MessageTooLargeError,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
PyFromSwarm,
|
||||
)
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
@@ -121,7 +122,8 @@ class Router:
|
||||
send = self.networking_receiver.clone_sender()
|
||||
router = TopicRouter[T](topic, send)
|
||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||
await self._networking_subscribe(str(topic.topic))
|
||||
if self._tg.is_running():
|
||||
await self._networking_subscribe(topic.topic)
|
||||
|
||||
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
||||
router = self.topic_routers.get(topic.topic, None)
|
||||
@@ -152,8 +154,10 @@ class Router:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# subscribe to pending topics
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_subscribe(topic)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
@@ -176,41 +180,40 @@ class Router:
|
||||
async def _networking_recv(self):
|
||||
try:
|
||||
while True:
|
||||
topic, data = await self._net.gossipsub_recv()
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(
|
||||
f"Received message on unknown or inactive topic {topic}"
|
||||
)
|
||||
continue
|
||||
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
from_swarm = await self._net.recv()
|
||||
logger.debug(from_swarm)
|
||||
match from_swarm:
|
||||
case PyFromSwarm.Message(origin, topic, data):
|
||||
logger.trace(
|
||||
f"Received message on {topic} from {origin} with payload {data}"
|
||||
)
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(
|
||||
f"Received message on unknown or inactive topic {topic}"
|
||||
)
|
||||
continue
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
case PyFromSwarm.Connection():
|
||||
message = ConnectionMessage.from_update(from_swarm)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
case _:
|
||||
logger.critical(
|
||||
"failed to exhaustively check FromSwarm messages - logic error"
|
||||
)
|
||||
except Exception as exception:
|
||||
logger.opt(exception=exception).error(
|
||||
"Gossipsub receive loop terminated unexpectedly"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
try:
|
||||
while True:
|
||||
update = await self._net.connection_update_recv()
|
||||
message = ConnectionMessage.from_update(update)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
except Exception as exception:
|
||||
logger.opt(exception=exception).error(
|
||||
"Connection update receive loop terminated unexpectedly"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _networking_publish(self):
|
||||
with self.networking_receiver as networked_items:
|
||||
async for topic, data in networked_items:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
@@ -327,14 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Send any connection message object; we close quickly to cancel before result creation
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4="",
|
||||
remote_tcp_port=0,
|
||||
)
|
||||
)
|
||||
await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
|
||||
|
||||
# Expect a broadcast for the new round at clock=1
|
||||
while True:
|
||||
|
||||
@@ -542,10 +542,13 @@ class InfoGatherer:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
t = TextReceiveStream(BufferedByteReceiveStream(p.stdout))
|
||||
while True:
|
||||
with anyio.fail_after(self.macmon_interval * 3):
|
||||
macmon_output = await t.receive()
|
||||
await self.info_sender.send(
|
||||
MacmonMetrics.from_raw_json(macmon_output)
|
||||
)
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
@@ -556,8 +559,12 @@ class InfoGatherer:
|
||||
else str(stderr_output)
|
||||
)
|
||||
logger.warning(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
f"memory monitor failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"memory monitor silent for {self.macmon_interval * 3}s - reloading"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in macmon monitor: {e}")
|
||||
logger.opt(exception=e).warning("Error in memory monitor")
|
||||
await anyio.sleep(self.macmon_interval)
|
||||
|
||||
@@ -32,7 +32,7 @@ def _default_memory_threshold() -> float:
|
||||
return 0.70
|
||||
|
||||
|
||||
MEMORY_THRESHOLD = float(
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _default_memory_threshold())
|
||||
)
|
||||
|
||||
@@ -92,15 +92,6 @@ class KVPrefixCache:
|
||||
self._snapshots.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def force_evict_all(self) -> int:
|
||||
count = len(self.caches)
|
||||
self.clear()
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"Force-evicted all {count} prefix cache entries due to memory pressure"
|
||||
)
|
||||
return count
|
||||
|
||||
def add_kv_cache(
|
||||
self,
|
||||
prompt_tokens: mx.array,
|
||||
@@ -226,7 +217,7 @@ class KVPrefixCache:
|
||||
# Evict LRU entries until below threshold
|
||||
while (
|
||||
len(self.caches) > 0
|
||||
and self.get_memory_used_percentage() > MEMORY_THRESHOLD
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
@@ -319,59 +310,6 @@ def get_memory_used_percentage() -> float:
|
||||
return float(mem.percent / 100)
|
||||
|
||||
|
||||
def get_safety_floor() -> int:
|
||||
total = psutil.virtual_memory().total
|
||||
return min(int(total * 0.10), 5 * 1024**3)
|
||||
|
||||
|
||||
def get_memory_pressure_threshold() -> float:
|
||||
total = psutil.virtual_memory().total
|
||||
return 1.0 - get_safety_floor() / total
|
||||
|
||||
|
||||
def _measure_single_cache_bytes(
|
||||
entry: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,
|
||||
) -> int:
|
||||
if isinstance(entry, CacheList):
|
||||
return sum(
|
||||
_measure_single_cache_bytes(c) # pyright: ignore[reportArgumentType]
|
||||
for c in entry.caches
|
||||
)
|
||||
|
||||
total = 0
|
||||
if isinstance(entry, ArraysCache):
|
||||
state = entry.state # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
for arr in state: # pyright: ignore[reportUnknownVariableType]
|
||||
if isinstance(arr, mx.array):
|
||||
total += arr.nbytes
|
||||
return total
|
||||
|
||||
total = 0
|
||||
for attr_name in ("keys", "values"):
|
||||
val: object = getattr(entry, attr_name, None)
|
||||
if val is None:
|
||||
continue
|
||||
if isinstance(val, mx.array):
|
||||
total += val.nbytes
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for arr in val: # pyright: ignore[reportUnknownVariableType]
|
||||
if isinstance(arr, mx.array):
|
||||
total += arr.nbytes
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def measure_cache_bytes(cache: KVCacheType) -> int:
|
||||
return sum(_measure_single_cache_bytes(c) for c in cache)
|
||||
|
||||
|
||||
def measure_kv_cache_bytes_per_token(cache: KVCacheType) -> int:
|
||||
offset = cache_length(cache)
|
||||
if offset == 0:
|
||||
return 0
|
||||
return measure_cache_bytes(cache) // offset
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
|
||||
@@ -4,7 +4,6 @@ from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
@@ -31,10 +30,8 @@ from exo.worker.engines.mlx.cache import (
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
encode_prompt,
|
||||
get_memory_pressure_threshold,
|
||||
has_non_kv_caches,
|
||||
make_kv_cache,
|
||||
measure_kv_cache_bytes_per_token,
|
||||
snapshot_ssm_states,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -46,7 +43,6 @@ from exo.worker.engines.mlx.constants import (
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
fix_unmatched_think_end_tokens,
|
||||
mx_any,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -152,8 +148,7 @@ def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> tuple[int, int]:
|
||||
"""Run warmup inference and tokens_generated and bytes_per_token"""
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
warmup_prompt = apply_chat_template(
|
||||
@@ -192,12 +187,9 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
bytes_per_token = measure_kv_cache_bytes_per_token(cache)
|
||||
logger.info(f"Measured KV cache cost: {bytes_per_token} bytes per token")
|
||||
|
||||
mx_barrier(group)
|
||||
|
||||
return tokens_generated, bytes_per_token
|
||||
return tokens_generated
|
||||
|
||||
|
||||
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
|
||||
@@ -275,37 +267,6 @@ def extract_top_logprobs(
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def _check_memory_budget(
|
||||
bytes_per_token: int,
|
||||
total_sequence_tokens: int,
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
) -> str | None:
|
||||
if bytes_per_token == 0:
|
||||
return None
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
estimated = bytes_per_token * total_sequence_tokens / mem.total
|
||||
projected = mem.percent / 100 + estimated
|
||||
threshold = get_memory_pressure_threshold()
|
||||
|
||||
if not mx_any(projected > threshold, group):
|
||||
return None
|
||||
|
||||
if kv_prefix_cache is not None and kv_prefix_cache.force_evict_all() > 0:
|
||||
mx.clear_cache()
|
||||
mem = psutil.virtual_memory()
|
||||
projected = mem.percent / 100 + estimated
|
||||
if not mx_any(projected > threshold, group):
|
||||
return None
|
||||
|
||||
return (
|
||||
f"Not enough memory for this conversation ({projected:.0%} projected, "
|
||||
f"{threshold:.0%} limit). "
|
||||
f"Please start a new conversation or compact your messages."
|
||||
)
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -314,7 +275,6 @@ def mlx_generate(
|
||||
kv_prefix_cache: KVPrefixCache | None,
|
||||
group: mx.distributed.Group | None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
bytes_per_token: int = 0,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -347,23 +307,6 @@ def mlx_generate(
|
||||
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
||||
)
|
||||
|
||||
if bytes_per_token > 0:
|
||||
oom_error = _check_memory_budget(
|
||||
bytes_per_token=bytes_per_token,
|
||||
total_sequence_tokens=len(all_prompt_tokens),
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=group,
|
||||
)
|
||||
if oom_error is not None:
|
||||
logger.warning(f"OOM prevention (prefill): {oom_error}")
|
||||
yield GenerationResponse(
|
||||
text=oom_error,
|
||||
token=0,
|
||||
finish_reason="error",
|
||||
usage=None,
|
||||
)
|
||||
return
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
|
||||
@@ -6,7 +6,6 @@ from functools import cache
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
@@ -65,7 +64,7 @@ from exo.shared.types.worker.runners import (
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, get_memory_pressure_threshold
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import (
|
||||
PrefillCancelled,
|
||||
mlx_generate,
|
||||
@@ -115,7 +114,6 @@ def main(
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
bytes_per_token: int = 0
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -227,14 +225,12 @@ def main(
|
||||
assert tokenizer
|
||||
|
||||
t = time.monotonic()
|
||||
toks, bytes_per_token = warmup_inference(
|
||||
toks = warmup_inference(
|
||||
model=cast(Model, inference_model),
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
)
|
||||
logger.info(
|
||||
f"warmed up by generating {toks} tokens, {bytes_per_token} bytes/token for KV cache"
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||
)
|
||||
@@ -314,7 +310,6 @@ def main(
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
group=group,
|
||||
bytes_per_token=bytes_per_token,
|
||||
)
|
||||
|
||||
if tokenizer.has_thinking:
|
||||
@@ -341,7 +336,6 @@ def main(
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = check_for_cancel_every
|
||||
oom_stopped = False
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
@@ -350,14 +344,7 @@ def main(
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
oom_local = (
|
||||
bytes_per_token > 0
|
||||
and psutil.virtual_memory().percent / 100
|
||||
> get_memory_pressure_threshold()
|
||||
)
|
||||
if mx_any(want_to_cancel or oom_local, group):
|
||||
if not want_to_cancel:
|
||||
oom_stopped = True
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
@@ -413,21 +400,6 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
if oom_stopped and device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=model_id,
|
||||
error_message=(
|
||||
"Generation stopped: running out of memory. "
|
||||
"Please start a new conversation or compact "
|
||||
"your messages."
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except PrefillCancelled:
|
||||
logger.info(f"Prefill cancelled for task {task.task_id}")
|
||||
# can we make this more explicit?
|
||||
|
||||
@@ -114,7 +114,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin((1, 0)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
|
||||
Reference in New Issue
Block a user