Compare commits

..

4 Commits

Author SHA1 Message Date
Evan
cc3b8392e6 workin on it
print spam
2026-02-19 17:19:53 +00:00
Evan Quiney
4c4c6ce99f simplify rust ident module
this is partly dead code, partly narrowing the rust-python boundary in
prep for future rewrites. no testing as this is all type safe
refactoring.
2026-02-19 17:19:31 +00:00
Jake Hillion
42e1e7322b bench: restore --danger-delete-downloads planning phase (#1542)
c2f2111b extracted shared utilities from exo_bench.py into harness.py
but accidentally dropped the run_planning_phase function and
--danger-delete-downloads CLI argument in the process.

Restored run_planning_phase in harness.py (where its dependencies now
live) and re-added the --danger-delete-downloads argument to
add_common_instance_args. Re-wired the planning phase call in
exo_bench.py's main() before the benchmark loop.
2026-02-19 15:42:02 +00:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
20 changed files with 595 additions and 830 deletions

1
Cargo.lock generated
View File

@@ -2767,6 +2767,7 @@ dependencies = [
"keccak-const", "keccak-const",
"libp2p", "libp2p",
"log", "log",
"pin-project",
"tokio", "tokio",
"tracing-subscriber", "tracing-subscriber",
"util", "util",

View File

@@ -20,6 +20,7 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -962,6 +963,21 @@ Examples:
selected.sort(key=_placement_sort_key) selected.sort(key=_placement_sort_key)
preview = selected[0] preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"] instance = preview["instance"]
instance_id = instance_id_from_instance(instance) instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"]) sharding = str(preview["sharding"])

View File

@@ -35,6 +35,7 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -332,6 +333,20 @@ def main() -> int:
if args.dry_run: if args.dry_run:
return 0 return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = [] all_rows: list[dict[str, Any]] = []
for preview in selected: for preview in selected:

View File

@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
return selected return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None: def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost")) ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument( ap.add_argument(
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0, default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).", help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
) )
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -2,7 +2,6 @@
# ruff: noqa: E501, F401 # ruff: noqa: E501, F401
import builtins import builtins
import enum
import typing import typing
@typing.final @typing.final
@@ -11,138 +10,33 @@ class AllQueuesFullError(builtins.Exception):
def __repr__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ... def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
@typing.final @typing.final
class Keypair: class Keypair:
r""" r"""
Identity keypair of a node. Identity keypair of a node.
""" """
@staticmethod @staticmethod
def generate_ed25519() -> Keypair: def generate() -> Keypair:
r""" r"""
Generate a new Ed25519 keypair. Generate a new Ed25519 keypair.
""" """
@staticmethod @staticmethod
def generate_ecdsa() -> Keypair: def from_bytes(bytes: bytes) -> Keypair:
r""" r"""
Generate a new ECDSA keypair. Construct an Ed25519 keypair from secret key bytes
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
""" """
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
r""" r"""
Return a copy of this [`Multiaddr`]'s byte representation. Get the secret key bytes underlying the keypair
""" """
def to_string(self) -> builtins.str: def to_node_id(self) -> builtins.str:
r""" r"""
Convert a Multiaddr to a string. Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
""" """
@typing.final @typing.final
class NetworkingHandle: class NetworkingHandle:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ... def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.
"""
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool: async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r""" r"""
Subscribe to a `GossipSub` topic. Subscribe to a `GossipSub` topic.
@@ -161,18 +55,7 @@ class NetworkingHandle:
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception. If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
""" """
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]: async def recv(self) -> PyFromSwarm: ...
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
@typing.final @typing.final
class NoPeersSubscribedToTopicError(builtins.Exception): class NoPeersSubscribedToTopicError(builtins.Exception):
@@ -180,42 +63,26 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ... def __str__(self) -> builtins.str: ...
@typing.final class PyFromSwarm:
class PeerId: @typing.final
r""" class Connection(PyFromSwarm):
Identifier of a peer of the network. __match_args__ = ("peer_id", "connected",)
@property
def peer_id(self) -> builtins.str: ...
@property
def connected(self) -> builtins.bool: ...
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer @typing.final
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md). class Message(PyFromSwarm):
""" __match_args__ = ("origin", "topic", "data",)
@staticmethod @property
def random() -> PeerId: def origin(self) -> builtins.str: ...
r""" @property
Generates a random peer ID from a cryptographically secure PRNG. def topic(self) -> builtins.str: ...
@property
This is useful for randomly walking on a DHT, or for testing purposes. def data(self) -> bytes: ...
""" def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
@staticmethod
def from_bytes(bytes: bytes) -> PeerId: ...
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...

View File

@@ -1,8 +1,6 @@
use crate::ext::ResultExt as _; use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair; use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _}; use pyo3::types::{PyBytes, PyBytesMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods}; use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
@@ -18,142 +16,32 @@ pub struct PyKeypair(pub Keypair);
impl PyKeypair { impl PyKeypair {
/// Generate a new Ed25519 keypair. /// Generate a new Ed25519 keypair.
#[staticmethod] #[staticmethod]
fn generate_ed25519() -> Self { fn generate() -> Self {
Self(Keypair::generate_ed25519()) Self(Keypair::generate_ed25519())
} }
/// Generate a new ECDSA keypair. /// Construct an Ed25519 keypair from secret key bytes
#[staticmethod] #[staticmethod]
fn generate_ecdsa() -> Self { fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes()); let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?)) Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
} }
/// Encode a private key as protobuf structure. /// Get the secret key bytes underlying the keypair
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> { fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?; let bytes = self
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes)) Ok(PyBytes::new(py, &bytes))
} }
/// Convert the `Keypair` into the corresponding `PeerId`. /// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
fn to_peer_id(&self) -> PyPeerId { fn to_node_id(&self) -> String {
PyPeerId(self.0.public().to_peer_id()) self.0.public().to_peer_id().to_base58()
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
} }
} }
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -8,9 +8,10 @@ mod allow_threading;
mod ident; mod ident;
mod networking; mod networking;
use crate::ident::ident_submodule; use crate::ident::PyKeypair;
use crate::networking::networking_submodule; use crate::networking::networking_submodule;
use pyo3::prelude::PyModule; use pyo3::prelude::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule}; use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer; use pyo3_stub_gen::define_stub_info_gatherer;
@@ -154,11 +155,14 @@ pub(crate) mod ext {
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> { fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger // install logger
pyo3_log::init(); pyo3_log::init();
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.enable_all();
pyo3_async_runtimes::tokio::init(builder);
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system // TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without // work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues... // too many importing issues...
ident_submodule(m)?; m.add_class::<PyKeypair>()?;
networking_submodule(m)?; networking_submodule(m)?;
// top-level constructs // top-level constructs

View File

@@ -1,26 +1,21 @@
#![allow( use std::sync::Arc;
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
use crate::r#const::MPSC_CHANNEL_SIZE; use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _}; use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _}; use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
use crate::ident::{PyKeypair, PyPeerId}; use crate::ident::PyKeypair;
use crate::networking::exception::{PyAllQueuesFullError, PyNoPeersSubscribedToTopicError};
use crate::pyclass; use crate::pyclass;
use libp2p::futures::StreamExt as _; use futures_lite::StreamExt as _;
use libp2p::gossipsub; use libp2p::gossipsub::PublishError;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError}; use networking::swarm::{FromSwarm, Swarm, ToSwarm, create_swarm};
use libp2p::swarm::SwarmEvent; use pyo3::exceptions::PyRuntimeError;
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _}; use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes; use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods}; use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods}; use pyo3_stub_gen::derive::{
use std::net::IpAddr; gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
};
use tokio::sync::{Mutex, mpsc, oneshot}; use tokio::sync::{Mutex, mpsc, oneshot};
mod exception { mod exception {
@@ -100,235 +95,45 @@ mod exception {
} }
} }
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use SwarmEvent::*;
use ToTask::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyAllQueuesFullError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
}
}
}
log::info!("RUST: networking task stopped");
}
#[gen_stub_pyclass] #[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")] #[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle { struct PyNetworkingHandle {
// channels // channels
to_task_tx: Option<mpsc::Sender<ToTask>>, pub to_swarm: mpsc::Sender<ToSwarm>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>, pub swarm: Arc<Mutex<Swarm>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
} }
impl Drop for PyNetworkingHandle { #[gen_stub_pyclass_complex_enum]
fn drop(&mut self) { #[pyclass]
// TODO: may or may not need to await a "kill-signal" oneshot channel message, enum PyFromSwarm {
// to ensure that the networking task is done BEFORE exiting the clear function... Connection {
// but this may require GIL?? and it may not be safe to call GIL here?? peer_id: String,
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped connected: bool,
} },
Message {
origin: String,
topic: String,
data: Py<PyBytes>,
},
} }
impl From<FromSwarm> for PyFromSwarm {
#[allow(clippy::expect_used)] fn from(value: FromSwarm) -> Self {
impl PyNetworkingHandle { match value {
fn new( FromSwarm::Discovered { peer_id } => Self::Connection {
to_task_tx: mpsc::Sender<ToTask>, peer_id: peer_id.to_base58(),
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>, connected: true,
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>, },
) -> Self { FromSwarm::Expired { peer_id } => Self::Connection {
Self { peer_id: peer_id.to_base58(),
to_task_tx: Some(to_task_tx), connected: false,
connection_update_rx: Mutex::new(connection_update_rx), },
gossipsub_message_rx: Mutex::new(gossipsub_message_rx), FromSwarm::Message { from, topic, data } => Self::Message {
origin: from.to_base58(),
topic: topic,
data: data.pybytes(),
},
} }
} }
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
self.to_task_tx
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
} }
#[gen_stub_pymethods] #[gen_stub_pymethods]
@@ -342,97 +147,36 @@ impl PyNetworkingHandle {
#[new] #[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> { fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channels // create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE); let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity // get identity
let identity = identity.borrow().0.clone(); let identity = identity.borrow().0.clone();
// create networking swarm (within tokio context!! or it crashes) // create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime() let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
.block_on(async { create_swarm(identity) }) let swarm = { create_swarm(identity, from_client).pyerr()? };
.pyerr()?;
// spawn tokio task running the networking logic Ok(Self {
get_runtime().spawn(async move { swarm: Arc::new(Mutex::new(swarm)),
networking_task( to_swarm,
swarm, })
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
} }
#[gen_stub(skip)] #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
Ok(()) // This is needed purely so `__clear__` can work let swarm = Arc::clone(&self.swarm);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
swarm
.try_lock()
.map_err(|_| PyRuntimeError::new_err("called recv twice concurrently"))?
.next()
.await
.ok_or(PyErr::receiver_channel_closed())
.map(PyFromSwarm::from)
})
} }
#[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
// ---- Connection update receiver methods ----
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
///
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
/// will sleep until a `ConnectionUpdate`s is sent.
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next `ConnectionUpdate` from networking.
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ---- // ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic. /// Subscribe to a `GossipSub` topic.
@@ -442,10 +186,10 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
// send off request to subscribe // send off request to subscribe
self.to_task_tx() self.to_swarm
.send_py(ToTask::GossipsubSubscribe { .send_py(ToSwarm::Subscribe {
topic, topic,
result_tx: tx, result_sender: tx,
}) })
.allow_threads_py() // allow-threads-aware async call .allow_threads_py() // allow-threads-aware async call
.await?; .await?;
@@ -454,6 +198,7 @@ impl PyNetworkingHandle {
rx.allow_threads_py() // allow-threads-aware async call rx.allow_threads_py() // allow-threads-aware async call
.await .await
.map_err(|_| PyErr::receiver_channel_closed())? .map_err(|_| PyErr::receiver_channel_closed())?
.pyerr()
} }
/// Unsubscribes from a `GossipSub` topic. /// Unsubscribes from a `GossipSub` topic.
@@ -463,10 +208,10 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
// send off request to unsubscribe // send off request to unsubscribe
self.to_task_tx() self.to_swarm
.send_py(ToTask::GossipsubUnsubscribe { .send_py(ToSwarm::Unsubscribe {
topic, topic,
result_tx: tx, result_sender: tx,
}) })
.allow_threads_py() // allow-threads-aware async call .allow_threads_py() // allow-threads-aware async call
.await?; .await?;
@@ -485,11 +230,11 @@ impl PyNetworkingHandle {
// send off request to subscribe // send off request to subscribe
let data = Python::attach(|py| Vec::from(data.as_bytes(py))); let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx() self.to_swarm
.send_py(ToTask::GossipsubPublish { .send_py(ToSwarm::Publish {
topic, topic,
data, data,
result_tx: tx, result_sender: tx,
}) })
.allow_threads_py() // allow-threads-aware async call .allow_threads_py() // allow-threads-aware async call
.await?; .await?;
@@ -498,74 +243,33 @@ impl PyNetworkingHandle {
let _ = rx let _ = rx
.allow_threads_py() // allow-threads-aware async call .allow_threads_py() // allow-threads-aware async call
.await .await
.map_err(|_| PyErr::receiver_channel_closed())??; .map_err(|_| PyErr::receiver_channel_closed())?
.map_err(|e| match e {
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
PublishError::NoPeersSubscribedToTopic => {
PyNoPeersSubscribedToTopicError::new_err()
}
e => PyRuntimeError::new_err(e.to_string()),
})?;
Ok(()) Ok(())
} }
}
// ---- Gossipsub message receiver methods ---- pyo3_stub_gen::inventory::submit! {
gen_methods_from_python! {
/// Receives the next message from the `GossipSub` network. r#"
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> { class PyNetworkingHandle:
self.gossipsub_message_rx async def recv() -> PyFromSwarm: ...
.lock() "#
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
} }
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.into_iter()
.map(|(t, d)| (t, d.pybytes()))
.collect())
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
} }
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?; m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
m.add_class::<exception::PyAllQueuesFullError>()?; m.add_class::<exception::PyAllQueuesFullError>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?; m.add_class::<PyNetworkingHandle>()?;
m.add_class::<PyFromSwarm>()?;
Ok(()) Ok(())
} }

View File

@@ -35,3 +35,4 @@ log = { workspace = true }
# networking # networking
libp2p = { workspace = true, features = ["full"] } libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -1,7 +1,9 @@
use futures_lite::StreamExt; use futures_lite::StreamExt;
use libp2p::{gossipsub, identity, swarm::SwarmEvent}; use libp2p::identity;
use networking::{discovery, swarm}; use networking::swarm;
use tokio::{io, io::AsyncBufReadExt as _, select}; use networking::swarm::{FromSwarm, ToSwarm};
use tokio::sync::{mpsc, oneshot};
use tokio::{io, io::AsyncBufReadExt as _};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::filter::LevelFilter;
@@ -11,64 +13,68 @@ async fn main() {
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into())) .with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init(); .try_init();
let (to_swarm, from_client) = mpsc::channel(20);
// Configure swarm // Configure swarm
let mut swarm = let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed"); .expect("Swarm creation failed");
// Create a Gossipsub topic & subscribe // Create a Gossipsub topic & subscribe
let topic = gossipsub::IdentTopic::new("test-net"); let (tx, rx) = oneshot::channel();
swarm _ = to_swarm
.behaviour_mut() .send(ToSwarm::Subscribe {
.gossipsub topic: "test-net".to_string(),
.subscribe(&topic) result_sender: tx,
.expect("Subscribing to topic failed"); })
.await
.expect("should send");
// Read full lines from stdin // Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines(); let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub"); println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
tokio::task::spawn(async move {
rx.await
.expect("tx not dropped")
.expect("subscribe shouldn't fail");
loop {
if let Ok(Some(line)) = stdin.next_line().await {
let (tx, rx) = oneshot::channel();
if let Err(e) = to_swarm
.send(swarm::ToSwarm::Publish {
topic: "test-net".to_string(),
data: line.as_bytes().to_vec(),
result_sender: tx,
})
.await
{
println!("Send error: {e:?}");
return;
};
match rx.await {
Ok(Err(e)) => println!("Publish error: {e:?}"),
Err(e) => println!("Publish error: {e:?}"),
Ok(_) => {}
}
}
}
});
// Kick it off // Kick it off
loop { loop {
select! { // on gossipsub outgoing
// on gossipsub outgoing match swarm.next().await {
Ok(Some(line)) = stdin.next_line() => { // on gossipsub incoming
if let Err(e) = swarm Some(FromSwarm::Discovered { peer_id }) => {
.behaviour_mut().gossipsub println!("\n\nconnected to {peer_id}\n\n")
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
} }
event = swarm.next() => match event { Some(FromSwarm::Expired { peer_id }) => {
// on gossipsub incoming println!("\n\ndisconnected from {peer_id}\n\n")
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
}))) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
}
// ignore outgoing errors: those are normal
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }
} }
Some(FromSwarm::Message { from, topic, data }) => {
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
}
None => {}
} }
} }
} }

View File

@@ -1,9 +1,12 @@
use crate::alias; use std::pin::Pin;
use crate::swarm::transport::tcp_transport; use std::task::Poll;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
pub type Swarm = libp2p::Swarm<Behaviour>; use crate::swarm::transport::tcp_transport;
use crate::{alias, discovery};
pub use behaviour::{Behaviour, BehaviourEvent};
use futures_lite::Stream;
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
use tokio::sync::{mpsc, oneshot};
/// The current version of the network: this prevents devices running different versions of the /// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other. /// software from interacting with each other.
@@ -15,8 +18,144 @@ pub type Swarm = libp2p::Swarm<Behaviour>;
pub const NETWORK_VERSION: &[u8] = b"v0.0.1"; pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE"; pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
pub enum ToSwarm {
Unsubscribe {
topic: String,
result_sender: oneshot::Sender<bool>,
},
Subscribe {
topic: String,
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
},
Publish {
topic: String,
data: Vec<u8>,
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
},
}
pub enum FromSwarm {
Message {
from: PeerId,
topic: String,
data: Vec<u8>,
},
Discovered {
peer_id: PeerId,
},
Expired {
peer_id: PeerId,
},
}
#[pin_project::pin_project]
pub struct Swarm {
#[pin]
inner: libp2p::Swarm<Behaviour>,
from_client: mpsc::Receiver<ToSwarm>,
}
impl Swarm {
fn on_message(mut self: Pin<&mut Self>, message: ToSwarm) {
match message {
ToSwarm::Subscribe {
topic,
result_sender,
} => {
// try to subscribe
let result = self
.inner
.behaviour_mut()
.gossipsub
.subscribe(&gossipsub::IdentTopic::new(topic));
// send response oneshot
_ = result_sender.send(result)
}
ToSwarm::Unsubscribe {
topic,
result_sender,
} => {
// try to unsubscribe from the topic
let result = self
.inner
.behaviour_mut()
.gossipsub
.unsubscribe(&gossipsub::IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
_ = result_sender.send(result)
}
ToSwarm::Publish {
topic,
data,
result_sender,
} => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = self
.inner
.behaviour_mut()
.gossipsub
.publish(gossipsub::IdentTopic::new(topic), data);
// send response oneshot (or exit if connection closed)
_ = result_sender.send(result)
}
}
}
}
impl Stream for Swarm {
type Item = FromSwarm;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
let recv = self.as_mut().project().from_client;
match recv.poll_recv(cx) {
Poll::Ready(Some(msg)) => {
self.as_mut().on_message(msg);
// continue to re-poll after consumption
continue;
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {}
}
let inner = self.as_mut().project().inner;
return match inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(swarm_event)) => match swarm_event {
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(
gossipsub::Event::Message {
message:
gossipsub::Message {
source: Some(peer_id),
topic,
data,
..
},
..
},
)) => Poll::Ready(Some(FromSwarm::Message {
from: peer_id,
topic: topic.into_string(),
data,
})),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
discovery::Event::ConnectionEstablished { peer_id, .. },
)) => Poll::Ready(Some(FromSwarm::Discovered { peer_id })),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
discovery::Event::ConnectionClosed { peer_id, .. },
)) => Poll::Ready(Some(FromSwarm::Expired { peer_id })),
// continue to re-poll after consumption
_ => continue,
},
};
}
}
}
/// Create and configure a swarm which listens to all ports on OS /// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> { pub fn create_swarm(
keypair: identity::Keypair,
from_client: mpsc::Receiver<ToSwarm>,
) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair) let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio() .with_tokio()
.with_other_transport(tcp_transport)? .with_other_transport(tcp_transport)?
@@ -25,7 +164,10 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
// Listen on all interfaces and whatever port the OS assigns // Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?; swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm) Ok(Swarm {
inner: swarm,
from_client,
})
} }
mod transport { mod transport {

View File

@@ -338,17 +338,7 @@ class DownloadCoordinator:
), ),
) )
elif progress.status in ["in_progress", "not_started"]: elif progress.status in ["in_progress", "not_started"]:
if ( if progress.downloaded_bytes_this_session.in_bytes == 0:
progress.downloaded_bytes.in_bytes
>= progress.total_bytes.in_bytes
> 0
):
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending( status = DownloadPending(
node_id=self.node_id, node_id=self.node_id,
shard_metadata=progress.shard, shard_metadata=progress.shard,

View File

@@ -45,7 +45,7 @@ class Node:
@classmethod @classmethod
async def create(cls, args: "Args") -> "Self": async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58()) node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair) router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS) await router.register_topic(topics.GLOBAL_EVENTS)
@@ -258,7 +258,7 @@ def main():
target = min(max(soft, 65535), hard) target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard)) resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system # TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity) logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO") logger.info("Starting EXO")

View File

@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
ResponseOutputText, ResponseOutputText,
ResponsesRequest, ResponsesRequest,
ResponsesResponse, ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseTextDoneEvent,
ResponseUsage, ResponseUsage,
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str: def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts.""" """Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str): if isinstance(content, str):
@@ -219,13 +225,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent( created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n" yield _format_sse(created_event)
# response.in_progress # response.in_progress
in_progress_event = ResponseInProgressEvent( in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n" yield _format_sse(in_progress_event)
# response.output_item.added # response.output_item.added
initial_item = ResponseMessageItem( initial_item = ResponseMessageItem(
@@ -236,7 +242,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent( item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item sequence_number=next(seq), output_index=0, item=initial_item
) )
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n" yield _format_sse(item_added)
# response.content_part.added # response.content_part.added
initial_part = ResponseOutputText(text="") initial_part = ResponseOutputText(text="")
@@ -247,7 +253,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
part=initial_part, part=initial_part,
) )
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n" yield _format_sse(part_added)
accumulated_text = "" accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = [] function_call_items: list[ResponseFunctionCallItem] = []
@@ -281,7 +287,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_item, item=fc_item,
) )
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n" yield _format_sse(fc_added)
# response.function_call_arguments.delta # response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent( args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -290,7 +296,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
delta=tool.arguments, delta=tool.arguments,
) )
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n" yield _format_sse(args_delta)
# response.function_call_arguments.done # response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent( args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -300,7 +306,7 @@ async def generate_responses_stream(
name=tool.name, name=tool.name,
arguments=tool.arguments, arguments=tool.arguments,
) )
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n" yield _format_sse(args_done)
# response.output_item.done # response.output_item.done
fc_done_item = ResponseFunctionCallItem( fc_done_item = ResponseFunctionCallItem(
@@ -315,7 +321,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_done_item, item=fc_done_item,
) )
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n" yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item) function_call_items.append(fc_done_item)
next_output_index += 1 next_output_index += 1
@@ -331,7 +337,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
delta=chunk.text, delta=chunk.text,
) )
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n" yield _format_sse(delta_event)
# response.output_text.done # response.output_text.done
text_done = ResponseTextDoneEvent( text_done = ResponseTextDoneEvent(
@@ -341,7 +347,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
text=accumulated_text, text=accumulated_text,
) )
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n" yield _format_sse(text_done)
# response.content_part.done # response.content_part.done
final_part = ResponseOutputText(text=accumulated_text) final_part = ResponseOutputText(text=accumulated_text)
@@ -352,7 +358,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
part=final_part, part=final_part,
) )
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n" yield _format_sse(part_done)
# response.output_item.done # response.output_item.done
final_message_item = ResponseMessageItem( final_message_item = ResponseMessageItem(
@@ -363,7 +369,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent( item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item sequence_number=next(seq), output_index=0, item=final_message_item
) )
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n" yield _format_sse(item_done)
# Create usage from usage data if available # Create usage from usage data if available
usage = None usage = None
@@ -388,4 +394,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent( completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response sequence_number=next(seq), response=final_response
) )
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n" yield _format_sse(completed_event)

View File

@@ -42,7 +42,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_master(): async def test_master():
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58()) node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]() ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -75,7 +75,7 @@ async def test_master():
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(master.run) tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender") sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
# inject a NodeGatheredInfo event # inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event") logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send( await local_event_sender.send(

View File

@@ -1,6 +1,4 @@
from enum import Enum from exo_pyo3_bindings import PyFromSwarm
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo.shared.types.common import NodeId from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
@@ -8,30 +6,10 @@ from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages""" """Serialisable types for Connection Updates/Messages"""
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class ConnectionMessage(CamelCaseModel): class ConnectionMessage(CamelCaseModel):
node_id: NodeId node_id: NodeId
connection_type: ConnectionMessageType connected: bool
remote_ipv4: str
remote_tcp_port: int
@classmethod @classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage": def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
return cls( return cls(node_id=NodeId(update.peer_id), connected=update.connected)
node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
)

