Compare commits

...

6 Commits

Author SHA1 Message Date
Evan
932878c9e7 add system ids 2026-02-19 17:19:50 +00:00
Evan Quiney
4c4c6ce99f simplify rust ident module
this is partly dead code, partly narrowing the rust-python boundary in
prep for future rewrites. no testing as this is all type safe
refactoring.
2026-02-19 17:19:31 +00:00
Jake Hillion
42e1e7322b bench: restore --danger-delete-downloads planning phase (#1542)
c2f2111b extracted shared utilities from exo_bench.py into harness.py
but accidentally dropped the run_planning_phase function and
--danger-delete-downloads CLI argument in the process.

Restored run_planning_phase in harness.py (where its dependencies now
live) and re-added the --danger-delete-downloads argument to
add_common_instance_args. Re-wired the planning phase call in
exo_bench.py's main() before the benchmark loop.
2026-02-19 15:42:02 +00:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
Mustafa Alp Yılmaz
2e29605194 fix: finalize cancel tasks (#1498)
# Cancel task finalization (main.py)

After forwarding the cancel to the runner supervisor, emit TaskStatusUpdated(Complete) for the cancel task itself. This ensures the cancel task is properly removed from state.tasks.
2026-02-19 13:27:34 +00:00
Evan Quiney
cacb456cb2 remove nightly (#1538)
we have no good need for rust nightly (nor futures, for that matter)
2026-02-19 12:55:31 +00:00
38 changed files with 387 additions and 634 deletions

13
Cargo.lock generated
View File

@@ -890,7 +890,7 @@ dependencies = [
"delegate",
"env_logger",
"extend",
"futures",
"futures-lite",
"libp2p",
"log",
"networking",
@@ -914,6 +914,12 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
version = "0.13.1"
@@ -1022,7 +1028,10 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
@@ -2753,7 +2762,7 @@ dependencies = [
"delegate",
"either",
"extend",
"futures",
"futures-lite",
"futures-timer",
"keccak-const",
"libp2p",

View File

@@ -29,14 +29,13 @@ util = { path = "rust/util" }
# Macro dependecies
extend = "1.2"
delegate = "0.13"
pin-project = "1"
# Utility dependencies
keccak-const = "0.2"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-lite = "2.6.1"
futures-timer = "3.0"
# Data structures

View File

@@ -20,6 +20,7 @@ from harness import (
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
@@ -962,6 +963,21 @@ Examples:
selected.sort(key=_placement_sort_key)
preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])

View File

@@ -35,6 +35,7 @@ from harness import (
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
@@ -332,6 +333,20 @@ def main() -> int:
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

View File

@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -74,7 +74,6 @@
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in

View File

@@ -1,2 +0,0 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -27,7 +27,7 @@ networking = { workspace = true }
# interop
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
"nightly", # enables better-supported GIL integration
# "nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
@@ -45,11 +45,10 @@ pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
futures-lite = { workspace = true }
# utility dependencies
util = { workspace = true }
@@ -60,3 +59,4 @@ env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -19,7 +19,7 @@ class ConnectionUpdate:
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
def peer_id(self) -> builtins.str:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@@ -40,92 +40,22 @@ class Keypair:
Identity keypair of a node.
"""
@staticmethod
def generate_ed25519() -> Keypair:
def generate() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def generate_ecdsa() -> Keypair:
def from_bytes(bytes: bytes) -> Keypair:
r"""
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
Construct an Ed25519 keypair from secret key bytes
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
Get the secret key bytes underlying the keypair
"""
def to_string(self) -> builtins.str:
def to_node_id(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
"""
@typing.final
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""

View File

@@ -2,7 +2,6 @@
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
@@ -26,8 +25,8 @@ where
impl<F> Future for AllowThreads<F>
where
F: Future + Ungil,
F::Output: Ungil,
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;

View File

@@ -0,0 +1,47 @@
use crate::ext::ResultExt as _;
use libp2p::identity::Keypair;
use pyo3::types::{PyBytes, PyBytesMethods as _};
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Construct an Ed25519 keypair from secret key bytes
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Get the secret key bytes underlying the keypair
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
fn to_node_id(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}

View File

@@ -4,26 +4,14 @@
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
mod ident;
mod networking;
use crate::ident::PyKeypair;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -32,14 +20,6 @@ pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
@@ -179,8 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
m.add_class::<PyKeypair>()?;
networking_submodule(m)?;
// top-level constructs

View File

@@ -8,8 +8,8 @@
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair;
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
peer_id: String,
/// Remote connection's IPv4 address.
#[pyo3(get)]
@@ -251,7 +251,7 @@ async fn networking_task(
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
@@ -272,7 +272,7 @@ async fn networking_task(
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {

View File

@@ -1,159 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -1,8 +0,0 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -1,81 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -22,7 +22,7 @@ delegate = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-lite = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies

View File

@@ -1,4 +1,4 @@
use futures::stream::StreamExt as _;
use futures_lite::StreamExt;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
@@ -38,19 +38,19 @@ async fn main() {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
event = swarm.next() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
}))) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
@@ -64,7 +64,7 @@ async fn main() {
}
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }

View File

@@ -1,127 +0,0 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::error::Error;
use std::time::Duration;
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -1,7 +1,7 @@
use crate::ext::MultiaddrExt;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_lite::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
if self.retry_delay.poll(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)

View File

@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures::{AsyncRead, AsyncWrite};
use futures_lite::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;

View File

@@ -1,11 +1,10 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
{ inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
rustToolchain = inputs'.fenix.packages.stable.withComponents [
"cargo"
"rustc"
"clippy"

View File

@@ -1,2 +0,0 @@
[toolchain]
channel = "nightly"

View File

@@ -1,7 +1,6 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
@@ -22,10 +21,10 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
ForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -45,9 +44,9 @@ class DownloadCoordinator:
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
local_event_sender: Sender[LocalForwarderEvent]
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -298,15 +297,16 @@ class DownloadCoordinator:
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self.node_id,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)

View File

@@ -1,11 +1,10 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Iterator, Self
from typing import Self
import anyio
from anyio.abc import TaskGroup
@@ -38,14 +37,13 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
async def create(cls, args: "Args") -> "Self":
async def create(cls, args: "Args") -> Self:
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)
@@ -57,9 +55,6 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
@@ -68,7 +63,6 @@ class Node:
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
@@ -95,7 +89,6 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -133,7 +126,6 @@ class Node:
master,
api,
node_id,
event_index_counter,
args.offline,
)
@@ -212,8 +204,6 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
@@ -224,7 +214,6 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
@@ -242,7 +231,6 @@ class Node:
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
@@ -219,13 +225,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
yield _format_sse(created_event)
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
yield _format_sse(in_progress_event)
# response.output_item.added
initial_item = ResponseMessageItem(
@@ -236,7 +242,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
yield _format_sse(item_added)
# response.content_part.added
initial_part = ResponseOutputText(text="")
@@ -247,7 +253,7 @@ async def generate_responses_stream(
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
yield _format_sse(part_added)
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
@@ -281,7 +287,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
yield _format_sse(fc_added)
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -290,7 +296,7 @@ async def generate_responses_stream(
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
yield _format_sse(args_delta)
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -300,7 +306,7 @@ async def generate_responses_stream(
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
yield _format_sse(args_done)
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
@@ -315,7 +321,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item)
next_output_index += 1
@@ -331,7 +337,7 @@ async def generate_responses_stream(
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
yield _format_sse(delta_event)
# response.output_text.done
text_done = ResponseTextDoneEvent(
@@ -341,7 +347,7 @@ async def generate_responses_stream(
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
yield _format_sse(text_done)
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
@@ -352,7 +358,7 @@ async def generate_responses_stream(
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
yield _format_sse(part_done)
# response.output_item.done
final_message_item = ResponseMessageItem(
@@ -363,7 +369,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
yield _format_sse(item_done)
# Create usage from usage data if available
usage = None
@@ -388,4 +394,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
yield _format_sse(completed_event)

View File

@@ -132,11 +132,11 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
GlobalForwarderEvent,
IndexedEvent,
PrefillProgress,
TracesMerged,
@@ -177,8 +177,7 @@ class API:
session_id: SessionId,
*,
port: int,
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
global_event_receiver: Receiver[GlobalForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
@@ -186,6 +185,7 @@ class API:
) -> None:
self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
@@ -237,6 +237,7 @@ class API:
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
@@ -554,7 +555,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -902,7 +903,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -988,7 +989,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -1429,6 +1430,8 @@ class API:
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -1508,12 +1511,12 @@ class API:
while self.paused:
await self.paused_ev.wait()
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
ForwarderDownloadCommand(origin=self._system_id, command=command)
)
async def start_download(

View File

@@ -29,13 +29,14 @@ from exo.shared.types.commands import (
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
ForwarderEvent,
GlobalForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
LocalForwarderEvent,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
@@ -71,8 +72,8 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
local_event_receiver: Receiver[ForwarderEvent],
global_event_sender: Sender[ForwarderEvent],
local_event_receiver: Receiver[LocalForwarderEvent],
global_event_sender: Sender[GlobalForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
@@ -87,10 +88,11 @@ class Master:
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[ForwarderEvent] = (
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
@@ -288,7 +290,7 @@ class Master:
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
origin=self._system_id, command=cmd
)
)
generated_events.extend(transition_events)
@@ -414,8 +416,8 @@ class Master:
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
LocalForwarderEvent(
origin=self._system_id,
origin_idx=local_index,
session=self.session_id,
event=event,
@@ -427,7 +429,7 @@ class Master:
async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly
await self.global_event_sender.send(
ForwarderEvent(
GlobalForwarderEvent(
origin=self.node_id,
origin_idx=event.idx,
session=self.session_id,

View File

@@ -15,11 +15,12 @@ from exo.shared.types.commands import (
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ForwarderEvent,
GlobalForwarderEvent,
IndexedEvent,
InstanceCreated,
LocalForwarderEvent,
NodeGatheredInfo,
TaskCreated,
)
@@ -42,12 +43,12 @@ from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]()
ge_sender, global_event_receiver = channel[GlobalForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = []
@@ -75,13 +76,12 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(
ForwarderEvent(
LocalForwarderEvent(
origin_idx=0,
origin=sender_node_id,
origin=SystemId("Worker"),
session=session_id,
event=(
NodeGatheredInfo(
@@ -108,7 +108,7 @@ async def test_master():
logger.info("inject a CreateInstance Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
origin=SystemId("API"),
command=(
PlaceInstance(
command_id=CommandId(),
@@ -133,7 +133,7 @@ async def test_master():
logger.info("inject a TextGeneration Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
origin=SystemId("API"),
command=(
TextGeneration(
command_id=CommandId(),

View File

@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id.to_base58()),
node_id=NodeId(update.peer_id),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,

View File

@@ -221,7 +221,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
return Keypair.generate()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
@@ -235,12 +235,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded)
return Keypair.from_bytes(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
f.write(keypair.to_bytes())
return keypair

View File

@@ -5,7 +5,8 @@ from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
ForwarderEvent,
GlobalForwarderEvent,
LocalForwarderEvent,
)
from exo.utils.pydantic_ext import CamelCaseModel
@@ -36,8 +37,8 @@ class TypedTopic[T: CamelCaseModel]:
return self.model_type.model_validate_json(b.decode("utf-8"))
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent)
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent)
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, GlobalForwarderEvent)
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, LocalForwarderEvent)
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
ELECTION_MESSAGES = TypedTopic(
"election_messages", PublishPolicy.Always, ElectionMessage

View File

@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.utils.channels import channel
# ======= #
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
# Pump local commands so our commands_seen is high before the round starts
for _ in range(50):
await co_tx.send(
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
)
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release()
# wait to be told to begin simultaneous read
ev.wait()
queue.put(get_node_id_keypair().to_protobuf_encoding())
queue.put(get_node_id_keypair().to_bytes())
def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -100,10 +100,10 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
origin: SystemId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
origin: SystemId
command: DownloadCommand

View File

@@ -25,6 +25,10 @@ class NodeId(Id):
pass
class SystemId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId, SystemId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -170,10 +170,19 @@ class IndexedEvent(CamelCaseModel):
event: Event
class ForwarderEvent(CamelCaseModel):
class GlobalForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: NodeId
session: SessionId
event: Event
class LocalForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: SystemId
session: SessionId
event: Event

View File

@@ -1,7 +1,6 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
@@ -17,13 +16,14 @@ from exo.shared.types.commands import (
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
GlobalForwarderEvent,
IndexedEvent,
InputChunkReceived,
LocalForwarderEvent,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -58,24 +58,22 @@ class Worker:
node_id: NodeId,
session_id: SessionId,
*,
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
global_event_receiver: Receiver[GlobalForwarderEvent],
local_event_sender: Sender[LocalForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {}
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
@@ -86,6 +84,8 @@ class Worker:
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
@@ -132,6 +132,8 @@ class Worker:
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -212,7 +214,7 @@ class Worker:
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
origin=self._system_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
@@ -241,6 +243,11 @@ class Worker:
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -312,7 +319,7 @@ class Worker:
)
await self.command_sender.send(
ForwarderCommand(
origin=self.node_id,
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
@@ -339,15 +346,16 @@ class Worker:
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self.node_id,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe