Compare commits

..

2 Commits

Author SHA1 Message Date
Alex Cheema
21c363e997 fix: move suppress(ClosedResourceError) inside runner.shutdown() per review
Move the ClosedResourceError suppression from the two call sites in
worker/main.py into RunnerSupervisor.shutdown() itself, so each
close/send on already-closed channels is individually guarded.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 09:01:54 -08:00
Alex Cheema
b1c0e3116d fix: misc bug fixes (spawn force, download restart, shutdown guard)
Three independent fixes extracted from meta-instance branch (#1519):

- Use force=True for mp.set_start_method("spawn") to prevent errors
  when the start method was already set by another initialization path
- Detect already-complete downloads on restart instead of reporting them
  as DownloadPending (checks downloaded_bytes >= total_bytes)
- Guard runner.shutdown() with contextlib.suppress(ClosedResourceError)
  to handle already-closed resources during worker teardown

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 05:37:11 -08:00
23 changed files with 358 additions and 328 deletions

View File

@@ -20,7 +20,6 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -963,21 +962,6 @@ Examples:
selected.sort(key=_placement_sort_key) selected.sort(key=_placement_sort_key)
preview = selected[0] preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"] instance = preview["instance"]
instance_id = instance_id_from_instance(instance) instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"]) sharding = str(preview["sharding"])

View File

@@ -35,7 +35,6 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -333,20 +332,6 @@ def main() -> int:
if args.dry_run: if args.dry_run:
return 0 return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = [] all_rows: list[dict[str, Any]] = []
for preview in selected: for preview in selected:

View File

@@ -282,151 +282,6 @@ def settle_and_fetch_placements(
return selected return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None: def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost")) ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument( ap.add_argument(
@@ -470,8 +325,3 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0, default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).", help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
) )
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -19,7 +19,7 @@ class ConnectionUpdate:
Whether this is a connection or disconnection event Whether this is a connection or disconnection event
""" """
@property @property
def peer_id(self) -> builtins.str: def peer_id(self) -> PeerId:
r""" r"""
Identity of the peer that we have connected to or disconnected from. Identity of the peer that we have connected to or disconnected from.
""" """
@@ -40,22 +40,92 @@ class Keypair:
Identity keypair of a node. Identity keypair of a node.
""" """
@staticmethod @staticmethod
def generate() -> Keypair: def generate_ed25519() -> Keypair:
r""" r"""
Generate a new Ed25519 keypair. Generate a new Ed25519 keypair.
""" """
@staticmethod @staticmethod
def from_bytes(bytes: bytes) -> Keypair: def generate_ecdsa() -> Keypair:
r""" r"""
Construct an Ed25519 keypair from secret key bytes Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
""" """
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
r""" r"""
Get the secret key bytes underlying the keypair Return a copy of this [`Multiaddr`]'s byte representation.
""" """
def to_node_id(self) -> builtins.str: def to_string(self) -> builtins.str:
r""" r"""
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`. Convert a Multiaddr to a string.
""" """
@typing.final @typing.final
@@ -110,6 +180,37 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ... def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final @typing.final
class ConnectionUpdateType(enum.Enum): class ConnectionUpdateType(enum.Enum):
r""" r"""

View File

@@ -1,6 +1,8 @@
use crate::ext::ResultExt as _; use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair; use libp2p::identity::Keypair;
use pyo3::types::{PyBytes, PyBytesMethods as _}; use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods}; use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
@@ -16,32 +18,142 @@ pub struct PyKeypair(pub Keypair);
impl PyKeypair { impl PyKeypair {
/// Generate a new Ed25519 keypair. /// Generate a new Ed25519 keypair.
#[staticmethod] #[staticmethod]
fn generate() -> Self { fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519()) Self(Keypair::generate_ed25519())
} }
/// Construct an Ed25519 keypair from secret key bytes /// Generate a new ECDSA keypair.
#[staticmethod] #[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> { fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes()); let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?)) Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
} }
/// Get the secret key bytes underlying the keypair /// Encode a private key as protobuf structure.
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> { fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self let bytes = self.0.to_protobuf_encoding().pyerr()?;
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes)) Ok(PyBytes::new(py, &bytes))
} }
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`. /// Convert the `Keypair` into the corresponding `PeerId`.
fn to_node_id(&self) -> String { fn to_peer_id(&self) -> PyPeerId {
self.0.public().to_peer_id().to_base58() PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
} }
} }
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -8,10 +8,9 @@ mod allow_threading;
mod ident; mod ident;
mod networking; mod networking;
use crate::ident::PyKeypair; use crate::ident::ident_submodule;
use crate::networking::networking_submodule; use crate::networking::networking_submodule;
use pyo3::prelude::PyModule; use pyo3::prelude::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule}; use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer; use pyo3_stub_gen::define_stub_info_gatherer;
@@ -159,7 +158,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system // TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without // work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues... // too many importing issues...
m.add_class::<PyKeypair>()?; ident_submodule(m)?;
networking_submodule(m)?; networking_submodule(m)?;
// top-level constructs // top-level constructs

View File

@@ -8,7 +8,7 @@
use crate::r#const::MPSC_CHANNEL_SIZE; use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _}; use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _}; use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair; use crate::ident::{PyKeypair, PyPeerId};
use crate::pyclass; use crate::pyclass;
use libp2p::futures::StreamExt as _; use libp2p::futures::StreamExt as _;
use libp2p::gossipsub; use libp2p::gossipsub;
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
/// Identity of the peer that we have connected to or disconnected from. /// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)] #[pyo3(get)]
peer_id: String, peer_id: PyPeerId,
/// Remote connection's IPv4 address. /// Remote connection's IPv4 address.
#[pyo3(get)] #[pyo3(get)]
@@ -251,7 +251,7 @@ async fn networking_task(
// send connection event to channel (or exit if connection closed) // send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate { if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected, update_type: PyConnectionUpdateType::Connected,
peer_id: peer_id.to_base58(), peer_id: PyPeerId(peer_id),
remote_ipv4, remote_ipv4,
remote_tcp_port, remote_tcp_port,
}).await { }).await {
@@ -272,7 +272,7 @@ async fn networking_task(
// send disconnection event to channel (or exit if connection closed) // send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate { if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected, update_type: PyConnectionUpdateType::Disconnected,
peer_id: peer_id.to_base58(), peer_id: PyPeerId(peer_id),
remote_ipv4, remote_ipv4,
remote_tcp_port, remote_tcp_port,
}).await { }).await {

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import socket import socket
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Iterator
import anyio import anyio
from anyio import current_time from anyio import current_time
@@ -21,10 +22,10 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand, ForwarderDownloadCommand,
StartDownload, StartDownload,
) )
from exo.shared.types.common import NodeId, SessionId, SystemId from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
Event, Event,
LocalForwarderEvent, ForwarderEvent,
NodeDownloadProgress, NodeDownloadProgress,
) )
from exo.shared.types.worker.downloads import ( from exo.shared.types.worker.downloads import (
@@ -44,9 +45,9 @@ class DownloadCoordinator:
session_id: SessionId session_id: SessionId
shard_downloader: ShardDownloader shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand] download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[LocalForwarderEvent] local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state # Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict) download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -297,16 +298,15 @@ class DownloadCoordinator:
del self.download_status[model_id] del self.download_status[model_id]
async def _forward_events(self) -> None: async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events: with self.event_receiver as events:
async for event in events: async for event in events:
fe = LocalForwarderEvent( idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx, origin_idx=idx,
origin=self._system_id, origin=self.node_id,
session=self.session_id, session=self.session_id,
event=event, event=event,
) )
idx += 1
logger.debug( logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}" f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
) )
@@ -338,7 +338,17 @@ class DownloadCoordinator:
), ),
) )
elif progress.status in ["in_progress", "not_started"]: elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0: if (
progress.downloaded_bytes.in_bytes
>= progress.total_bytes.in_bytes
> 0
):
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending( status = DownloadPending(
node_id=self.node_id, node_id=self.node_id,
shard_metadata=progress.shard, shard_metadata=progress.shard,

View File

@@ -1,10 +1,11 @@
import argparse import argparse
import itertools
import multiprocessing as mp import multiprocessing as mp
import os import os
import resource import resource
import signal import signal
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Self from typing import Iterator, Self
import anyio import anyio
from anyio.abc import TaskGroup from anyio.abc import TaskGroup
@@ -37,13 +38,14 @@ class Node:
api: API | None api: API | None
node_id: NodeId node_id: NodeId
event_index_counter: Iterator[int]
offline: bool offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group) _tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod @classmethod
async def create(cls, args: "Args") -> Self: async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id()) node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair) router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS) await router.register_topic(topics.GLOBAL_EVENTS)
@@ -55,6 +57,9 @@ class Node:
logger.info(f"Starting node {node_id}") logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads) # Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads: if not args.no_downloads:
download_coordinator = DownloadCoordinator( download_coordinator = DownloadCoordinator(
@@ -63,6 +68,7 @@ class Node:
exo_shard_downloader(), exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS), local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline, offline=args.offline,
) )
else: else:
@@ -89,6 +95,7 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS), local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS), command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
) )
else: else:
worker = None worker = None
@@ -126,6 +133,7 @@ class Node:
master, master,
api, api,
node_id, node_id,
event_index_counter,
args.offline, args.offline,
) )
@@ -204,6 +212,8 @@ class Node:
) )
if result.is_new_master: if result.is_new_master:
await anyio.sleep(0) await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator: if self.download_coordinator:
self.download_coordinator.shutdown() self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator( self.download_coordinator = DownloadCoordinator(
@@ -214,6 +224,7 @@ class Node:
topics.DOWNLOAD_COMMANDS topics.DOWNLOAD_COMMANDS
), ),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS), local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline, offline=self.offline,
) )
self._tg.start_soon(self.download_coordinator.run) self._tg.start_soon(self.download_coordinator.run)
@@ -231,6 +242,7 @@ class Node:
download_command_sender=self.router.sender( download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS topics.DOWNLOAD_COMMANDS
), ),
event_index_counter=self.event_index_counter,
) )
self._tg.start_soon(self.worker.run) self._tg.start_soon(self.worker.run)
if self.api: if self.api:
@@ -246,7 +258,7 @@ def main():
target = min(max(soft, 65535), hard) target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard)) resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn") mp.set_start_method("spawn", force=True)
# TODO: Refactor the current verbosity system # TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity) logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO") logger.info("Starting EXO")

