mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 15:27:02 -05:00
Compare commits
3 Commits
meta-insta
...
move-messa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc3b8392e6 | ||
|
|
4c4c6ce99f | ||
|
|
42e1e7322b |
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2767,6 +2767,7 @@ dependencies = [
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
|
||||
@@ -20,6 +20,7 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -962,6 +963,21 @@ Examples:
|
||||
|
||||
selected.sort(key=_placement_sort_key)
|
||||
preview = selected[0]
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
print("Planning phase: checking downloads...", file=log)
|
||||
run_planning_phase(
|
||||
exo,
|
||||
full_model_id,
|
||||
preview,
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
sharding = str(preview["sharding"])
|
||||
|
||||
@@ -35,6 +35,7 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -332,6 +333,20 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
150
bench/harness.py
150
bench/harness.py
@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
|
||||
return selected
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
import enum
|
||||
import typing
|
||||
|
||||
@typing.final
|
||||
@@ -11,138 +10,33 @@ class AllQueuesFullError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
def generate() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ecdsa() -> Keypair:
|
||||
def from_bytes(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
Construct an Ed25519 keypair from secret key bytes
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
Get the secret key bytes underlying the keypair
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
def to_node_id(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
@@ -161,18 +55,7 @@ class NetworkingHandle:
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
async def recv(self) -> PyFromSwarm: ...
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
@@ -180,42 +63,26 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
class PyFromSwarm:
|
||||
@typing.final
|
||||
class Connection(PyFromSwarm):
|
||||
__match_args__ = ("peer_id", "connected",)
|
||||
@property
|
||||
def peer_id(self) -> builtins.str: ...
|
||||
@property
|
||||
def connected(self) -> builtins.bool: ...
|
||||
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
@typing.final
|
||||
class Message(PyFromSwarm):
|
||||
__match_args__ = ("origin", "topic", "data",)
|
||||
@property
|
||||
def origin(self) -> builtins.str: ...
|
||||
@property
|
||||
def topic(self) -> builtins.str: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
|
||||
|
||||
...
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::types::{PyBytes, PyBytesMethods as _};
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
@@ -18,142 +16,32 @@ pub struct PyKeypair(pub Keypair);
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
/// Construct an Ed25519 keypair from secret key bytes
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
/// Get the secret key bytes underlying the keypair
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self
|
||||
.0
|
||||
.clone()
|
||||
.try_into_ed25519()
|
||||
.pyerr()?
|
||||
.secret()
|
||||
.as_ref()
|
||||
.to_vec();
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
fn to_node_id(&self) -> String {
|
||||
self.0.public().to_peer_id().to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -8,9 +8,10 @@ mod allow_threading;
|
||||
mod ident;
|
||||
mod networking;
|
||||
|
||||
use crate::ident::ident_submodule;
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::networking_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::types::PyModuleMethods;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -154,11 +155,14 @@ pub(crate) mod ext {
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
let mut builder = tokio::runtime::Builder::new_multi_thread();
|
||||
builder.enable_all();
|
||||
pyo3_async_runtimes::tokio::init(builder);
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
m.add_class::<PyKeypair>()?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::{PyKeypair, PyPeerId};
|
||||
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::exception::{PyAllQueuesFullError, PyNoPeersSubscribedToTopicError};
|
||||
use crate::pyclass;
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use futures_lite::StreamExt as _;
|
||||
use libp2p::gossipsub::PublishError;
|
||||
use networking::swarm::{FromSwarm, Swarm, ToSwarm, create_swarm};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{
|
||||
gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
|
||||
};
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
@@ -100,235 +95,45 @@ mod exception {
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
pub to_swarm: mpsc::Sender<ToSwarm>,
|
||||
pub swarm: Arc<Mutex<Swarm>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
#[gen_stub_pyclass_complex_enum]
|
||||
#[pyclass]
|
||||
enum PyFromSwarm {
|
||||
Connection {
|
||||
peer_id: String,
|
||||
connected: bool,
|
||||
},
|
||||
Message {
|
||||
origin: String,
|
||||
topic: String,
|
||||
data: Py<PyBytes>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
impl From<FromSwarm> for PyFromSwarm {
|
||||
fn from(value: FromSwarm) -> Self {
|
||||
match value {
|
||||
FromSwarm::Discovered { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: true,
|
||||
},
|
||||
FromSwarm::Expired { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: false,
|
||||
},
|
||||
FromSwarm::Message { from, topic, data } => Self::Message {
|
||||
origin: from.to_base58(),
|
||||
topic: topic,
|
||||
data: data.pybytes(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
@@ -342,97 +147,36 @@ impl PyNetworkingHandle {
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
|
||||
let swarm = { create_swarm(identity, from_client).pyerr()? };
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
Ok(Self {
|
||||
swarm: Arc::new(Mutex::new(swarm)),
|
||||
to_swarm,
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||
let swarm = Arc::clone(&self.swarm);
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
swarm
|
||||
.try_lock()
|
||||
.map_err(|_| PyRuntimeError::new_err("called recv twice concurrently"))?
|
||||
.next()
|
||||
.await
|
||||
.ok_or(PyErr::receiver_channel_closed())
|
||||
.map(PyFromSwarm::from)
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
@@ -442,10 +186,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -454,6 +198,7 @@ impl PyNetworkingHandle {
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
@@ -463,10 +208,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -485,11 +230,11 @@ impl PyNetworkingHandle {
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -498,74 +243,33 @@ impl PyNetworkingHandle {
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.map_err(|e| match e {
|
||||
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
|
||||
PublishError::NoPeersSubscribedToTopic => {
|
||||
PyNoPeersSubscribedToTopicError::new_err()
|
||||
}
|
||||
e => PyRuntimeError::new_err(e.to_string()),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
pyo3_stub_gen::inventory::submit! {
|
||||
gen_methods_from_python! {
|
||||
r#"
|
||||
class PyNetworkingHandle:
|
||||
async def recv() -> PyFromSwarm: ...
|
||||
"#
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
m.add_class::<PyFromSwarm>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -35,3 +35,4 @@ log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use libp2p::identity;
|
||||
use networking::swarm;
|
||||
use networking::swarm::{FromSwarm, ToSwarm};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
@@ -11,64 +13,68 @@ async fn main() {
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
let (to_swarm, from_client) = mpsc::channel(20);
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
|
||||
.expect("Swarm creation failed");
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
let (tx, rx) = oneshot::channel();
|
||||
_ = to_swarm
|
||||
.send(ToSwarm::Subscribe {
|
||||
topic: "test-net".to_string(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
.expect("should send");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
rx.await
|
||||
.expect("tx not dropped")
|
||||
.expect("subscribe shouldn't fail");
|
||||
loop {
|
||||
if let Ok(Some(line)) = stdin.next_line().await {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Err(e) = to_swarm
|
||||
.send(swarm::ToSwarm::Publish {
|
||||
topic: "test-net".to_string(),
|
||||
data: line.as_bytes().to_vec(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
{
|
||||
println!("Send error: {e:?}");
|
||||
return;
|
||||
};
|
||||
match rx.await {
|
||||
Ok(Err(e)) => println!("Publish error: {e:?}"),
|
||||
Err(e) => println!("Publish error: {e:?}"),
|
||||
Ok(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
// on gossipsub outgoing
|
||||
match swarm.next().await {
|
||||
// on gossipsub incoming
|
||||
Some(FromSwarm::Discovered { peer_id }) => {
|
||||
println!("\n\nconnected to {peer_id}\n\n")
|
||||
}
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
Some(FromSwarm::Expired { peer_id }) => {
|
||||
println!("\n\ndisconnected from {peer_id}\n\n")
|
||||
}
|
||||
Some(FromSwarm::Message { from, topic, data }) => {
|
||||
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
use crate::{alias, discovery};
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use futures_lite::Stream;
|
||||
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
@@ -15,8 +18,144 @@ pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
pub enum ToSwarm {
|
||||
Unsubscribe {
|
||||
topic: String,
|
||||
result_sender: oneshot::Sender<bool>,
|
||||
},
|
||||
Subscribe {
|
||||
topic: String,
|
||||
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
|
||||
},
|
||||
Publish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
|
||||
},
|
||||
}
|
||||
pub enum FromSwarm {
|
||||
Message {
|
||||
from: PeerId,
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
},
|
||||
Discovered {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
Expired {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
}
|
||||
#[pin_project::pin_project]
|
||||
pub struct Swarm {
|
||||
#[pin]
|
||||
inner: libp2p::Swarm<Behaviour>,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
}
|
||||
impl Swarm {
|
||||
fn on_message(mut self: Pin<&mut Self>, message: ToSwarm) {
|
||||
match message {
|
||||
ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
// try to subscribe
|
||||
let result = self
|
||||
.inner
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&gossipsub::IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
_ = result_sender.send(result)
|
||||
}
|
||||
ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = self
|
||||
.inner
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.unsubscribe(&gossipsub::IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
_ = result_sender.send(result)
|
||||
}
|
||||
ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_sender,
|
||||
} => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = self
|
||||
.inner
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.publish(gossipsub::IdentTopic::new(topic), data);
|
||||
// send response oneshot (or exit if connection closed)
|
||||
_ = result_sender.send(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Stream for Swarm {
|
||||
type Item = FromSwarm;
|
||||
fn poll_next(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
let recv = self.as_mut().project().from_client;
|
||||
match recv.poll_recv(cx) {
|
||||
Poll::Ready(Some(msg)) => {
|
||||
self.as_mut().on_message(msg);
|
||||
// continue to re-poll after consumption
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
let inner = self.as_mut().project().inner;
|
||||
return match inner.poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Ready(Some(swarm_event)) => match swarm_event {
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(
|
||||
gossipsub::Event::Message {
|
||||
message:
|
||||
gossipsub::Message {
|
||||
source: Some(peer_id),
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
},
|
||||
)) => Poll::Ready(Some(FromSwarm::Message {
|
||||
from: peer_id,
|
||||
topic: topic.into_string(),
|
||||
data,
|
||||
})),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionEstablished { peer_id, .. },
|
||||
)) => Poll::Ready(Some(FromSwarm::Discovered { peer_id })),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionClosed { peer_id, .. },
|
||||
)) => Poll::Ready(Some(FromSwarm::Expired { peer_id })),
|
||||
// continue to re-poll after consumption
|
||||
_ => continue,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
pub fn create_swarm(
|
||||
keypair: identity::Keypair,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
@@ -25,7 +164,10 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
Ok(Swarm {
|
||||
inner: swarm,
|
||||
from_client,
|
||||
})
|
||||
}
|
||||
|
||||
mod transport {
|
||||
|
||||
@@ -45,7 +45,7 @@ class Node:
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
|
||||
@@ -36,8 +36,6 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -62,7 +60,6 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
|
||||
@@ -97,7 +94,6 @@ class Master:
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -411,11 +407,6 @@ class Master:
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
# After broadcasting JacclSideChannelData, accumulate and
|
||||
# emit gathered result when all runners have contributed.
|
||||
if isinstance(event, JacclSideChannelData):
|
||||
await self._handle_jaccl_side_channel(event)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
@@ -469,42 +460,3 @@ class Master:
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
|
||||
"""Accumulate SideChannel contributions; when all runners for an instance
|
||||
have submitted for the same sequence, emit JacclSideChannelGathered."""
|
||||
iid = event.instance_id
|
||||
seq = event.sequence
|
||||
|
||||
if iid not in self._jaccl_pending:
|
||||
self._jaccl_pending[iid] = {}
|
||||
if seq not in self._jaccl_pending[iid]:
|
||||
self._jaccl_pending[iid][seq] = {}
|
||||
self._jaccl_pending[iid][seq][event.runner_id] = event.data
|
||||
|
||||
instance = self.state.instances.get(iid)
|
||||
if instance is None:
|
||||
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
|
||||
return
|
||||
|
||||
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
|
||||
submitted = set(self._jaccl_pending[iid][seq].keys())
|
||||
|
||||
logger.info(
|
||||
f"JACCL side channel: instance={iid} seq={seq} "
|
||||
f"submitted={len(submitted)}/{len(expected_runners)}"
|
||||
)
|
||||
|
||||
if submitted >= expected_runners:
|
||||
gathered = dict(self._jaccl_pending[iid][seq])
|
||||
del self._jaccl_pending[iid][seq]
|
||||
if not self._jaccl_pending[iid]:
|
||||
del self._jaccl_pending[iid]
|
||||
|
||||
await self.event_sender.send(
|
||||
JacclSideChannelGathered(
|
||||
instance_id=iid,
|
||||
sequence=seq,
|
||||
gathered_data=gathered,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ from exo.utils.channels import channel
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
@@ -75,7 +75,7 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
await local_event_sender.send(
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
||||
from exo_pyo3_bindings import PyFromSwarm
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -8,30 +6,10 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
|
||||
class ConnectionMessageType(Enum):
|
||||
Connected = 0
|
||||
Disconnected = 1
|
||||
|
||||
@staticmethod
|
||||
def from_update_type(update_type: ConnectionUpdateType):
|
||||
match update_type:
|
||||
case ConnectionUpdateType.Connected:
|
||||
return ConnectionMessageType.Connected
|
||||
case ConnectionUpdateType.Disconnected:
|
||||
return ConnectionMessageType.Disconnected
|
||||
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
connection_type: ConnectionMessageType
|
||||
remote_ipv4: str
|
||||
remote_tcp_port: int
|
||||
connected: bool
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
)
|
||||
def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
|
||||
return cls(node_id=NodeId(update.peer_id), connected=update.connected)
|
||||
|
||||
@@ -18,6 +18,7 @@ from exo_pyo3_bindings import (
|
||||
Keypair,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
PyFromSwarm,
|
||||
)
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
@@ -114,7 +115,6 @@ class Router:
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
|
||||
assert self._tg is None, "Attempted to register topic after setup time"
|
||||
send = self._tmp_networking_sender
|
||||
if send:
|
||||
self._tmp_networking_sender = None
|
||||
@@ -122,7 +122,8 @@ class Router:
|
||||
send = self.networking_receiver.clone_sender()
|
||||
router = TopicRouter[T](topic, send)
|
||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||
await self._networking_subscribe(str(topic.topic))
|
||||
if self._tg is not None:
|
||||
await self._networking_subscribe(topic.topic)
|
||||
|
||||
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
||||
router = self.topic_routers.get(topic.topic, None)
|
||||
@@ -154,8 +155,10 @@ class Router:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# subscribe to pending topics
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_subscribe(topic)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
@@ -179,27 +182,33 @@ class Router:
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
topic, data = await self._net.gossipsub_recv()
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(f"Received message on unknown or inactive topic {topic}")
|
||||
continue
|
||||
from_swarm = await self._net.recv()
|
||||
logger.debug(from_swarm)
|
||||
match from_swarm:
|
||||
case PyFromSwarm.Message(origin, topic, data):
|
||||
logger.trace(
|
||||
f"Received message on {topic} from {origin} with payload {data}"
|
||||
)
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(
|
||||
f"Received message on unknown or inactive topic {topic}"
|
||||
)
|
||||
continue
|
||||
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
while True:
|
||||
update = await self._net.connection_update_recv()
|
||||
message = ConnectionMessage.from_update(update)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
case PyFromSwarm.Connection():
|
||||
message = ConnectionMessage.from_update(from_swarm)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
case _:
|
||||
raise AssertionError("exhaustive net messages have been checked")
|
||||
|
||||
async def _networking_publish(self):
|
||||
with self.networking_receiver as networked_items:
|
||||
@@ -221,7 +230,7 @@ def get_node_id_keypair(
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
return Keypair.generate()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
@@ -235,12 +244,12 @@ def get_node_id_keypair(
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
return Keypair.from_bytes(protobuf_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
|
||||
@@ -12,8 +12,6 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -70,8 +68,6 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
| JacclSideChannelData()
|
||||
| JacclSideChannelGathered()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
@@ -327,14 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Send any connection message object; we close quickly to cancel before result creation
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4="",
|
||||
remote_tcp_port=0,
|
||||
)
|
||||
)
|
||||
await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
|
||||
|
||||
# Expect a broadcast for the new round at clock=1
|
||||
while True:
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
||||
sem.release()
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
queue.put(get_node_id_keypair().to_bytes())
|
||||
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Annotated, final
|
||||
from typing import final
|
||||
|
||||
from pydantic import BeforeValidator, Field, PlainSerializer
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
@@ -16,28 +14,6 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
|
||||
|
||||
def _decode_base64_bytes(v: bytes | str) -> bytes:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
return base64.b64decode(v)
|
||||
|
||||
|
||||
def _encode_base64_bytes(v: bytes) -> str:
|
||||
return base64.b64encode(v).decode("ascii")
|
||||
|
||||
|
||||
Base64Bytes = Annotated[
|
||||
bytes,
|
||||
BeforeValidator(_decode_base64_bytes),
|
||||
PlainSerializer(_encode_base64_bytes, return_type=str),
|
||||
]
|
||||
"""bytes that serialize to/from base64 strings in JSON.
|
||||
|
||||
Needed because TaggedModel's wrap validator converts JSON→Python validation
|
||||
context, which breaks strict-mode bytes deserialization from JSON strings.
|
||||
"""
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
"""
|
||||
Newtype around `ID`
|
||||
@@ -163,25 +139,6 @@ class TracesMerged(BaseEvent):
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelData(BaseEvent):
|
||||
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
runner_id: RunnerId
|
||||
sequence: int
|
||||
data: Base64Bytes
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelGathered(BaseEvent):
|
||||
"""Gathered result of a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
sequence: int
|
||||
gathered_data: Mapping[RunnerId, Base64Bytes]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -203,8 +160,6 @@ Event = (
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
| TracesMerged
|
||||
| JacclSideChannelData
|
||||
| JacclSideChannelGathered
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -643,11 +643,6 @@ def mlx_cleanup(
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
"""Synchronize a boolean across all distributed nodes.
|
||||
|
||||
Returns True if any node has bool_=True. Uses all_sum so every
|
||||
node participates in the collective — preventing GPU deadlocks.
|
||||
"""
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
|
||||
@@ -24,7 +24,6 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
JacclSideChannelGathered,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -160,15 +159,6 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
|
||||
if isinstance(event, JacclSideChannelGathered):
|
||||
for runner in self.runners.values():
|
||||
if (
|
||||
runner.bound_instance.instance.instance_id
|
||||
== event.instance_id
|
||||
):
|
||||
runner.notify_gathered(event)
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
|
||||
@@ -17,7 +17,6 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
pipe_fifo_paths: tuple[str, str] | None = None,
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
@@ -31,16 +30,6 @@ def entrypoint(
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
|
||||
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
|
||||
if pipe_fifo_paths is not None:
|
||||
fifo_c2p, fifo_p2c = pipe_fifo_paths
|
||||
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
|
||||
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
|
||||
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
|
||||
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
|
||||
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from multiprocessing import Process
|
||||
from typing import Self
|
||||
|
||||
@@ -18,14 +14,12 @@ from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
@@ -40,26 +34,6 @@ from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
|
||||
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
|
||||
data = b""
|
||||
while len(data) < n:
|
||||
chunk = os.read(fd, n - len(data))
|
||||
if not chunk:
|
||||
return None
|
||||
data += chunk
|
||||
return data
|
||||
|
||||
|
||||
def _pipe_write_all(fd: int, data: bytes) -> None:
|
||||
"""Write all bytes to a file descriptor."""
|
||||
view = memoryview(data)
|
||||
while view:
|
||||
written = os.write(fd, view)
|
||||
view = view[written:]
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
DECODE_TIMEOUT_SECONDS = 5
|
||||
|
||||
@@ -74,19 +48,10 @@ class RunnerSupervisor:
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_pipe_read_fd: int | None = None # Python reads runner's pipe output
|
||||
_pipe_write_fd: int | None = None # Python writes gathered data to runner
|
||||
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
|
||||
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
|
||||
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
|
||||
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
_gathered_waiters: dict[
|
||||
int, tuple[anyio.Event, JacclSideChannelGathered | None]
|
||||
] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -100,23 +65,6 @@ class RunnerSupervisor:
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
|
||||
# Named pipes work across multiprocessing.Process spawn (macOS default).
|
||||
# FIFO c2p: C++ writes local data → Python reads it
|
||||
# FIFO p2c: Python writes gathered data → C++ reads it
|
||||
fifo_dir: str | None = None
|
||||
fifo_c2p: str | None = None
|
||||
fifo_p2c: str | None = None
|
||||
pipe_fifo_paths: tuple[str, str] | None = None
|
||||
|
||||
if isinstance(bound_instance.instance, MlxJacclInstance):
|
||||
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
|
||||
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
|
||||
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
|
||||
os.mkfifo(fifo_c2p)
|
||||
os.mkfifo(fifo_p2c)
|
||||
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
@@ -125,7 +73,6 @@ class RunnerSupervisor:
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
pipe_fifo_paths,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
@@ -141,58 +88,22 @@ class RunnerSupervisor:
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
_fifo_dir=fifo_dir,
|
||||
_fifo_c2p=fifo_c2p,
|
||||
_fifo_p2c=fifo_p2c,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
|
||||
if self._fifo_c2p is not None and self._fifo_p2c is not None:
|
||||
# Open FIFOs from parent side. These block until child opens the other end,
|
||||
# so we run them in threads concurrently to avoid deadlock.
|
||||
fifo_c2p = self._fifo_c2p
|
||||
fifo_p2c = self._fifo_p2c
|
||||
|
||||
async def open_read() -> None:
|
||||
self._pipe_read_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_c2p, os.O_RDONLY)
|
||||
)
|
||||
|
||||
async def open_write() -> None:
|
||||
self._pipe_write_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_p2c, os.O_WRONLY)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as open_tg:
|
||||
open_tg.start_soon(open_read)
|
||||
open_tg.start_soon(open_write)
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(self._pipe_relay)
|
||||
tg.start_soon(self._forward_events)
|
||||
else:
|
||||
await self._forward_events()
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
try:
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
except ClosedResourceError:
|
||||
pass
|
||||
self._event_sender.close()
|
||||
self._close_pipe_fds()
|
||||
self.runner_process.join(1)
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
@@ -229,7 +140,6 @@ class RunnerSupervisor:
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
"""Send a cancellation signal to the runner process."""
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
@@ -271,110 +181,6 @@ class RunnerSupervisor:
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
|
||||
def _close_pipe_fds(self) -> None:
|
||||
if self._pipe_read_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_read_fd)
|
||||
self._pipe_read_fd = None
|
||||
if self._pipe_write_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_write_fd)
|
||||
self._pipe_write_fd = None
|
||||
if self._child_pipe_fds is not None:
|
||||
for fd in self._child_pipe_fds:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
self._child_pipe_fds = None
|
||||
# Clean up FIFO files
|
||||
if self._fifo_c2p is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_c2p)
|
||||
self._fifo_c2p = None
|
||||
if self._fifo_p2c is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_p2c)
|
||||
self._fifo_p2c = None
|
||||
if self._fifo_dir is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.rmdir(self._fifo_dir)
|
||||
self._fifo_dir = None
|
||||
|
||||
async def _pipe_relay(self) -> None:
|
||||
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
|
||||
assert self._pipe_read_fd is not None
|
||||
assert self._pipe_write_fd is not None
|
||||
read_fd = self._pipe_read_fd
|
||||
write_fd = self._pipe_write_fd
|
||||
sequence = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 1. Read local data from runner: [uint32 size][size bytes]
|
||||
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
|
||||
if header is None:
|
||||
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
|
||||
break
|
||||
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
|
||||
local_data = await to_thread.run_sync(
|
||||
partial(_pipe_read_exact, read_fd, data_size)
|
||||
)
|
||||
if local_data is None:
|
||||
logger.warning("JACCL pipe relay: EOF reading data payload")
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
|
||||
)
|
||||
|
||||
# 2. Emit JacclSideChannelData event
|
||||
waiter = anyio.Event()
|
||||
self._gathered_waiters[sequence] = (waiter, None)
|
||||
await self._event_sender.send(
|
||||
JacclSideChannelData(
|
||||
instance_id=self.bound_instance.instance.instance_id,
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
sequence=sequence,
|
||||
data=local_data,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Wait for gathered result
|
||||
await waiter.wait()
|
||||
_, gathered_event = self._gathered_waiters.pop(sequence)
|
||||
assert gathered_event is not None
|
||||
|
||||
# 4. Order gathered data by runner rank and concatenate
|
||||
instance = self.bound_instance.instance
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
|
||||
ordered_data = b"".join(
|
||||
gathered_event.gathered_data[rid] for rid in runner_order
|
||||
)
|
||||
|
||||
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
|
||||
total_size = len(ordered_data)
|
||||
response = struct.pack("<I", total_size) + ordered_data
|
||||
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
|
||||
)
|
||||
sequence += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"JACCL pipe relay: OS error: {e}")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
|
||||
|
||||
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
|
||||
"""Called by the worker when a JacclSideChannelGathered event arrives."""
|
||||
seq = event.sequence
|
||||
if seq not in self._gathered_waiters:
|
||||
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
|
||||
return
|
||||
waiter, _ = self._gathered_waiters[seq]
|
||||
self._gathered_waiters[seq] = (waiter, event)
|
||||
waiter.set()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||
|
||||
Reference in New Issue
Block a user