View File

@@ -18,6 +18,7 @@ from exo_pyo3_bindings import (
Keypair, Keypair,
NetworkingHandle, NetworkingHandle,
NoPeersSubscribedToTopicError, NoPeersSubscribedToTopicError,
PyFromSwarm,
) )
from filelock import FileLock from filelock import FileLock
from loguru import logger from loguru import logger
@@ -114,7 +115,6 @@ class Router:
self._tg: TaskGroup | None = None self._tg: TaskGroup | None = None
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]): async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
assert self._tg is None, "Attempted to register topic after setup time"
send = self._tmp_networking_sender send = self._tmp_networking_sender
if send: if send:
self._tmp_networking_sender = None self._tmp_networking_sender = None
@@ -122,7 +122,8 @@ class Router:
send = self.networking_receiver.clone_sender() send = self.networking_receiver.clone_sender()
router = TopicRouter[T](topic, send) router = TopicRouter[T](topic, send)
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router) self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
await self._networking_subscribe(str(topic.topic)) if self._tg is not None:
await self._networking_subscribe(topic.topic)
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]: def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
router = self.topic_routers.get(topic.topic, None) router = self.topic_routers.get(topic.topic, None)
@@ -154,8 +155,10 @@ class Router:
router = self.topic_routers[topic] router = self.topic_routers[topic]
tg.start_soon(router.run) tg.start_soon(router.run)
tg.start_soon(self._networking_recv) tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish) tg.start_soon(self._networking_publish)
# subscribe to pending topics
for topic in self.topic_routers:
await self._networking_subscribe(topic)
# Router only shuts down if you cancel it. # Router only shuts down if you cancel it.
await sleep_forever() await sleep_forever()
finally: finally:
@@ -179,27 +182,33 @@ class Router:
async def _networking_recv(self): async def _networking_recv(self):
while True: while True:
topic, data = await self._net.gossipsub_recv() from_swarm = await self._net.recv()
logger.trace(f"Received message on {topic} with payload {data}") logger.debug(from_swarm)
if topic not in self.topic_routers: match from_swarm:
logger.warning(f"Received message on unknown or inactive topic {topic}") case PyFromSwarm.Message(origin, topic, data):
continue logger.trace(
f"Received message on {topic} from {origin} with payload {data}"
)
if topic not in self.topic_routers:
logger.warning(
f"Received message on unknown or inactive topic {topic}"
)
continue
router = self.topic_routers[topic] router = self.topic_routers[topic]
await router.publish_bytes(data) await router.publish_bytes(data)
case PyFromSwarm.Connection():
async def _networking_recv_connection_messages(self): message = ConnectionMessage.from_update(from_swarm)
while True: logger.trace(
update = await self._net.connection_update_recv() f"Received message on connection_messages with payload {message}"
message = ConnectionMessage.from_update(update) )
logger.trace( if CONNECTION_MESSAGES.topic in self.topic_routers:
f"Received message on connection_messages with payload {message}" router = self.topic_routers[CONNECTION_MESSAGES.topic]
) assert router.topic.model_type == ConnectionMessage
if CONNECTION_MESSAGES.topic in self.topic_routers: router = cast(TopicRouter[ConnectionMessage], router)
router = self.topic_routers[CONNECTION_MESSAGES.topic] await router.publish(message)
assert router.topic.model_type == ConnectionMessage case _:
router = cast(TopicRouter[ConnectionMessage], router) raise AssertionError("exhaustive net messages have been checked")
await router.publish(message)
async def _networking_publish(self): async def _networking_publish(self):
with self.networking_receiver as networked_items: with self.networking_receiver as networked_items:
@@ -221,7 +230,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it. Obtain the :class:`PeerId` by from it.
""" """
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519() return Keypair.generate()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock") return Path(str(path) + ".lock")
@@ -235,12 +244,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read() protobuf_encoded = f.read()
try: # if decoded successfully, save & return try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded) return Keypair.from_bytes(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}") logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist # if no valid credentials, create new ones and persist
with open(path, "w+b") as f: with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519() keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding()) f.write(keypair.to_bytes())
return keypair return keypair