View File

@@ -31,7 +31,6 @@ from exo.shared.types.openai_responses import (
ResponseOutputText, ResponseOutputText,
ResponsesRequest, ResponsesRequest,
ResponsesResponse, ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseTextDoneEvent,
ResponseUsage, ResponseUsage,
@@ -39,11 +38,6 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str: def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts.""" """Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str): if isinstance(content, str):
@@ -225,13 +219,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent( created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield _format_sse(created_event) yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress # response.in_progress
in_progress_event = ResponseInProgressEvent( in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield _format_sse(in_progress_event) yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added # response.output_item.added
initial_item = ResponseMessageItem( initial_item = ResponseMessageItem(
@@ -242,7 +236,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent( item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item sequence_number=next(seq), output_index=0, item=initial_item
) )
yield _format_sse(item_added) yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added # response.content_part.added
initial_part = ResponseOutputText(text="") initial_part = ResponseOutputText(text="")
@@ -253,7 +247,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
part=initial_part, part=initial_part,
) )
yield _format_sse(part_added) yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = "" accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = [] function_call_items: list[ResponseFunctionCallItem] = []
@@ -287,7 +281,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_item, item=fc_item,
) )
yield _format_sse(fc_added) yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
# response.function_call_arguments.delta # response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent( args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -296,7 +290,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
delta=tool.arguments, delta=tool.arguments,
) )
yield _format_sse(args_delta) yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
# response.function_call_arguments.done # response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent( args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -306,7 +300,7 @@ async def generate_responses_stream(
name=tool.name, name=tool.name,
arguments=tool.arguments, arguments=tool.arguments,
) )
yield _format_sse(args_done) yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
# response.output_item.done # response.output_item.done
fc_done_item = ResponseFunctionCallItem( fc_done_item = ResponseFunctionCallItem(
@@ -321,7 +315,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_done_item, item=fc_done_item,
) )
yield _format_sse(fc_item_done) yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
function_call_items.append(fc_done_item) function_call_items.append(fc_done_item)
next_output_index += 1 next_output_index += 1
@@ -337,7 +331,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
delta=chunk.text, delta=chunk.text,
) )
yield _format_sse(delta_event) yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done # response.output_text.done
text_done = ResponseTextDoneEvent( text_done = ResponseTextDoneEvent(
@@ -347,7 +341,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
text=accumulated_text, text=accumulated_text,
) )
yield _format_sse(text_done) yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done # response.content_part.done
final_part = ResponseOutputText(text=accumulated_text) final_part = ResponseOutputText(text=accumulated_text)
@@ -358,7 +352,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
part=final_part, part=final_part,
) )
yield _format_sse(part_done) yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done # response.output_item.done
final_message_item = ResponseMessageItem( final_message_item = ResponseMessageItem(
@@ -369,7 +363,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent( item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item sequence_number=next(seq), output_index=0, item=final_message_item
) )
yield _format_sse(item_done) yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from usage data if available # Create usage from usage data if available
usage = None usage = None
@@ -394,4 +388,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent( completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response sequence_number=next(seq), response=final_response
) )
yield _format_sse(completed_event) yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -132,11 +132,11 @@ from exo.shared.types.commands import (
TaskFinished, TaskFinished,
TextGeneration, TextGeneration,
) )
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
ChunkGenerated, ChunkGenerated,
Event, Event,
GlobalForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
PrefillProgress, PrefillProgress,
TracesMerged, TracesMerged,
@@ -177,7 +177,8 @@ class API:
session_id: SessionId, session_id: SessionId,
*, *,
port: int, port: int,
global_event_receiver: Receiver[GlobalForwarderEvent], # Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand], command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand], download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running # This lets us pause the API if an election is running
@@ -185,7 +186,6 @@ class API:
) -> None: ) -> None:
self.state = State() self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR) self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender self.command_sender = command_sender
self.download_command_sender = download_command_sender self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver self.global_event_receiver = global_event_receiver
@@ -237,7 +237,6 @@ class API:
self._event_log.close() self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR) self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State() self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]() self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {} self._text_generation_queues = {}
@@ -555,7 +554,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id) command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command) ForwarderCommand(origin=self.node_id, command=command)
) )
raise raise
finally: finally:
@@ -903,7 +902,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id) command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command) ForwarderCommand(origin=self.node_id, command=command)
) )
raise raise
finally: finally:
@@ -989,7 +988,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id) command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command) ForwarderCommand(origin=self.node_id, command=command)
) )
raise raise
finally: finally:
@@ -1430,8 +1429,6 @@ class API:
async def _apply_state(self): async def _apply_state(self):
with self.global_event_receiver as events: with self.global_event_receiver as events:
async for f_event in events: async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id: if f_event.origin != self.session_id.master_node_id:
continue continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event) self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -1511,12 +1508,12 @@ class API:
while self.paused: while self.paused:
await self.paused_ev.wait() await self.paused_ev.wait()
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command) ForwarderCommand(origin=self.node_id, command=command)
) )
async def _send_download(self, command: DownloadCommand): async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send( await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self._system_id, command=command) ForwarderDownloadCommand(origin=self.node_id, command=command)
) )
async def start_download( async def start_download(

View File

@@ -29,14 +29,13 @@ from exo.shared.types.commands import (
TestCommand, TestCommand,
TextGeneration, TextGeneration,
) )
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
Event, Event,
GlobalForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
InputChunkReceived, InputChunkReceived,
InstanceDeleted, InstanceDeleted,
LocalForwarderEvent,
NodeGatheredInfo, NodeGatheredInfo,
NodeTimedOut, NodeTimedOut,
TaskCreated, TaskCreated,
@@ -72,8 +71,8 @@ class Master:
session_id: SessionId, session_id: SessionId,
*, *,
command_receiver: Receiver[ForwarderCommand], command_receiver: Receiver[ForwarderCommand],
local_event_receiver: Receiver[LocalForwarderEvent], local_event_receiver: Receiver[ForwarderEvent],
global_event_sender: Sender[GlobalForwarderEvent], global_event_sender: Sender[ForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand], download_command_sender: Sender[ForwarderDownloadCommand],
): ):
self.state = State() self.state = State()
@@ -88,11 +87,10 @@ class Master:
send, recv = channel[Event]() send, recv = channel[Event]()
self.event_sender: Sender[Event] = send self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[LocalForwarderEvent] = ( self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender() local_event_receiver.clone_sender()
) )
self._system_id = SystemId() self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master") self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {} self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {} self._expected_ranks: dict[TaskId, set[int]] = {}
@@ -290,7 +288,7 @@ class Master:
): ):
await self.download_command_sender.send( await self.download_command_sender.send(
ForwarderDownloadCommand( ForwarderDownloadCommand(
origin=self._system_id, command=cmd origin=self.node_id, command=cmd
) )
) )
generated_events.extend(transition_events) generated_events.extend(transition_events)
@@ -416,8 +414,8 @@ class Master:
with self._loopback_event_receiver as events: with self._loopback_event_receiver as events:
async for event in events: async for event in events:
await self._loopback_event_sender.send( await self._loopback_event_sender.send(
LocalForwarderEvent( ForwarderEvent(
origin=self._system_id, origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index, origin_idx=local_index,
session=self.session_id, session=self.session_id,
event=event, event=event,
@@ -429,7 +427,7 @@ class Master:
async def _send_event(self, event: IndexedEvent): async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly # Convenience method since this line is ugly
await self.global_event_sender.send( await self.global_event_sender.send(
GlobalForwarderEvent( ForwarderEvent(
origin=self.node_id, origin=self.node_id,
origin_idx=event.idx, origin_idx=event.idx,
session=self.session_id, session=self.session_id,

View File

@@ -15,12 +15,11 @@ from exo.shared.types.commands import (
PlaceInstance, PlaceInstance,
TextGeneration, TextGeneration,
) )
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
GlobalForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
InstanceCreated, InstanceCreated,
LocalForwarderEvent,
NodeGatheredInfo, NodeGatheredInfo,
TaskCreated, TaskCreated,
) )
@@ -43,12 +42,12 @@ from exo.utils.channels import channel
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_master(): async def test_master():
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id()) node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[GlobalForwarderEvent]() ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]() command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[LocalForwarderEvent]() local_event_sender, le_receiver = channel[ForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]() fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = [] all_events: list[IndexedEvent] = []
@@ -76,12 +75,13 @@ async def test_master():
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(master.run) tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event # inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event") logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send( await local_event_sender.send(
LocalForwarderEvent( ForwarderEvent(
origin_idx=0, origin_idx=0,
origin=SystemId("Worker"), origin=sender_node_id,
session=session_id, session=session_id,
event=( event=(
NodeGatheredInfo( NodeGatheredInfo(
@@ -108,7 +108,7 @@ async def test_master():
logger.info("inject a CreateInstance Command") logger.info("inject a CreateInstance Command")
await command_sender.send( await command_sender.send(
ForwarderCommand( ForwarderCommand(
origin=SystemId("API"), origin=node_id,
command=( command=(
PlaceInstance( PlaceInstance(
command_id=CommandId(), command_id=CommandId(),
@@ -133,7 +133,7 @@ async def test_master():
logger.info("inject a TextGeneration Command") logger.info("inject a TextGeneration Command")
await command_sender.send( await command_sender.send(
ForwarderCommand( ForwarderCommand(
origin=SystemId("API"), origin=node_id,
command=( command=(
TextGeneration( TextGeneration(
command_id=CommandId(), command_id=CommandId(),

View File

@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
@classmethod @classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage": def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls( return cls(
node_id=NodeId(update.peer_id), node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type), connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4, remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port, remote_tcp_port=update.remote_tcp_port,

View File

@@ -221,7 +221,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it. Obtain the :class:`PeerId` by from it.
""" """
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate() return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock") return Path(str(path) + ".lock")
@@ -235,12 +235,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read() protobuf_encoded = f.read()
try: # if decoded successfully, save & return try: # if decoded successfully, save & return
return Keypair.from_bytes(protobuf_encoded) return Keypair.from_protobuf_encoding(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}") logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist # if no valid credentials, create new ones and persist
with open(path, "w+b") as f: with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519() keypair = Keypair.generate_ed25519()
f.write(keypair.to_bytes()) f.write(keypair.to_protobuf_encoding())
return keypair return keypair

View File

@@ -5,8 +5,7 @@ from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import ( from exo.shared.types.events import (
GlobalForwarderEvent, ForwarderEvent,
LocalForwarderEvent,
) )
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
@@ -37,8 +36,8 @@ class TypedTopic[T: CamelCaseModel]:
return self.model_type.model_validate_json(b.decode("utf-8")) return self.model_type.model_validate_json(b.decode("utf-8"))
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, GlobalForwarderEvent) GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent)
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, LocalForwarderEvent) LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent)
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand) COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
ELECTION_MESSAGES = TypedTopic( ELECTION_MESSAGES = TypedTopic(
"election_messages", PublishPolicy.Always, ElectionMessage "election_messages", PublishPolicy.Always, ElectionMessage

View File

@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId, SystemId from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import channel from exo.utils.channels import channel
# ======= # # ======= #
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
# Pump local commands so our commands_seen is high before the round starts # Pump local commands so our commands_seen is high before the round starts
for _ in range(50): for _ in range(50):
await co_tx.send( await co_tx.send(
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand()) ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
) )
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands # Trigger a round at clock=1 with a peer of equal seniority but fewer commands

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release() sem.release()
# wait to be told to begin simultaneous read # wait to be told to begin simultaneous read
ev.wait() ev.wait()
queue.put(get_node_id_keypair().to_bytes()) queue.put(get_node_id_keypair().to_protobuf_encoding())
def _get_keypair_concurrent(num_procs: int) -> bytes: def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams, ImageGenerationTaskParams,
) )
from exo.shared.types.chunks import InputImageChunk from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId, SystemId from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -100,10 +100,10 @@ Command = (
class ForwarderCommand(CamelCaseModel): class ForwarderCommand(CamelCaseModel):
origin: SystemId origin: NodeId
command: Command command: Command
class ForwarderDownloadCommand(CamelCaseModel): class ForwarderDownloadCommand(CamelCaseModel):
origin: SystemId origin: NodeId
command: DownloadCommand command: DownloadCommand

View File

@@ -25,10 +25,6 @@ class NodeId(Id):
pass pass
class SystemId(Id):
pass
class ModelId(Id): class ModelId(Id):
def normalize(self) -> str: def normalize(self) -> str:
return self.replace("/", "--") return self.replace("/", "--")

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId, SystemId from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId from exo.shared.types.worker.instances import Instance, InstanceId
@@ -170,19 +170,10 @@ class IndexedEvent(CamelCaseModel):
event: Event event: Event
class GlobalForwarderEvent(CamelCaseModel): class ForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network""" """An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0) origin_idx: int = Field(ge=0)
origin: NodeId origin: NodeId
session: SessionId session: SessionId
event: Event event: Event
class LocalForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: SystemId
session: SessionId
event: Event

View File

@@ -1,6 +1,7 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from random import random from random import random
from typing import Iterator
import anyio import anyio
from anyio import CancelScope, create_task_group, fail_after from anyio import CancelScope, create_task_group, fail_after
@@ -16,14 +17,13 @@ from exo.shared.types.commands import (
RequestEventLog, RequestEventLog,
StartDownload, StartDownload,
) )
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
Event, Event,
EventId, EventId,
GlobalForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
InputChunkReceived, InputChunkReceived,
LocalForwarderEvent,
NodeGatheredInfo, NodeGatheredInfo,
TaskCreated, TaskCreated,
TaskStatusUpdated, TaskStatusUpdated,
@@ -58,22 +58,24 @@ class Worker:
node_id: NodeId, node_id: NodeId,
session_id: SessionId, session_id: SessionId,
*, *,
global_event_receiver: Receiver[GlobalForwarderEvent], global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[LocalForwarderEvent], local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now, # This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands # but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand], command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand], download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
): ):
self.node_id: NodeId = node_id self.node_id: NodeId = node_id
self.session_id: SessionId = session_id self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.command_sender = command_sender self.command_sender = command_sender
self.download_command_sender = download_command_sender self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]() self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {} self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State() self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {} self.runners: dict[RunnerId, RunnerSupervisor] = {}
@@ -84,8 +86,6 @@ class Worker:
self._nack_base_seconds: float = 0.5 self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0 self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]() self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing) # Buffer for input image chunks (for image editing)
@@ -132,8 +132,6 @@ class Worker:
async def _event_applier(self): async def _event_applier(self):
with self.global_event_receiver as events: with self.global_event_receiver as events:
async for f_event in events: async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id: if f_event.origin != self.session_id.master_node_id:
continue continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event) self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -214,7 +212,7 @@ class Worker:
await self.download_command_sender.send( await self.download_command_sender.send(
ForwarderDownloadCommand( ForwarderDownloadCommand(
origin=self._system_id, origin=self.node_id,
command=StartDownload( command=StartDownload(
target_node_id=self.node_id, target_node_id=self.node_id,
shard_metadata=shard, shard_metadata=shard,
@@ -319,7 +317,7 @@ class Worker:
) )
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand( ForwarderCommand(
origin=self._system_id, origin=self.node_id,
command=RequestEventLog(since_idx=since_idx), command=RequestEventLog(since_idx=since_idx),
) )
) )
@@ -346,16 +344,15 @@ class Worker:
return runner return runner
async def _forward_events(self) -> None: async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events: with self.event_receiver as events:
async for event in events: async for event in events:
fe = LocalForwarderEvent( idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx, origin_idx=idx,
origin=self._system_id, origin=self.node_id,
session=self.session_id, session=self.session_id,
event=event, event=event,
) )
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}") logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe) await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe self.out_for_delivery[event.event_id] = fe

View File

@@ -98,11 +98,16 @@ class RunnerSupervisor:
def shutdown(self): def shutdown(self):
logger.info("Runner supervisor shutting down") logger.info("Runner supervisor shutting down")
self._ev_recv.close() with contextlib.suppress(ClosedResourceError):
self._task_sender.close() self._ev_recv.close()
self._event_sender.close() with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK")) self._task_sender.close()
self._cancel_sender.close() with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
self.runner_process.join(5) self.runner_process.join(5)
if not self.runner_process.is_alive(): if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated") logger.info("Runner process succesfully terminated")