mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 15:27:02 -05:00
Compare commits
7 Commits
feat/e2e-c
...
memory-tid
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41c4240919 | ||
|
|
08b4ec5724 | ||
|
|
4c4c6ce99f | ||
|
|
42e1e7322b | ||
|
|
aa3f106fb9 | ||
|
|
2e29605194 | ||
|
|
cacb456cb2 |
13
Cargo.lock
generated
13
Cargo.lock
generated
@@ -890,7 +890,7 @@ dependencies = [
|
||||
"delegate",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
@@ -914,6 +914,12 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -1022,7 +1028,10 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -2753,7 +2762,7 @@ dependencies = [
|
||||
"delegate",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"futures-timer",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
|
||||
@@ -29,14 +29,13 @@ util = { path = "rust/util" }
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-lite = "2.6.1"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
|
||||
@@ -20,6 +20,7 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -962,6 +963,21 @@ Examples:
|
||||
|
||||
selected.sort(key=_placement_sort_key)
|
||||
preview = selected[0]
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
print("Planning phase: checking downloads...", file=log)
|
||||
run_planning_phase(
|
||||
exo,
|
||||
full_model_id,
|
||||
preview,
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
sharding = str(preview["sharding"])
|
||||
|
||||
@@ -35,6 +35,7 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -332,6 +333,20 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
150
bench/harness.py
150
bench/harness.py
@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
|
||||
return selected
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
|
||||
@@ -250,6 +250,11 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// Disk usage per node
|
||||
nodeDisk?: Record<
|
||||
string,
|
||||
{ total: { inBytes: number }; available: { inBytes: number } }
|
||||
>;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
|
||||
@@ -790,10 +790,8 @@
|
||||
if (!progress || typeof progress !== "object") return null;
|
||||
|
||||
const prog = progress as Record<string, unknown>;
|
||||
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
|
||||
const downloadedBytes = getBytes(
|
||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
||||
);
|
||||
const totalBytes = getBytes(prog.total);
|
||||
const downloadedBytes = getBytes(prog.downloaded);
|
||||
const speed = (prog.speed as number) ?? 0;
|
||||
const completedFiles =
|
||||
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
|
||||
@@ -806,8 +804,8 @@
|
||||
for (const [fileName, fileData] of Object.entries(filesObj)) {
|
||||
if (!fileData || typeof fileData !== "object") continue;
|
||||
const fd = fileData as Record<string, unknown>;
|
||||
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
|
||||
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
|
||||
const fTotal = getBytes(fd.total);
|
||||
const fDownloaded = getBytes(fd.downloaded);
|
||||
files.push({
|
||||
name: fileName,
|
||||
totalBytes: fTotal,
|
||||
@@ -1196,7 +1194,6 @@
|
||||
if (typeof value === "number") return value;
|
||||
if (value && typeof value === "object") {
|
||||
const v = value as Record<string, unknown>;
|
||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
||||
if (typeof v.inBytes === "number") return v.inBytes;
|
||||
}
|
||||
return 0;
|
||||
|
||||
@@ -74,7 +74,6 @@
|
||||
if (typeof value === "number") return value;
|
||||
if (value && typeof value === "object") {
|
||||
const v = value as Record<string, unknown>;
|
||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
||||
if (typeof v.inBytes === "number") return v.inBytes;
|
||||
}
|
||||
return 0;
|
||||
@@ -231,23 +230,14 @@
|
||||
undefined;
|
||||
let cell: CellStatus;
|
||||
if (tag === "DownloadCompleted") {
|
||||
const totalBytes = getBytes(
|
||||
payload.total_bytes ?? payload.totalBytes,
|
||||
);
|
||||
const totalBytes = getBytes(payload.total);
|
||||
cell = { kind: "completed", totalBytes, modelDirectory };
|
||||
} else if (tag === "DownloadOngoing") {
|
||||
const rawProgress =
|
||||
payload.download_progress ?? payload.downloadProgress ?? {};
|
||||
const prog = rawProgress as Record<string, unknown>;
|
||||
const totalBytes = getBytes(
|
||||
prog.total_bytes ??
|
||||
prog.totalBytes ??
|
||||
payload.total_bytes ??
|
||||
payload.totalBytes,
|
||||
);
|
||||
const downloadedBytes = getBytes(
|
||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
||||
);
|
||||
const totalBytes = getBytes(prog.total ?? payload.total);
|
||||
const downloadedBytes = getBytes(prog.downloaded);
|
||||
const speed = (prog.speed as number) ?? 0;
|
||||
const etaMs =
|
||||
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;
|
||||
|
||||
@@ -74,7 +74,6 @@
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
@@ -27,7 +27,7 @@ networking = { workspace = true }
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
"nightly", # enables better-supported GIL integration
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
@@ -45,11 +45,10 @@ pyo3-log = "0.13.2"
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
@@ -60,3 +59,4 @@ env_logger = "0.11"
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -19,7 +19,7 @@ class ConnectionUpdate:
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
def peer_id(self) -> builtins.str:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@@ -40,92 +40,22 @@ class Keypair:
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
def generate() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ecdsa() -> Keypair:
|
||||
def from_bytes(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
Construct an Ed25519 keypair from secret key bytes
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
Get the secret key bytes underlying the keypair
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
def to_node_id(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -26,8 +25,8 @@ where
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
F: Future + Send,
|
||||
F::Output: Send,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
|
||||
47
rust/exo_pyo3_bindings/src/ident.rs
Normal file
47
rust/exo_pyo3_bindings/src/ident.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::types::{PyBytes, PyBytesMethods as _};
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Construct an Ed25519 keypair from secret key bytes
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Get the secret key bytes underlying the keypair
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self
|
||||
.0
|
||||
.clone()
|
||||
.try_into_ed25519()
|
||||
.pyerr()?
|
||||
.secret()
|
||||
.as_ref()
|
||||
.to_vec();
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
fn to_node_id(&self) -> String {
|
||||
self.0.public().to_peer_id().to_base58()
|
||||
}
|
||||
}
|
||||
@@ -4,26 +4,14 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
mod ident;
|
||||
mod networking;
|
||||
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::types::PyModuleMethods;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -32,14 +20,6 @@ pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
@@ -179,8 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
m.add_class::<PyKeypair>()?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
peer_id: String,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
@@ -251,7 +251,7 @@ async fn networking_task(
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
@@ -272,7 +272,7 @@ async fn networking_task(
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
@@ -1,81 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -22,7 +22,7 @@ delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use futures::stream::StreamExt as _;
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
@@ -38,19 +38,19 @@ async fn main() {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
@@ -64,7 +64,7 @@ async fn main() {
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_lite::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
if self.retry_delay.poll(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
|
||||
@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use futures_lite::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
{ inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
rustToolchain = inputs'.fenix.packages.stable.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
@@ -80,7 +80,7 @@ class DownloadCoordinator:
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
total=progress.total,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
@@ -203,7 +203,7 @@ class DownloadCoordinator:
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
total=initial_progress.total,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
@@ -332,13 +332,13 @@ class DownloadCoordinator:
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
total=progress.total,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
if progress.downloaded_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
|
||||
@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
|
||||
repo_file_download_progress: RepoFileDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
downloaded_bytes=repo_file_download_progress.downloaded,
|
||||
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
total_bytes=repo_file_download_progress.total,
|
||||
downloaded=repo_file_download_progress.downloaded,
|
||||
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
total=repo_file_download_progress.total,
|
||||
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
|
||||
total_files=1,
|
||||
speed=repo_file_download_progress.speed,
|
||||
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
|
||||
repo_download_progress: RepoDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
total_bytes=repo_download_progress.total_bytes,
|
||||
downloaded_bytes=repo_download_progress.downloaded_bytes,
|
||||
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
|
||||
total=repo_download_progress.total,
|
||||
downloaded=repo_download_progress.downloaded,
|
||||
downloaded_this_session=repo_download_progress.downloaded_this_session,
|
||||
completed_files=repo_download_progress.completed_files,
|
||||
total_files=repo_download_progress.total_files,
|
||||
speed=repo_download_progress.overall_speed,
|
||||
@@ -578,19 +578,20 @@ def calculate_repo_progress(
|
||||
file_progress: dict[str, RepoFileDownloadProgress],
|
||||
all_start_time: float,
|
||||
) -> RepoDownloadProgress:
|
||||
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
|
||||
all_downloaded_bytes = sum(
|
||||
(p.downloaded.in_bytes for p in file_progress.values()), 0
|
||||
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
|
||||
all_downloaded = sum(
|
||||
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
|
||||
)
|
||||
all_downloaded_bytes_this_session = sum(
|
||||
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
|
||||
all_downloaded_this_session = sum(
|
||||
(p.downloaded_this_session for p in file_progress.values()),
|
||||
Memory.from_bytes(0),
|
||||
)
|
||||
elapsed_time = time.time() - all_start_time
|
||||
all_speed = (
|
||||
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
||||
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||
)
|
||||
all_eta = (
|
||||
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
|
||||
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
|
||||
if all_speed > 0
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
@@ -609,11 +610,9 @@ def calculate_repo_progress(
|
||||
[p for p in file_progress.values() if p.downloaded == p.total]
|
||||
),
|
||||
total_files=len(file_progress),
|
||||
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(
|
||||
all_downloaded_bytes_this_session
|
||||
),
|
||||
total_bytes=Memory.from_bytes(all_total_bytes),
|
||||
downloaded=all_downloaded,
|
||||
downloaded_this_session=all_downloaded_this_session,
|
||||
total=all_total,
|
||||
overall_speed=all_speed,
|
||||
overall_eta=all_eta,
|
||||
status=status,
|
||||
|
||||
@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
downloaded=Memory.from_bytes(0),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
|
||||
@@ -45,7 +45,7 @@ class Node:
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
|
||||
@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponsesStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||
"""Format a streaming event as an SSE message."""
|
||||
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||
if isinstance(content, str):
|
||||
@@ -219,13 +225,13 @@ async def generate_responses_stream(
|
||||
created_event = ResponseCreatedEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(created_event)
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(in_progress_event)
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
@@ -236,7 +242,7 @@ async def generate_responses_stream(
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq), output_index=0, item=initial_item
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(item_added)
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
@@ -247,7 +253,7 @@ async def generate_responses_stream(
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
@@ -281,7 +287,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
item=fc_item,
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(fc_added)
|
||||
|
||||
# response.function_call_arguments.delta
|
||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||
@@ -290,7 +296,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
delta=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
||||
yield _format_sse(args_delta)
|
||||
|
||||
# response.function_call_arguments.done
|
||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||
@@ -300,7 +306,7 @@ async def generate_responses_stream(
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(args_done)
|
||||
|
||||
# response.output_item.done
|
||||
fc_done_item = ResponseFunctionCallItem(
|
||||
@@ -315,7 +321,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
item=fc_done_item,
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(fc_item_done)
|
||||
|
||||
function_call_items.append(fc_done_item)
|
||||
next_output_index += 1
|
||||
@@ -331,7 +337,7 @@ async def generate_responses_stream(
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(delta_event)
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
@@ -341,7 +347,7 @@ async def generate_responses_stream(
|
||||
content_index=0,
|
||||
text=accumulated_text,
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(text_done)
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
@@ -352,7 +358,7 @@ async def generate_responses_stream(
|
||||
content_index=0,
|
||||
part=final_part,
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(part_done)
|
||||
|
||||
# response.output_item.done
|
||||
final_message_item = ResponseMessageItem(
|
||||
@@ -363,7 +369,7 @@ async def generate_responses_stream(
|
||||
item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(item_done)
|
||||
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
@@ -388,4 +394,4 @@ async def generate_responses_stream(
|
||||
completed_event = ResponseCompletedEvent(
|
||||
sequence_number=next(seq), response=final_response
|
||||
)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(completed_event)
|
||||
|
||||
@@ -1323,7 +1323,7 @@ class API:
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
tags=[],
|
||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
||||
storage_size_megabytes=card.storage_size.in_mb,
|
||||
supports_tensor=card.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
is_custom=is_custom_card(card.model_id),
|
||||
|
||||
@@ -102,22 +102,21 @@ def _allocate_and_validate_layers(
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=model_card.n_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
|
||||
for node_id in node_ids
|
||||
node_memory[node_id].ram_available / total_memory for node_id in node_ids
|
||||
],
|
||||
)
|
||||
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
total_storage = model_card.storage_size
|
||||
total_layers = model_card.n_layers
|
||||
for i, node_id in enumerate(node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
required_memory = (total_storage * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
f"Node {i} ({node_id}) has insufficient memory: "
|
||||
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory.in_gb:.2f} GB available"
|
||||
)
|
||||
|
||||
return layer_allocations
|
||||
|
||||
@@ -42,7 +42,7 @@ from exo.utils.channels import channel
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
@@ -75,7 +75,7 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
await local_event_sender.send(
|
||||
|
||||
@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
|
||||
):
|
||||
# arrange
|
||||
model_card.n_layers = total_layers
|
||||
model_card.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
model_card.storage_size = Memory.from_bytes(
|
||||
sum(available_memory)
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
# arrange
|
||||
topology = Topology()
|
||||
model_card.n_layers = 12
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
model_card.storage_size = Memory.from_bytes(1500)
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
node_id=NodeId(update.peer_id),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
|
||||
@@ -221,7 +221,7 @@ def get_node_id_keypair(
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
return Keypair.generate()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
@@ -235,12 +235,12 @@ def get_node_id_keypair(
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
return Keypair.from_bytes(protobuf_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
total=Memory(),
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
total=Memory(),
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
total_bytes=Memory(),
|
||||
total=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
||||
sem.release()
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
queue.put(get_node_id_keypair().to_bytes())
|
||||
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from math import ceil
|
||||
from typing import Self
|
||||
from typing import Self, overload
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
class Memory(CamelCaseModel):
|
||||
class Memory(FrozenModel):
|
||||
in_bytes: int = 0
|
||||
|
||||
@classmethod
|
||||
@@ -33,12 +33,22 @@ class Memory(CamelCaseModel):
|
||||
return cls(in_bytes=round(val * 1024))
|
||||
|
||||
@property
|
||||
def in_mb(self) -> float:
|
||||
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
|
||||
return self.in_bytes / (1024**2)
|
||||
def in_mb(self) -> int:
|
||||
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
|
||||
return round(self.in_bytes / (1024**2))
|
||||
|
||||
@in_mb.setter
|
||||
def in_mb(self, val: float):
|
||||
def in_mb(self, val: int):
|
||||
"""Set the megabytes for this memory."""
|
||||
self.in_bytes = val * (1024**2)
|
||||
|
||||
@property
|
||||
def in_float_mb(self) -> float:
|
||||
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
|
||||
return self.in_bytes / (1024**2)
|
||||
|
||||
@in_float_mb.setter
|
||||
def in_float_mb(self, val: float):
|
||||
"""Set the megabytes for this memory, rounded to the nearest byte."""
|
||||
self.in_bytes = round(val * (1024**2))
|
||||
|
||||
@@ -57,17 +67,85 @@ class Memory(CamelCaseModel):
|
||||
"""The approximate gigabytes this memory represents."""
|
||||
return self.in_bytes / (1024**3)
|
||||
|
||||
def __add__(self, other: "Memory") -> "Memory":
|
||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||
def __add__(self, other: object) -> "Memory":
|
||||
if isinstance(other, Memory):
|
||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: Self) -> bool:
|
||||
return self.in_bytes < other.in_bytes
|
||||
def __radd__(self, other: object) -> "Memory":
|
||||
if other == 0:
|
||||
return self
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: Self) -> bool:
|
||||
return self.in_bytes <= other.in_bytes
|
||||
def __sub__(self, other: object) -> "Memory":
|
||||
if isinstance(other, Memory):
|
||||
return Memory.from_bytes(self.in_bytes - other.in_bytes)
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: Self) -> bool:
|
||||
return self.in_bytes > other.in_bytes
|
||||
def __mul__(self, other: int | float):
|
||||
return Memory.from_bytes(round(self.in_bytes * other))
|
||||
|
||||
def __ge__(self, other: Self) -> bool:
|
||||
return self.in_bytes >= other.in_bytes
|
||||
def __rmul__(self, other: int | float):
|
||||
return self * other
|
||||
|
||||
@overload
|
||||
def __truediv__(self, other: "Memory") -> float: ...
|
||||
@overload
|
||||
def __truediv__(self, other: int) -> "Memory": ...
|
||||
@overload
|
||||
def __truediv__(self, other: float) -> "Memory": ...
|
||||
def __truediv__(self, other: object) -> "Memory | float":
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes / other.in_bytes
|
||||
if isinstance(other, (int, float)):
|
||||
return Memory.from_bytes(round(self.in_bytes / other))
|
||||
return NotImplemented
|
||||
|
||||
def __floordiv__(self, other: object) -> "Memory":
|
||||
if isinstance(other, (int, float)):
|
||||
return Memory.from_bytes(int(self.in_bytes // other))
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes < other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes <= other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes > other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes >= other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes == other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Memory.from_bytes({self.in_bytes})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.in_gb > 2:
|
||||
val = self.in_gb
|
||||
unit = "GiB"
|
||||
elif self.in_mb > 2:
|
||||
val = self.in_mb
|
||||
unit = "MiB"
|
||||
elif self.in_kb > 3:
|
||||
val = self.in_kb
|
||||
unit = "KiB"
|
||||
else:
|
||||
val = self.in_bytes
|
||||
unit = "B"
|
||||
|
||||
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
|
||||
|
||||
@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class DownloadProgressData(CamelCaseModel):
|
||||
total_bytes: Memory
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total: Memory
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
|
||||
completed_files: int
|
||||
total_files: int
|
||||
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
total_bytes: Memory
|
||||
total: Memory
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
"""Shared fixtures and helpers for E2E chaos/networking tests.
|
||||
|
||||
Provides a ``MiniCluster`` that wires Master + Worker(s) + Election together
|
||||
using in-process channels. No Docker, no network, no GPU -- pure async
|
||||
integration testing of the coordination layer.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Final
|
||||
|
||||
import pytest
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
NodeGatheredInfo,
|
||||
TopologyEdgeCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryUsage
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_MODEL_ID: Final[ModelId] = ModelId("test-model/chaos-test-1b")
|
||||
TEST_MODEL_CARD: Final[ModelCard] = ModelCard(
|
||||
model_id=TEST_MODEL_ID,
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678_948),
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
FAST_ELECTION_TIMEOUT: Final[float] = 0.1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_node_id(label: str) -> NodeId:
|
||||
return NodeId(f"node-{label}")
|
||||
|
||||
|
||||
def make_session_id(master_node: NodeId) -> SessionId:
|
||||
return SessionId(master_node_id=master_node, election_clock=0)
|
||||
|
||||
|
||||
def make_memory_info() -> MemoryUsage:
|
||||
return MemoryUsage(
|
||||
ram_total=Memory.from_bytes(16 * 1024 * 1024 * 1024),
|
||||
ram_available=Memory.from_bytes(8 * 1024 * 1024 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
)
|
||||
|
||||
|
||||
def make_gathered_info_event(
|
||||
node_id: NodeId, sender_id: NodeId, session_id: SessionId, origin_idx: int
|
||||
) -> ForwarderEvent:
|
||||
return ForwarderEvent(
|
||||
origin_idx=origin_idx,
|
||||
origin=sender_id,
|
||||
session=session_id,
|
||||
event=NodeGatheredInfo(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
info=make_memory_info(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_topology_edge_event(
|
||||
source: NodeId,
|
||||
sink: NodeId,
|
||||
sender_id: NodeId,
|
||||
session_id: SessionId,
|
||||
origin_idx: int,
|
||||
ip_suffix: int = 1,
|
||||
) -> ForwarderEvent:
|
||||
"""Create a ForwarderEvent wrapping a TopologyEdgeCreated event."""
|
||||
return ForwarderEvent(
|
||||
origin_idx=origin_idx,
|
||||
origin=sender_id,
|
||||
session=session_id,
|
||||
event=TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=source,
|
||||
sink=sink,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{ip_suffix}/tcp/52415"
|
||||
)
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EventCollector:
|
||||
"""Collects ForwarderEvents from a global event receiver."""
|
||||
|
||||
def __init__(self, receiver: Receiver[ForwarderEvent]) -> None:
|
||||
self._receiver = receiver
|
||||
self.indexed_events: list[IndexedEvent] = []
|
||||
|
||||
def collect(self) -> list[IndexedEvent]:
|
||||
raw = self._receiver.collect()
|
||||
for fe in raw:
|
||||
self.indexed_events.append(
|
||||
IndexedEvent(event=fe.event, idx=len(self.indexed_events))
|
||||
)
|
||||
return self.indexed_events
|
||||
|
||||
async def wait_for_event_count(
|
||||
self, count: int, *, timeout: float = 5.0, poll_interval: float = 0.01
|
||||
) -> list[IndexedEvent]:
|
||||
import anyio
|
||||
|
||||
with anyio.fail_after(timeout):
|
||||
while len(self.collect()) < count:
|
||||
await anyio.sleep(poll_interval)
|
||||
return self.indexed_events
|
||||
|
||||
|
||||
class MiniCluster:
|
||||
"""An in-process cluster with one Master and N Workers wired via channels.
|
||||
|
||||
No networking, no real model loading -- exercises the coordination logic
|
||||
(event sourcing, command routing, election) in a deterministic, fast
|
||||
test harness.
|
||||
"""
|
||||
|
||||
def __init__(self, node_count: int = 2) -> None:
|
||||
self.node_count = node_count
|
||||
self.master_node_id = make_node_id("master")
|
||||
self.session_id = make_session_id(self.master_node_id)
|
||||
|
||||
# -- shared bus channels --
|
||||
self.global_event_sender: Sender[ForwarderEvent]
|
||||
self.global_event_internal_receiver: Receiver[ForwarderEvent]
|
||||
self.global_event_sender, self.global_event_internal_receiver = channel[
|
||||
ForwarderEvent
|
||||
]()
|
||||
|
||||
self.command_sender: Sender[ForwarderCommand]
|
||||
self.command_receiver: Receiver[ForwarderCommand]
|
||||
self.command_sender, self.command_receiver = channel[ForwarderCommand]()
|
||||
|
||||
self.local_event_sender: Sender[ForwarderEvent]
|
||||
self.local_event_receiver: Receiver[ForwarderEvent]
|
||||
self.local_event_sender, self.local_event_receiver = channel[ForwarderEvent]()
|
||||
|
||||
self.download_cmd_sender: Sender[ForwarderDownloadCommand]
|
||||
self._download_cmd_receiver: Receiver[ForwarderDownloadCommand]
|
||||
self.download_cmd_sender, self._download_cmd_receiver = channel[
|
||||
ForwarderDownloadCommand
|
||||
]()
|
||||
|
||||
# -- event collector (taps global events) --
|
||||
self.event_collector = EventCollector(
|
||||
self.global_event_internal_receiver.clone()
|
||||
)
|
||||
|
||||
# -- master --
|
||||
self.master = Master(
|
||||
self.master_node_id,
|
||||
self.session_id,
|
||||
global_event_sender=self.global_event_sender.clone(),
|
||||
local_event_receiver=self.local_event_receiver.clone(),
|
||||
command_receiver=self.command_receiver.clone(),
|
||||
download_command_sender=self.download_cmd_sender.clone(),
|
||||
)
|
||||
|
||||
# -- workers --
|
||||
self.worker_node_ids: list[NodeId] = []
|
||||
self.workers: list[Worker] = []
|
||||
for i in range(node_count):
|
||||
wid = make_node_id(f"worker-{i}")
|
||||
self.worker_node_ids.append(wid)
|
||||
|
||||
counter: Iterator[int] = iter(range(1_000_000))
|
||||
worker = Worker(
|
||||
wid,
|
||||
self.session_id,
|
||||
global_event_receiver=self.global_event_internal_receiver.clone(),
|
||||
local_event_sender=self.local_event_sender.clone(),
|
||||
command_sender=self.command_sender.clone(),
|
||||
download_command_sender=self.download_cmd_sender.clone(),
|
||||
event_index_counter=counter,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
async def inject_node_info(self, node_id: NodeId, sender_suffix: str = "") -> None:
|
||||
"""Inject a NodeGatheredInfo event for a node into the local event bus."""
|
||||
sender_id = NodeId(f"{node_id}_sender{sender_suffix}")
|
||||
await self.local_event_sender.send(
|
||||
make_gathered_info_event(node_id, sender_id, self.session_id, 0)
|
||||
)
|
||||
|
||||
async def wait_for_topology_nodes(
|
||||
self, count: int, *, timeout: float = 5.0
|
||||
) -> None:
|
||||
import anyio
|
||||
|
||||
with anyio.fail_after(timeout):
|
||||
while len(list(self.master.state.topology.list_nodes())) < count:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
async def place_model(
|
||||
self,
|
||||
model_card: ModelCard | None = None,
|
||||
min_nodes: int = 1,
|
||||
) -> None:
|
||||
card = model_card or TEST_MODEL_CARD
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.master_node_id,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=card,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def wait_for_instances(self, count: int, *, timeout: float = 5.0) -> None:
|
||||
import anyio
|
||||
|
||||
with anyio.fail_after(timeout):
|
||||
while len(self.master.state.instances) < count:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
async def send_chat(
|
||||
self,
|
||||
message: str,
|
||||
model: ModelId | None = None,
|
||||
) -> CommandId:
|
||||
cmd_id = CommandId()
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.master_node_id,
|
||||
command=TextGeneration(
|
||||
command_id=cmd_id,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=model or TEST_MODEL_ID,
|
||||
input=[InputMessage(role="user", content=message)],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return cmd_id
|
||||
|
||||
async def shutdown_master(self) -> None:
|
||||
await self.master.shutdown()
|
||||
|
||||
def shutdown_workers(self) -> None:
|
||||
for w in self.workers:
|
||||
w.shutdown()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fast_election_timeout(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caplog(caplog: LogCaptureFixture) -> Iterator[LogCaptureFixture]:
|
||||
handler_id = logger.add(
|
||||
caplog.handler,
|
||||
format="{message}",
|
||||
level=0,
|
||||
filter=lambda record: record["level"].no >= caplog.handler.level,
|
||||
enqueue=True,
|
||||
)
|
||||
yield caplog
|
||||
logger.remove(handler_id)
|
||||
@@ -1,255 +0,0 @@
|
||||
"""E2E Chaos Test: Client disconnect.
|
||||
|
||||
Scenarios:
|
||||
1. Task cancellation after client disconnect -- a TextGeneration command is
|
||||
sent, then immediately cancelled (simulating browser tab close).
|
||||
Verify the master correctly transitions the task to Cancelled status.
|
||||
2. Multiple rapid cancellations -- several chat commands are sent and
|
||||
cancelled in quick succession; no tasks should remain in a stuck state.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TaskCancelled,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
TEST_MODEL_ID,
|
||||
EventCollector,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cancelled_after_client_disconnect() -> None:
|
||||
"""Simulate a browser tab close by sending a TextGeneration command
|
||||
followed immediately by a TaskCancelled command. Verify the task
|
||||
transitions to Cancelled status.
|
||||
"""
|
||||
master_nid = make_node_id("master-cancel")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register node
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place instance
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Send a chat command
|
||||
chat_cmd_id = CommandId()
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=chat_cmd_id,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=TEST_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello world")],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the task to be created
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.tasks) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Immediately cancel -- simulating browser tab close
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TaskCancelled(
|
||||
command_id=CommandId(),
|
||||
cancelled_command_id=chat_cmd_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the task status to be updated to Cancelled
|
||||
with anyio.fail_after(3):
|
||||
while True:
|
||||
tasks_cancelled = [
|
||||
t
|
||||
for t in master.state.tasks.values()
|
||||
if t.task_status == TaskStatus.Cancelled
|
||||
]
|
||||
if tasks_cancelled:
|
||||
break
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
assert len(tasks_cancelled) == 1
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_cancel_does_not_leave_stuck_tasks() -> None:
|
||||
"""Send multiple chat commands and cancel them all rapidly.
|
||||
Verify no tasks remain in Pending or Running state.
|
||||
"""
|
||||
master_nid = make_node_id("master-rapid-cancel")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register node and place instance
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Send 5 chat commands and immediately cancel each
|
||||
chat_cmd_ids: list[CommandId] = []
|
||||
for i in range(5):
|
||||
cmd_id = CommandId()
|
||||
chat_cmd_ids.append(cmd_id)
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=cmd_id,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=TEST_MODEL_ID,
|
||||
input=[InputMessage(role="user", content=f"Message {i}")],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all tasks to be created
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.tasks) < 5:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Cancel all of them
|
||||
for cmd_id in chat_cmd_ids:
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TaskCancelled(
|
||||
command_id=CommandId(),
|
||||
cancelled_command_id=cmd_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all cancellations to be processed
|
||||
with anyio.fail_after(3):
|
||||
while True:
|
||||
cancelled_count = sum(
|
||||
1
|
||||
for t in master.state.tasks.values()
|
||||
if t.task_status == TaskStatus.Cancelled
|
||||
)
|
||||
if cancelled_count == 5:
|
||||
break
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# No tasks should be Pending or Running
|
||||
stuck = [
|
||||
t
|
||||
for t in master.state.tasks.values()
|
||||
if t.task_status in (TaskStatus.Pending, TaskStatus.Running)
|
||||
]
|
||||
assert len(stuck) == 0
|
||||
|
||||
await master.shutdown()
|
||||
@@ -1,395 +0,0 @@
|
||||
"""E2E Chaos Test: Concurrent requests.
|
||||
|
||||
Scenarios:
|
||||
1. Multiple simultaneous inference requests -- verify they are all created
|
||||
as tasks with no data corruption (unique task IDs, correct model IDs).
|
||||
2. Concurrent requests across multiple model instances -- verify tasks are
|
||||
routed to the correct instances.
|
||||
3. Concurrent requests with load balancing -- when multiple instances of
|
||||
the same model exist, verify tasks are distributed.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
TEST_MODEL_ID,
|
||||
EventCollector,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_chat_requests_no_corruption() -> None:
|
||||
"""Send multiple TextGeneration commands concurrently and verify each
|
||||
results in a unique task with the correct model and content mapping.
|
||||
"""
|
||||
master_nid = make_node_id("master-concurrent")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Set up node and instance
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Send 10 concurrent chat requests
|
||||
num_requests = 10
|
||||
cmd_ids: list[CommandId] = []
|
||||
|
||||
async def send_chat(index: int) -> None:
|
||||
cmd_id = CommandId()
|
||||
cmd_ids.append(cmd_id)
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=cmd_id,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=TEST_MODEL_ID,
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user",
|
||||
content=f"Concurrent request #{index}",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as send_tg:
|
||||
for i in range(num_requests):
|
||||
send_tg.start_soon(send_chat, i)
|
||||
|
||||
# Wait for all tasks to be created
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.tasks) < num_requests:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Verify no corruption
|
||||
assert len(master.state.tasks) == num_requests
|
||||
|
||||
# All task IDs should be unique
|
||||
task_ids = list(master.state.tasks.keys())
|
||||
assert len(set(task_ids)) == num_requests
|
||||
|
||||
# All tasks should target the correct model
|
||||
for task in master.state.tasks.values():
|
||||
assert isinstance(task, TextGenerationTask)
|
||||
assert task.task_params.model == TEST_MODEL_ID
|
||||
assert task.task_status == TaskStatus.Pending
|
||||
|
||||
# All tasks should reference the same instance
|
||||
instance_ids = {task.instance_id for task in master.state.tasks.values()}
|
||||
assert len(instance_ids) == 1
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_across_multiple_models() -> None:
|
||||
"""Place two different models, then send concurrent requests for each.
|
||||
Verify tasks are routed to the correct model instances.
|
||||
"""
|
||||
master_nid = make_node_id("master-multi-model")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register node
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place two different models
|
||||
model_a_id = ModelId("test-model/model-a")
|
||||
model_a_card = ModelCard(
|
||||
model_id=model_a_id,
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(500_000),
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
model_b_id = ModelId("test-model/model-b")
|
||||
model_b_card = ModelCard(
|
||||
model_id=model_b_id,
|
||||
n_layers=32,
|
||||
storage_size=Memory.from_bytes(500_000),
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
for card in [model_a_card, model_b_card]:
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=card,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.instances) < 2:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Map instance IDs to models
|
||||
instance_to_model: dict[str, ModelId] = {}
|
||||
for iid, inst in master.state.instances.items():
|
||||
instance_to_model[iid] = inst.shard_assignments.model_id
|
||||
|
||||
# Send concurrent requests for both models
|
||||
async def send_for_model(model_id: ModelId, count: int) -> None:
|
||||
for i in range(count):
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=model_id,
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user",
|
||||
content=f"Request for {model_id} #{i}",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as send_tg:
|
||||
send_tg.start_soon(send_for_model, model_a_id, 3)
|
||||
send_tg.start_soon(send_for_model, model_b_id, 3)
|
||||
|
||||
# Wait for all 6 tasks
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.tasks) < 6:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Verify task routing
|
||||
model_a_tasks = [
|
||||
t
|
||||
for t in master.state.tasks.values()
|
||||
if isinstance(t, TextGenerationTask) and t.task_params.model == model_a_id
|
||||
]
|
||||
model_b_tasks = [
|
||||
t
|
||||
for t in master.state.tasks.values()
|
||||
if isinstance(t, TextGenerationTask) and t.task_params.model == model_b_id
|
||||
]
|
||||
|
||||
assert len(model_a_tasks) == 3
|
||||
assert len(model_b_tasks) == 3
|
||||
|
||||
# All model_a tasks should reference the model_a instance
|
||||
model_a_instance_ids = {
|
||||
iid for iid, mid in instance_to_model.items() if mid == model_a_id
|
||||
}
|
||||
for task in model_a_tasks:
|
||||
assert task.instance_id in model_a_instance_ids
|
||||
|
||||
# All model_b tasks should reference the model_b instance
|
||||
model_b_instance_ids = {
|
||||
iid for iid, mid in instance_to_model.items() if mid == model_b_id
|
||||
}
|
||||
for task in model_b_tasks:
|
||||
assert task.instance_id in model_b_instance_ids
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_index_monotonically_increases_under_load() -> None:
|
||||
"""Under heavy concurrent command load, verify the master's event log
|
||||
index increases monotonically with no gaps or duplicates.
|
||||
"""
|
||||
master_nid = make_node_id("master-monotonic")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register node and place instance
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Blast 20 concurrent commands
|
||||
async def blast_commands(start: int, count: int) -> None:
|
||||
for i in range(count):
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=TEST_MODEL_ID,
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user",
|
||||
content=f"Blast {start + i}",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as blast_tg:
|
||||
blast_tg.start_soon(blast_commands, 0, 10)
|
||||
blast_tg.start_soon(blast_commands, 10, 10)
|
||||
|
||||
# Wait for all tasks
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.tasks) < 20:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Collect all events and verify monotonic indexing
|
||||
# NodeGatheredInfo(0) + InstanceCreated(1) + 20 TaskCreated = 22 events
|
||||
await collector.wait_for_event_count(22, timeout=5.0)
|
||||
|
||||
events = collector.indexed_events
|
||||
indices = [e.idx for e in events]
|
||||
|
||||
# Should be 0, 1, 2, ..., N-1 with no gaps
|
||||
expected = list(range(len(indices)))
|
||||
assert indices == expected
|
||||
|
||||
# last_event_applied_idx should match
|
||||
assert master.state.last_event_applied_idx == len(events) - 1
|
||||
|
||||
await master.shutdown()
|
||||
@@ -1,356 +0,0 @@
|
||||
"""E2E Chaos Test: Large model distributed loading.
|
||||
|
||||
Scenarios:
|
||||
1. Multi-node sharding -- place a model with min_nodes > 1, verify sharding
|
||||
is distributed across multiple nodes with correct shard assignments.
|
||||
2. Single-node gets all layers -- place on 1 node, verify full assignment.
|
||||
3. Three-node sharding -- verify 3-way distribution.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import InstanceMeta, MlxRingInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
|
||||
from exo.utils.channels import Sender, channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
make_topology_edge_event,
|
||||
)
|
||||
|
||||
# A model large enough to need sharding but small enough to fit in test node memory
|
||||
# Each test node has 8GB available, so 2 nodes = 16GB, 3 nodes = 24GB.
|
||||
# storage_size < total cluster memory to pass the memory filter.
|
||||
LARGE_MODEL_CARD = ModelCard(
|
||||
model_id=ModelId("test-model/large-70b-4bit"),
|
||||
n_layers=80,
|
||||
storage_size=Memory.from_bytes(4 * 1024 * 1024 * 1024),
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
async def _register_node(
|
||||
le_sender: Sender[ForwarderEvent],
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
) -> None:
|
||||
"""Register a node by injecting NodeGatheredInfo."""
|
||||
sender_id = NodeId(f"{node_id}_sender")
|
||||
await le_sender.send(make_gathered_info_event(node_id, sender_id, session_id, 0))
|
||||
|
||||
|
||||
async def _add_bidirectional_edge(
|
||||
le_sender: Sender[ForwarderEvent],
|
||||
node_a: NodeId,
|
||||
node_b: NodeId,
|
||||
session_id: SessionId,
|
||||
sender_id: NodeId,
|
||||
origin_idx_start: int,
|
||||
ip_a: int,
|
||||
ip_b: int,
|
||||
) -> None:
|
||||
"""Add bidirectional topology edges between two nodes."""
|
||||
await le_sender.send(
|
||||
make_topology_edge_event(
|
||||
node_a, node_b, sender_id, session_id, origin_idx_start, ip_suffix=ip_b
|
||||
)
|
||||
)
|
||||
await le_sender.send(
|
||||
make_topology_edge_event(
|
||||
node_b, node_a, sender_id, session_id, origin_idx_start + 1, ip_suffix=ip_a
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_node_sharding_distributes_layers() -> None:
|
||||
"""Place a model with min_nodes=2 on a cluster with 2 connected nodes.
|
||||
Verify the resulting instance has shard assignments spanning both nodes.
|
||||
"""
|
||||
master_nid = make_node_id("master-shard")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
worker_a = make_node_id("shard-worker-a")
|
||||
worker_b = make_node_id("shard-worker-b")
|
||||
|
||||
# Register both worker nodes (each sender uses origin_idx=0)
|
||||
for nid in [worker_a, worker_b]:
|
||||
await _register_node(le_sender, nid, session_id)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 2:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Add bidirectional edges to form a 2-node cycle (A <-> B)
|
||||
edge_sender = NodeId("edge_sender")
|
||||
await _add_bidirectional_edge(
|
||||
le_sender, worker_a, worker_b, session_id, edge_sender, 0, 1, 2
|
||||
)
|
||||
|
||||
# Wait for edges to be processed
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_connections())) < 2:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place a large model requiring 2 nodes
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=LARGE_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=2,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
instance_id = next(iter(master.state.instances))
|
||||
instance = master.state.instances[instance_id]
|
||||
assert isinstance(instance, MlxRingInstance)
|
||||
|
||||
shard_assignments = instance.shard_assignments
|
||||
runner_shards = shard_assignments.runner_to_shard
|
||||
|
||||
assert len(runner_shards) == 2
|
||||
|
||||
assigned_nodes = set(shard_assignments.node_to_runner.keys())
|
||||
assert worker_a in assigned_nodes
|
||||
assert worker_b in assigned_nodes
|
||||
|
||||
shards = list(runner_shards.values())
|
||||
assert all(isinstance(s, PipelineShardMetadata) for s in shards)
|
||||
pipeline_shards = [s for s in shards if isinstance(s, PipelineShardMetadata)]
|
||||
|
||||
assert all(s.world_size == 2 for s in pipeline_shards)
|
||||
ranks = {s.device_rank for s in pipeline_shards}
|
||||
assert ranks == {0, 1}
|
||||
|
||||
sorted_shards = sorted(pipeline_shards, key=lambda s: s.device_rank)
|
||||
assert sorted_shards[0].start_layer == 0
|
||||
assert sorted_shards[-1].end_layer == LARGE_MODEL_CARD.n_layers
|
||||
|
||||
total_layers = sum(s.end_layer - s.start_layer for s in sorted_shards)
|
||||
assert total_layers == LARGE_MODEL_CARD.n_layers
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_node_gets_all_layers() -> None:
|
||||
"""Place a model with min_nodes=1 on a single node. Verify the
|
||||
instance has one runner assigned all layers (world_size=1).
|
||||
"""
|
||||
master_nid = make_node_id("master-single")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
worker_nid = make_node_id("single-worker")
|
||||
await _register_node(le_sender, worker_nid, session_id)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 1:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
instance_id = next(iter(master.state.instances))
|
||||
instance = master.state.instances[instance_id]
|
||||
assert isinstance(instance, MlxRingInstance)
|
||||
|
||||
shards = list(instance.shard_assignments.runner_to_shard.values())
|
||||
assert len(shards) == 1
|
||||
|
||||
shard = shards[0]
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
assert shard.world_size == 1
|
||||
assert shard.device_rank == 0
|
||||
assert shard.start_layer == 0
|
||||
assert shard.end_layer == TEST_MODEL_CARD.n_layers
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_three_node_sharding_distributes_evenly() -> None:
|
||||
"""Place a model across 3 connected nodes. Verify all 3 get shard assignments."""
|
||||
master_nid = make_node_id("master-3way")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
workers: list[NodeId] = []
|
||||
for i in range(3):
|
||||
nid = make_node_id(f"three-worker-{i}")
|
||||
workers.append(nid)
|
||||
await _register_node(le_sender, nid, session_id)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 3:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Add bidirectional edges to form a fully connected 3-node cycle:
|
||||
# A <-> B, B <-> C, C <-> A
|
||||
edge_sender = NodeId("edge_sender_3way")
|
||||
idx = 0
|
||||
ip_counter = 10
|
||||
for i in range(3):
|
||||
source = workers[i]
|
||||
sink = workers[(i + 1) % 3]
|
||||
# Forward edge
|
||||
await le_sender.send(
|
||||
make_topology_edge_event(
|
||||
source,
|
||||
sink,
|
||||
edge_sender,
|
||||
session_id,
|
||||
idx,
|
||||
ip_suffix=ip_counter,
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
ip_counter += 1
|
||||
# Reverse edge
|
||||
await le_sender.send(
|
||||
make_topology_edge_event(
|
||||
sink,
|
||||
source,
|
||||
edge_sender,
|
||||
session_id,
|
||||
idx,
|
||||
ip_suffix=ip_counter,
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
ip_counter += 1
|
||||
|
||||
# Wait for all 6 edges (3 pairs x 2 directions)
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_connections())) < 6:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=LARGE_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=3,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
instance = next(iter(master.state.instances.values()))
|
||||
assert isinstance(instance, MlxRingInstance)
|
||||
|
||||
assignments = instance.shard_assignments
|
||||
assert len(assignments.runner_to_shard) == 3
|
||||
assert len(assignments.node_to_runner) == 3
|
||||
|
||||
for w in workers:
|
||||
assert w in assignments.node_to_runner
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
ranks = {s.device_rank for s in shards if isinstance(s, PipelineShardMetadata)}
|
||||
assert ranks == {0, 1, 2}
|
||||
|
||||
pipeline_shards = [s for s in shards if isinstance(s, PipelineShardMetadata)]
|
||||
total_layers = sum(s.end_layer - s.start_layer for s in pipeline_shards)
|
||||
assert total_layers == LARGE_MODEL_CARD.n_layers
|
||||
|
||||
await master.shutdown()
|
||||
@@ -1,272 +0,0 @@
|
||||
"""E2E Chaos Test: Failure recovery.
|
||||
|
||||
Scenarios:
|
||||
1. Master crash and re-election -- master shuts down, a new election round
|
||||
produces a new master, workers re-converge.
|
||||
2. Worker crash during task execution -- runner death is detected, instance
|
||||
is cleaned up, and cluster recovers.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
RunnerStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
EventCollector,
|
||||
MiniCluster,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_master_crash_and_reelection() -> None:
|
||||
"""Simulate master crash by shutting it down, then verify a new master
|
||||
can be started with fresh state and begin accepting commands.
|
||||
|
||||
This tests the scenario where the elected master dies and a new election
|
||||
must take place. We simulate the election result directly (since
|
||||
Election is tested separately) and verify the new master works.
|
||||
"""
|
||||
cluster = MiniCluster(node_count=1)
|
||||
old_instance_id: str = ""
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(cluster.master.run)
|
||||
|
||||
# Set up initial state
|
||||
await cluster.inject_node_info(cluster.master_node_id)
|
||||
await cluster.wait_for_topology_nodes(1)
|
||||
await cluster.place_model()
|
||||
await cluster.wait_for_instances(1)
|
||||
|
||||
# Verify initial state
|
||||
assert len(cluster.master.state.instances) == 1
|
||||
old_instance_id = next(iter(cluster.master.state.instances))
|
||||
|
||||
# --- Crash the master ---
|
||||
await cluster.shutdown_master()
|
||||
|
||||
# --- Start a new master (simulating re-election) ---
|
||||
new_master_nid = make_node_id("new-master")
|
||||
new_session_id = SessionId(master_node_id=new_master_nid, election_clock=1)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
new_master = Master(
|
||||
new_master_nid,
|
||||
new_session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_new_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(new_master.run)
|
||||
|
||||
# New master starts with clean state
|
||||
assert len(new_master.state.instances) == 0
|
||||
assert new_master.state.last_event_applied_idx == -1
|
||||
|
||||
# Re-register node with the new master
|
||||
sender_id = NodeId(f"{new_master_nid}_sender_new")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(new_master_nid, sender_id, new_session_id, 0)
|
||||
)
|
||||
|
||||
# Wait for topology to be rebuilt
|
||||
with anyio.fail_after(3):
|
||||
while len(list(new_master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place a new model instance on the new master
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=new_master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(new_master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Verify new master is functional
|
||||
assert len(new_master.state.instances) == 1
|
||||
new_instance_id = next(iter(new_master.state.instances))
|
||||
# New instance should be different from old one
|
||||
assert new_instance_id != old_instance_id
|
||||
|
||||
await new_master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_failure_triggers_instance_cleanup() -> None:
|
||||
"""Simulate a runner failure by injecting a RunnerStatusUpdated(RunnerFailed)
|
||||
event. Verify that the master's plan loop eventually detects the broken
|
||||
instance (no connected node for the runner) and cleans it up.
|
||||
"""
|
||||
master_nid = make_node_id("master-runner-fail")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register a worker node
|
||||
worker_nid = make_node_id("worker-failing")
|
||||
sender_id = NodeId(f"{worker_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(worker_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place a model instance
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
instance_id = next(iter(master.state.instances))
|
||||
instance = master.state.instances[instance_id]
|
||||
runner_id = next(iter(instance.shard_assignments.runner_to_shard))
|
||||
|
||||
# Inject a RunnerFailed event from the worker
|
||||
await le_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=1,
|
||||
origin=sender_id,
|
||||
session=session_id,
|
||||
event=RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Simulated OOM kill (exitcode=137)"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the runner failure to be processed
|
||||
with anyio.fail_after(3):
|
||||
while runner_id not in master.state.runners:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# The runner status should be RunnerFailed
|
||||
assert isinstance(master.state.runners[runner_id], RunnerFailed)
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_election_recovers_after_multiple_node_joins() -> None:
|
||||
"""Verify that the election protocol correctly handles rapid node
|
||||
join/leave events by running multiple election rounds.
|
||||
"""
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("SURVIVOR"),
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
with anyio.fail_after(5):
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Simulate rapid node joins via connection messages
|
||||
for i in range(3):
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(f"joiner-{i}"),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4=f"10.0.0.{i + 1}",
|
||||
remote_tcp_port=52415,
|
||||
)
|
||||
)
|
||||
# Each connection triggers a new election round
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.proposed_session.master_node_id == NodeId("SURVIVOR"):
|
||||
break
|
||||
|
||||
# After all joins, an election result should eventually be produced
|
||||
result = await er_rx.receive()
|
||||
assert result.session_id.master_node_id == NodeId("SURVIVOR")
|
||||
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
@@ -1,227 +0,0 @@
|
||||
"""E2E Chaos Test: Networking resilience.
|
||||
|
||||
Scenarios:
|
||||
1. Node disconnect mid-inference -- a worker stops receiving global events, then
|
||||
reconnects and catches up via the event buffer / nack mechanism.
|
||||
2. Master detects stale node and times it out, then the node re-announces.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
InstanceCreated,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
EventCollector,
|
||||
MiniCluster,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_disconnect_and_reconnect_event_replay() -> None:
|
||||
"""Simulate a node disconnecting by closing its global event receiver,
|
||||
then reconnecting with a fresh receiver.
|
||||
|
||||
After reconnection, events that were broadcast while the node was
|
||||
disconnected should be replayed to the new receiver via the shared
|
||||
channel state. The master's state should remain consistent.
|
||||
"""
|
||||
cluster = MiniCluster(node_count=1)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(cluster.master.run)
|
||||
|
||||
# Register the master node so topology is populated
|
||||
await cluster.inject_node_info(cluster.master_node_id)
|
||||
await cluster.wait_for_topology_nodes(1)
|
||||
|
||||
# Place a model instance
|
||||
await cluster.place_model()
|
||||
await cluster.wait_for_instances(1)
|
||||
|
||||
# Verify instance was created
|
||||
assert len(cluster.master.state.instances) == 1
|
||||
|
||||
# --- Simulate disconnection ---
|
||||
# The worker's global event receiver is independent; we just verify
|
||||
# that the master continues to accept commands while a worker is gone.
|
||||
_first_instance_id = next(iter(cluster.master.state.instances))
|
||||
|
||||
# Send a chat command while "disconnected" worker can't process
|
||||
_cmd_id = await cluster.send_chat("Hello during disconnect")
|
||||
|
||||
# Give master time to process the command
|
||||
await cluster.event_collector.wait_for_event_count(3, timeout=3.0)
|
||||
|
||||
events = cluster.event_collector.indexed_events
|
||||
# Should have: NodeGatheredInfo, InstanceCreated, TaskCreated
|
||||
assert any(isinstance(e.event, NodeGatheredInfo) for e in events)
|
||||
assert any(isinstance(e.event, InstanceCreated) for e in events)
|
||||
assert any(isinstance(e.event, TaskCreated) for e in events)
|
||||
|
||||
# --- Simulate reconnection ---
|
||||
# A reconnecting node gets a fresh receiver clone and catches up
|
||||
reconnect_receiver = cluster.global_event_internal_receiver.clone()
|
||||
_reconnect_collector = EventCollector(reconnect_receiver)
|
||||
|
||||
# The new receiver should see future events; existing events are in
|
||||
# the master's event log (which would be replayed via RequestEventLog
|
||||
# in production). Here we verify the channel infrastructure works.
|
||||
await cluster.send_chat("Hello after reconnect")
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Master state should now have 2 tasks
|
||||
assert len(cluster.master.state.tasks) == 2
|
||||
|
||||
# The master's state is consistent throughout
|
||||
assert len(cluster.master.state.instances) == 1
|
||||
assert cluster.master.state.last_event_applied_idx >= 3
|
||||
|
||||
await cluster.shutdown_master()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_master_detects_timed_out_node_and_cleans_state() -> None:
|
||||
"""Verify that the master's plan loop detects a node that hasn't sent
|
||||
a heartbeat (NodeGatheredInfo) recently and emits NodeTimedOut, cleaning
|
||||
up topology and related state.
|
||||
"""
|
||||
master_nid = make_node_id("master-timeout")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
_cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register two nodes
|
||||
stale_node = make_node_id("stale")
|
||||
alive_node = make_node_id("alive")
|
||||
|
||||
for node_id, suffix in [(stale_node, "_s0"), (alive_node, "_a0")]:
|
||||
sender_id = NodeId(f"{node_id}_sender{suffix}")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(node_id, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
# Wait for both nodes in topology
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 2:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
assert stale_node in master.state.last_seen
|
||||
assert alive_node in master.state.last_seen
|
||||
|
||||
# Manually expire the stale node's last_seen time by patching the state
|
||||
# (in production, the _plan loop checks every 10s with a 30s threshold)
|
||||
from datetime import timedelta
|
||||
|
||||
old_time = master.state.last_seen[stale_node] - timedelta(seconds=60)
|
||||
patched_last_seen = {**master.state.last_seen, stale_node: old_time}
|
||||
master.state = master.state.model_copy(update={"last_seen": patched_last_seen})
|
||||
|
||||
# Trigger the plan loop manually to speed up the test
|
||||
# The plan loop checks for stale nodes
|
||||
# We wait for the NodeTimedOut event to be emitted
|
||||
with anyio.fail_after(15):
|
||||
while stale_node in master.state.last_seen:
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Stale node should be removed from topology
|
||||
assert stale_node not in set(master.state.topology.list_nodes())
|
||||
|
||||
# Alive node should still be present
|
||||
assert alive_node in set(master.state.topology.list_nodes())
|
||||
assert alive_node in master.state.last_seen
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_ordering_preserved_under_concurrent_writers() -> None:
|
||||
"""Multiple sources writing local events concurrently. Verify that the
|
||||
master's MultiSourceBuffer correctly sequences events from each source
|
||||
and the final state is consistent.
|
||||
"""
|
||||
master_nid = make_node_id("master-ordering")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
_cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Inject events from 3 different "worker" sources concurrently
|
||||
node_ids = [make_node_id(f"concurrent-{i}") for i in range(3)]
|
||||
|
||||
async def inject_events(node_id: NodeId, count: int) -> None:
|
||||
for idx in range(count):
|
||||
sender_id = NodeId(f"{node_id}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(node_id, sender_id, session_id, idx)
|
||||
)
|
||||
await anyio.sleep(0.001) # slight jitter
|
||||
|
||||
async with anyio.create_task_group() as inject_tg:
|
||||
for nid in node_ids:
|
||||
inject_tg.start_soon(inject_events, nid, 5)
|
||||
|
||||
# Wait for master to process all events (3 nodes * 5 events each = 15)
|
||||
with anyio.fail_after(5):
|
||||
while master.state.last_event_applied_idx < 14:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# All 3 nodes should be visible in topology
|
||||
topo_nodes = set(master.state.topology.list_nodes())
|
||||
for nid in node_ids:
|
||||
assert nid in topo_nodes
|
||||
|
||||
# Event indices should be sequential with no gaps
|
||||
assert master.state.last_event_applied_idx == 14
|
||||
|
||||
await master.shutdown()
|
||||
@@ -1,267 +0,0 @@
|
||||
"""E2E Chaos Test: Node join/leave during operation.
|
||||
|
||||
Scenarios:
|
||||
1. Add nodes dynamically -- register new nodes with the master while
|
||||
a model is already placed, verify topology grows.
|
||||
2. Remove nodes -- simulate node timeout, verify instances on that node
|
||||
are cleaned up and remaining nodes are unaffected.
|
||||
3. Rapid join/leave churn -- nodes join and leave quickly, verify state
|
||||
converges to a consistent snapshot.
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_node_registration_expands_topology() -> None:
|
||||
"""Start with one node, then add more dynamically. Verify the topology
|
||||
grows and all nodes are visible in state.
|
||||
"""
|
||||
master_nid = make_node_id("master-join")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register initial node
|
||||
initial_node = make_node_id("initial")
|
||||
sender_id = NodeId(f"{initial_node}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(initial_node, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 1:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place a model instance
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Dynamically add 3 more nodes
|
||||
new_nodes: list[NodeId] = []
|
||||
for i in range(3):
|
||||
new_nid = make_node_id(f"dynamic-{i}")
|
||||
new_nodes.append(new_nid)
|
||||
new_sender = NodeId(f"{new_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(new_nid, new_sender, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 4:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# All 4 nodes should be in topology
|
||||
topo_nodes = set(master.state.topology.list_nodes())
|
||||
assert initial_node in topo_nodes
|
||||
for nid in new_nodes:
|
||||
assert nid in topo_nodes
|
||||
|
||||
# Original instance should still exist
|
||||
assert len(master.state.instances) >= 1
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_removal_cleans_up_instances() -> None:
|
||||
"""Place a model on a specific node, then time it out. Verify the
|
||||
instance assigned to that node is deleted by the master's plan loop.
|
||||
"""
|
||||
master_nid = make_node_id("master-leave")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register a worker node
|
||||
worker_nid = make_node_id("worker-leaving")
|
||||
sender_id = NodeId(f"{worker_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(worker_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 1:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place instance on the worker node
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
assert len(master.state.instances) == 1
|
||||
|
||||
# Simulate node leaving by expiring its last_seen
|
||||
old_time = master.state.last_seen[worker_nid] - timedelta(seconds=60)
|
||||
patched_last_seen = {**master.state.last_seen, worker_nid: old_time}
|
||||
master.state = master.state.model_copy(update={"last_seen": patched_last_seen})
|
||||
|
||||
# The plan loop should detect the stale node and delete the instance
|
||||
# because the node assigned to the instance is no longer in the topology
|
||||
with anyio.fail_after(15):
|
||||
while worker_nid in master.state.last_seen:
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# After timeout, the node should be removed from topology
|
||||
assert worker_nid not in set(master.state.topology.list_nodes())
|
||||
|
||||
# The instance should eventually be deleted since the assigned node
|
||||
# is no longer connected (the _plan loop kills broken instances)
|
||||
with anyio.fail_after(15):
|
||||
while len(master.state.instances) > 0:
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
assert len(master.state.instances) == 0
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_join_leave_churn_converges() -> None:
|
||||
"""Rapidly join and leave nodes. After the churn settles, verify the
|
||||
master's state reflects only the surviving nodes.
|
||||
"""
|
||||
master_nid = make_node_id("master-churn")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
_cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register 5 nodes rapidly
|
||||
all_nodes: list[NodeId] = []
|
||||
for i in range(5):
|
||||
nid = make_node_id(f"churn-{i}")
|
||||
all_nodes.append(nid)
|
||||
sender_id = NodeId(f"{nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(5):
|
||||
while len(list(master.state.topology.list_nodes())) < 5:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
assert len(list(master.state.topology.list_nodes())) == 5
|
||||
|
||||
# Expire the first 3 nodes (simulate leaving)
|
||||
leaving_nodes = all_nodes[:3]
|
||||
surviving_nodes = all_nodes[3:]
|
||||
|
||||
patched_last_seen = dict(master.state.last_seen)
|
||||
for nid in leaving_nodes:
|
||||
patched_last_seen[nid] = patched_last_seen[nid] - timedelta(seconds=60)
|
||||
master.state = master.state.model_copy(update={"last_seen": patched_last_seen})
|
||||
|
||||
# Wait for master's plan loop to time out the expired nodes
|
||||
with anyio.fail_after(15):
|
||||
while any(nid in master.state.last_seen for nid in leaving_nodes):
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Verify only surviving nodes remain
|
||||
topo_nodes = set(master.state.topology.list_nodes())
|
||||
for nid in leaving_nodes:
|
||||
assert nid not in topo_nodes
|
||||
for nid in surviving_nodes:
|
||||
assert nid in topo_nodes
|
||||
|
||||
assert len(list(master.state.topology.list_nodes())) == 2
|
||||
|
||||
await master.shutdown()
|
||||
@@ -166,7 +166,7 @@ def generate_image(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
peak_memory = Memory.from_bytes(mx.get_peak_memory())
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
@@ -175,7 +175,7 @@ def generate_image(
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
peak_memory_usage=peak_memory,
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
|
||||
@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
|
||||
@@ -232,11 +232,11 @@ def shard_and_load(
|
||||
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb
|
||||
model_size = get_weights_size(shard_metadata)
|
||||
timeout_seconds = base_timeout + model_size.in_gb
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
f"(model size: {model_size.in_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
match shard_metadata:
|
||||
@@ -617,18 +617,17 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
model_bytes = model_size.in_bytes
|
||||
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if model_bytes > 0.9 * max_rec_size:
|
||||
model_mb = model_bytes // 2**20
|
||||
max_rec_mb = max_rec_size // 2**20
|
||||
max_rec_size = Memory.from_bytes(
|
||||
int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
)
|
||||
if model_size > 0.9 * max_rec_size:
|
||||
logger.warning(
|
||||
f"Generating with a model that requires {model_mb} MB "
|
||||
f"which is close to the maximum recommended size of {max_rec_mb} "
|
||||
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
|
||||
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
|
||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
||||
)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
mx.set_wired_limit(max_rec_size.in_bytes)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
|
||||
|
||||
@@ -241,6 +241,11 @@ class Worker:
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
|
||||
@@ -90,14 +90,10 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||
],
|
||||
}
|
||||
|
||||
@@ -138,9 +134,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
# Global state shows shard is downloaded for NODE_A
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
@@ -187,9 +181,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
],
|
||||
NODE_B: [], # NODE_B has no downloads completed yet
|
||||
}
|
||||
@@ -207,14 +199,10 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user