View File

@@ -1,7 +1,7 @@
import pytest import pytest
from anyio import create_task_group, fail_after, move_on_after from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import Election, ElectionMessage, ElectionResult from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId from exo.shared.types.common import NodeId, SessionId
@@ -327,14 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
tg.start_soon(election.run) tg.start_soon(election.run)
# Send any connection message object; we close quickly to cancel before result creation # Send any connection message object; we close quickly to cancel before result creation
await cm_tx.send( await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
ConnectionMessage(
node_id=NodeId(),
connection_type=ConnectionMessageType.Connected,
remote_ipv4="",
remote_tcp_port=0,
)
)
# Expect a broadcast for the new round at clock=1 # Expect a broadcast for the new round at clock=1
while True: while True:

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release() sem.release()
# wait to be told to begin simultaneous read # wait to be told to begin simultaneous read
ev.wait() ev.wait()
queue.put(get_node_id_keypair().to_protobuf_encoding()) queue.put(get_node_id_keypair().to_bytes())
def _get_keypair_concurrent(num_procs: int) -> bytes: def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -98,16 +98,11 @@ class RunnerSupervisor:
def shutdown(self): def shutdown(self):
logger.info("Runner supervisor shutting down") logger.info("Runner supervisor shutting down")
with contextlib.suppress(ClosedResourceError): self._ev_recv.close()
self._ev_recv.close() self._task_sender.close()
with contextlib.suppress(ClosedResourceError): self._event_sender.close()
self._task_sender.close() self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
with contextlib.suppress(ClosedResourceError): self._cancel_sender.close()
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
self.runner_process.join(5) self.runner_process.join(5)
if not self.runner_process.is_alive(): if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated") logger.info("Runner process succesfully terminated")