Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
1a4be539da feat: add bug report to GitHub issue prompt modal
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:31:45 -08:00
39 changed files with 823 additions and 388 deletions

13
Cargo.lock generated
View File

@@ -890,7 +890,7 @@ dependencies = [
"delegate", "delegate",
"env_logger", "env_logger",
"extend", "extend",
"futures-lite", "futures",
"libp2p", "libp2p",
"log", "log",
"networking", "networking",
@@ -914,12 +914,6 @@ dependencies = [
"syn 2.0.111", "syn 2.0.111",
] ]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]] [[package]]
name = "ff" name = "ff"
version = "0.13.1" version = "0.13.1"
@@ -1028,10 +1022,7 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [ dependencies = [
"fastrand",
"futures-core", "futures-core",
"futures-io",
"parking",
"pin-project-lite", "pin-project-lite",
] ]
@@ -2762,7 +2753,7 @@ dependencies = [
"delegate", "delegate",
"either", "either",
"extend", "extend",
"futures-lite", "futures",
"futures-timer", "futures-timer",
"keccak-const", "keccak-const",
"libp2p", "libp2p",

View File

@@ -29,13 +29,14 @@ util = { path = "rust/util" }
# Macro dependecies # Macro dependecies
extend = "1.2" extend = "1.2"
delegate = "0.13" delegate = "0.13"
pin-project = "1"
# Utility dependencies # Utility dependencies
keccak-const = "0.2" keccak-const = "0.2"
# Async dependencies # Async dependencies
tokio = "1.46" tokio = "1.46"
futures-lite = "2.6.1" futures = "0.3"
futures-timer = "3.0" futures-timer = "3.0"
# Data structures # Data structures

View File

@@ -20,7 +20,6 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -963,21 +962,6 @@ Examples:
selected.sort(key=_placement_sort_key) selected.sort(key=_placement_sort_key)
preview = selected[0] preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"] instance = preview["instance"]
instance_id = instance_id_from_instance(instance) instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"]) sharding = str(preview["sharding"])

View File

@@ -35,7 +35,6 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -333,20 +332,6 @@ def main() -> int:
if args.dry_run: if args.dry_run:
return 0 return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = [] all_rows: list[dict[str, Any]] = []
for preview in selected: for preview in selected:

View File

@@ -282,151 +282,6 @@ def settle_and_fetch_placements(
return selected return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None: def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost")) ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument( ap.add_argument(
@@ -470,8 +325,3 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0, default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).", help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
) )
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -0,0 +1,188 @@
<script lang="ts">
import { fade, fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
interface Props {
isOpen: boolean;
onClose: () => void;
}
let { isOpen, onClose }: Props = $props();
let bugReportId = $state<string | null>(null);
let githubIssueUrl = $state<string | null>(null);
let isLoading = $state(false);
let error = $state<string | null>(null);
async function generateBugReport() {
isLoading = true;
error = null;
try {
const response = await fetch("/bug-report", { method: "POST" });
if (!response.ok) {
error = "Failed to generate bug report. Please try again.";
return;
}
const data = await response.json();
bugReportId = data.bugReportId;
githubIssueUrl = data.githubIssueUrl;
} catch {
error = "Failed to connect to the server. Please try again.";
} finally {
isLoading = false;
}
}
function handleClose() {
bugReportId = null;
githubIssueUrl = null;
error = null;
isLoading = false;
onClose();
}
// Generate bug report when modal opens
$effect(() => {
if (isOpen && !bugReportId && !isLoading) {
generateBugReport();
}
});
</script>
{#if isOpen}
<!-- Backdrop -->
<div
class="fixed inset-0 z-50 bg-black/80 backdrop-blur-sm"
transition:fade={{ duration: 200 }}
onclick={handleClose}
role="presentation"
></div>
<!-- Modal -->
<div
class="fixed z-50 top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(90vw,480px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl overflow-hidden flex flex-col"
transition:fly={{ y: 20, duration: 300, easing: cubicOut }}
role="dialog"
aria-modal="true"
aria-label="Bug Report"
>
<!-- Header -->
<div
class="flex items-center justify-between px-5 py-4 border-b border-exo-medium-gray/30"
>
<div class="flex items-center gap-2">
<svg
class="w-5 h-5 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
/>
</svg>
<h2 class="text-sm font-mono text-exo-yellow tracking-wider uppercase">
Report a Bug
</h2>
</div>
<button
onclick={handleClose}
class="text-exo-light-gray hover:text-white transition-colors cursor-pointer"
aria-label="Close"
>
<svg
class="w-5 h-5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</div>
<!-- Body -->
<div class="px-5 py-5 space-y-4">
{#if isLoading}
<div class="flex items-center justify-center py-6">
<div
class="w-5 h-5 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"
></div>
<span class="ml-3 text-sm text-exo-light-gray font-mono"
>Generating bug report...</span
>
</div>
{:else if error}
<div
class="text-sm text-red-400 font-mono bg-red-400/10 border border-red-400/20 rounded px-4 py-3"
>
{error}
</div>
<button
onclick={generateBugReport}
class="w-full px-4 py-2.5 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded text-sm font-mono text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer"
>
Try Again
</button>
{:else if bugReportId && githubIssueUrl}
<p class="text-sm text-exo-light-gray leading-relaxed">
Would you like to create a GitHub issue? This would help us track and
fix the issue for you.
</p>
<!-- Bug Report ID -->
<div
class="bg-exo-black/50 border border-exo-medium-gray/30 rounded px-4 py-3"
>
<div
class="text-[11px] text-exo-light-gray/60 font-mono tracking-wider uppercase mb-1"
>
Bug Report ID
</div>
<div class="text-sm text-exo-yellow font-mono tracking-wide">
{bugReportId}
</div>
<div class="text-[11px] text-exo-light-gray/50 font-mono mt-1">
Include this ID when communicating with the team.
</div>
</div>
<p class="text-xs text-exo-light-gray/60 leading-relaxed">
No diagnostic data is attached. The issue template contains
placeholder fields for you to fill in.
</p>
<!-- Actions -->
<div class="flex gap-3 pt-1">
<a
href={githubIssueUrl}
target="_blank"
rel="noopener noreferrer"
class="flex-1 flex items-center justify-center gap-2 px-4 py-2.5 bg-exo-yellow/10 border border-exo-yellow/40 rounded text-sm font-mono text-exo-yellow hover:bg-exo-yellow/20 hover:border-exo-yellow/60 transition-colors"
>
<svg class="w-4 h-4" viewBox="0 0 16 16" fill="currentColor">
<path
d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"
/>
</svg>
Create GitHub Issue
</a>
<button
onclick={handleClose}
class="px-4 py-2.5 border border-exo-medium-gray/40 rounded text-sm font-mono text-exo-light-gray hover:border-exo-medium-gray/60 transition-colors cursor-pointer"
>
Close
</button>
</div>
{/if}
</div>
</div>
{/if}

View File

@@ -74,6 +74,7 @@
perSystem = perSystem =
{ config, self', inputs', pkgs, lib, system, ... }: { config, self', inputs', pkgs, lib, system, ... }:
let let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs) # Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; }; pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in in

2
rust/clippy.toml Normal file
View File

@@ -0,0 +1,2 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -27,7 +27,7 @@ networking = { workspace = true }
# interop # interop
pyo3 = { version = "0.27.2", features = [ pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13 # "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
# "nightly", # enables better-supported GIL integration "nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods] "experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation #"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen) #"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
@@ -45,10 +45,11 @@ pyo3-log = "0.13.2"
# macro dependencies # macro dependencies
extend = { workspace = true } extend = { workspace = true }
delegate = { workspace = true } delegate = { workspace = true }
pin-project = { workspace = true }
# async runtime # async runtime
tokio = { workspace = true, features = ["full", "tracing"] } tokio = { workspace = true, features = ["full", "tracing"] }
futures-lite = { workspace = true } futures = { workspace = true }
# utility dependencies # utility dependencies
util = { workspace = true } util = { workspace = true }
@@ -59,4 +60,3 @@ env_logger = "0.11"
# Networking # Networking
libp2p = { workspace = true, features = ["full"] } libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -19,7 +19,7 @@ class ConnectionUpdate:
Whether this is a connection or disconnection event Whether this is a connection or disconnection event
""" """
@property @property
def peer_id(self) -> builtins.str: def peer_id(self) -> PeerId:
r""" r"""
Identity of the peer that we have connected to or disconnected from. Identity of the peer that we have connected to or disconnected from.
""" """
@@ -40,22 +40,92 @@ class Keypair:
Identity keypair of a node. Identity keypair of a node.
""" """
@staticmethod @staticmethod
def generate() -> Keypair: def generate_ed25519() -> Keypair:
r""" r"""
Generate a new Ed25519 keypair. Generate a new Ed25519 keypair.
""" """
@staticmethod @staticmethod
def from_bytes(bytes: bytes) -> Keypair: def generate_ecdsa() -> Keypair:
r""" r"""
Construct an Ed25519 keypair from secret key bytes Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
""" """
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
r""" r"""
Get the secret key bytes underlying the keypair Return a copy of this [`Multiaddr`]'s byte representation.
""" """
def to_node_id(self) -> builtins.str: def to_string(self) -> builtins.str:
r""" r"""
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`. Convert a Multiaddr to a string.
""" """
@typing.final @typing.final
@@ -110,6 +180,37 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ... def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final @typing.final
class ConnectionUpdateType(enum.Enum): class ConnectionUpdateType(enum.Enum):
r""" r"""

View File

@@ -2,6 +2,7 @@
//! //!
use pin_project::pin_project; use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*; use pyo3::prelude::*;
use std::{ use std::{
future::Future, future::Future,
@@ -25,8 +26,8 @@ where
impl<F> Future for AllowThreads<F> impl<F> Future for AllowThreads<F>
where where
F: Future + Send, F: Future + Ungil,
F::Output: Send, F::Output: Ungil,
{ {
type Output = F::Output; type Output = F::Output;

View File

@@ -1,47 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::identity::Keypair;
use pyo3::types::{PyBytes, PyBytesMethods as _};
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Construct an Ed25519 keypair from secret key bytes
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Get the secret key bytes underlying the keypair
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
fn to_node_id(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}

View File

@@ -4,14 +4,26 @@
//! //!
//! //!
mod allow_threading; // enable Rust-unstable features for convenience
mod ident; #![feature(trait_alias)]
mod networking; #![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
use crate::ident::PyKeypair;
use crate::networking::networking_submodule; use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule; use pyo3::prelude::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule}; use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer; use pyo3_stub_gen::define_stub_info_gatherer;
@@ -20,6 +32,14 @@ pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024; pub const MPSC_CHANNEL_SIZE: usize = 1024;
} }
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
}
/// Namespace for crate-wide extension traits/methods /// Namespace for crate-wide extension traits/methods
pub(crate) mod ext { pub(crate) mod ext {
use crate::allow_threading::AllowThreads; use crate::allow_threading::AllowThreads;
@@ -159,7 +179,8 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system // TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without // work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues... // too many importing issues...
m.add_class::<PyKeypair>()?; ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?; networking_submodule(m)?;
// top-level constructs // top-level constructs

View File

@@ -8,8 +8,8 @@
use crate::r#const::MPSC_CHANNEL_SIZE; use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _}; use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _}; use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair;
use crate::pyclass; use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _; use libp2p::futures::StreamExt as _;
use libp2p::gossipsub; use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError}; use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
/// Identity of the peer that we have connected to or disconnected from. /// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)] #[pyo3(get)]
peer_id: String, peer_id: PyPeerId,
/// Remote connection's IPv4 address. /// Remote connection's IPv4 address.
#[pyo3(get)] #[pyo3(get)]
@@ -251,7 +251,7 @@ async fn networking_task(
// send connection event to channel (or exit if connection closed) // send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate { if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected, update_type: PyConnectionUpdateType::Connected,
peer_id: peer_id.to_base58(), peer_id: PyPeerId(peer_id),
remote_ipv4, remote_ipv4,
remote_tcp_port, remote_tcp_port,
}).await { }).await {
@@ -272,7 +272,7 @@ async fn networking_task(
// send disconnection event to channel (or exit if connection closed) // send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate { if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected, update_type: PyConnectionUpdateType::Disconnected,
peer_id: peer_id.to_base58(), peer_id: PyPeerId(peer_id),
remote_ipv4, remote_ipv4,
remote_tcp_port, remote_tcp_port,
}).await { }).await {

View File

@@ -0,0 +1,159 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -0,0 +1,8 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -0,0 +1,81 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -22,7 +22,7 @@ delegate = { workspace = true }
# async # async
tokio = { workspace = true, features = ["full"] } tokio = { workspace = true, features = ["full"] }
futures-lite = { workspace = true } futures = { workspace = true }
futures-timer = { workspace = true } futures-timer = { workspace = true }
# utility dependencies # utility dependencies

View File

@@ -1,4 +1,4 @@
use futures_lite::StreamExt; use futures::stream::StreamExt as _;
use libp2p::{gossipsub, identity, swarm::SwarmEvent}; use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm}; use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select}; use tokio::{io, io::AsyncBufReadExt as _, select};
@@ -38,19 +38,19 @@ async fn main() {
println!("Publish error: {e:?}"); println!("Publish error: {e:?}");
} }
} }
event = swarm.next() => match event { event = swarm.select_next_some() => match event {
// on gossipsub incoming // on gossipsub incoming
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message { SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id, propagation_source: peer_id,
message_id: id, message_id: id,
message, message,
}))) => println!( })) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n", "\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data), String::from_utf8_lossy(&message.data),
), ),
// on discovery // on discovery
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e { SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
discovery::Event::ConnectionEstablished { discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port peer_id, connection_id, remote_ip, remote_tcp_port
} => { } => {
@@ -64,7 +64,7 @@ async fn main() {
} }
// ignore outgoing errors: those are normal // ignore outgoing errors: those are normal
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); } e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event // otherwise log any other event
e => { log::info!("Other event {e:?}"); } e => { log::info!("Other event {e:?}"); }

View File

@@ -0,0 +1,127 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::error::Error;
use std::time::Duration;
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -1,7 +1,7 @@
use crate::ext::MultiaddrExt; use crate::ext::MultiaddrExt;
use delegate::delegate; use delegate::delegate;
use either::Either; use either::Either;
use futures_lite::FutureExt; use futures::FutureExt;
use futures_timer::Delay; use futures_timer::Delay;
use libp2p::core::transport::PortUse; use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint}; use libp2p::core::{ConnectedPoint, Endpoint};
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
} }
// retry connecting to all mDNS peers periodically (fails safely if already connected) // retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll(cx).is_ready() { if self.retry_delay.poll_unpin(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() { for (p, mas) in self.mdns_discovered.clone() {
for ma in mas { for ma in mas {
self.dial(p, ma) self.dial(p, ma)

View File

@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
mod transport { mod transport {
use crate::alias; use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR}; use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures_lite::{AsyncRead, AsyncWrite}; use futures::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256; use keccak_const::Sha3_256;
use libp2p::core::muxing; use libp2p::core::muxing;
use libp2p::core::transport::Boxed; use libp2p::core::transport::Boxed;

View File

@@ -1,10 +1,11 @@
{ inputs, ... }: { inputs, ... }:
{ {
perSystem = perSystem =
{ inputs', pkgs, lib, ... }: { config, self', inputs', pkgs, lib, ... }:
let let
# Fenix nightly toolchain with all components # Fenix nightly toolchain with all components
rustToolchain = inputs'.fenix.packages.stable.withComponents [ fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
"cargo" "cargo"
"rustc" "rustc"
"clippy" "clippy"

2
rust/rust-toolchain.toml Normal file
View File

@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly"

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import socket import socket
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Iterator
import anyio import anyio
from anyio import current_time from anyio import current_time
@@ -21,10 +22,10 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand, ForwarderDownloadCommand,
StartDownload, StartDownload,
) )
from exo.shared.types.common import NodeId, SessionId, SystemId from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
Event, Event,
LocalForwarderEvent, ForwarderEvent,
NodeDownloadProgress, NodeDownloadProgress,
) )
from exo.shared.types.worker.downloads import ( from exo.shared.types.worker.downloads import (
@@ -44,9 +45,9 @@ class DownloadCoordinator:
session_id: SessionId session_id: SessionId
shard_downloader: ShardDownloader shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand] download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[LocalForwarderEvent] local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state # Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict) download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -297,16 +298,15 @@ class DownloadCoordinator:
del self.download_status[model_id] del self.download_status[model_id]
async def _forward_events(self) -> None: async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events: with self.event_receiver as events:
async for event in events: async for event in events:
fe = LocalForwarderEvent( idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx, origin_idx=idx,
origin=self._system_id, origin=self.node_id,
session=self.session_id, session=self.session_id,
event=event, event=event,
) )
idx += 1
logger.debug( logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}" f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
) )

View File

@@ -1,10 +1,11 @@
import argparse import argparse
import itertools
import multiprocessing as mp import multiprocessing as mp
import os import os
import resource import resource
import signal import signal
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Self from typing import Iterator, Self
import anyio import anyio
from anyio.abc import TaskGroup from anyio.abc import TaskGroup
@@ -37,13 +38,14 @@ class Node:
api: API | None api: API | None
node_id: NodeId node_id: NodeId
event_index_counter: Iterator[int]
offline: bool offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group) _tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod @classmethod
async def create(cls, args: "Args") -> Self: async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id()) node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair) router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS) await router.register_topic(topics.GLOBAL_EVENTS)
@@ -55,6 +57,9 @@ class Node:
logger.info(f"Starting node {node_id}") logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads) # Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads: if not args.no_downloads:
download_coordinator = DownloadCoordinator( download_coordinator = DownloadCoordinator(
@@ -63,6 +68,7 @@ class Node:
exo_shard_downloader(), exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS), local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline, offline=args.offline,
) )
else: else:
@@ -89,6 +95,7 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS), local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS), command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
) )
else: else:
worker = None worker = None
@@ -126,6 +133,7 @@ class Node:
master, master,
api, api,
node_id, node_id,
event_index_counter,
args.offline, args.offline,
) )
@@ -204,6 +212,8 @@ class Node:
) )
if result.is_new_master: if result.is_new_master:
await anyio.sleep(0) await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator: if self.download_coordinator:
self.download_coordinator.shutdown() self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator( self.download_coordinator = DownloadCoordinator(
@@ -214,6 +224,7 @@ class Node:
topics.DOWNLOAD_COMMANDS topics.DOWNLOAD_COMMANDS
), ),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS), local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline, offline=self.offline,
) )
self._tg.start_soon(self.download_coordinator.run) self._tg.start_soon(self.download_coordinator.run)
@@ -231,6 +242,7 @@ class Node:
download_command_sender=self.router.sender( download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS topics.DOWNLOAD_COMMANDS
), ),
event_index_counter=self.event_index_counter,
) )
self._tg.start_soon(self.worker.run) self._tg.start_soon(self.worker.run)
if self.api: if self.api:

View File

@@ -31,7 +31,6 @@ from exo.shared.types.openai_responses import (
ResponseOutputText, ResponseOutputText,
ResponsesRequest, ResponsesRequest,
ResponsesResponse, ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseTextDoneEvent,
ResponseUsage, ResponseUsage,
@@ -39,11 +38,6 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str: def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts.""" """Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str): if isinstance(content, str):
@@ -225,13 +219,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent( created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield _format_sse(created_event) yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress # response.in_progress
in_progress_event = ResponseInProgressEvent( in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield _format_sse(in_progress_event) yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added # response.output_item.added
initial_item = ResponseMessageItem( initial_item = ResponseMessageItem(
@@ -242,7 +236,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent( item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item sequence_number=next(seq), output_index=0, item=initial_item
) )
yield _format_sse(item_added) yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added # response.content_part.added
initial_part = ResponseOutputText(text="") initial_part = ResponseOutputText(text="")
@@ -253,7 +247,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
part=initial_part, part=initial_part,
) )
yield _format_sse(part_added) yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = "" accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = [] function_call_items: list[ResponseFunctionCallItem] = []
@@ -287,7 +281,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_item, item=fc_item,
) )
yield _format_sse(fc_added) yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
# response.function_call_arguments.delta # response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent( args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -296,7 +290,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
delta=tool.arguments, delta=tool.arguments,
) )
yield _format_sse(args_delta) yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
# response.function_call_arguments.done # response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent( args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -306,7 +300,7 @@ async def generate_responses_stream(
name=tool.name, name=tool.name,
arguments=tool.arguments, arguments=tool.arguments,
) )
yield _format_sse(args_done) yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
# response.output_item.done # response.output_item.done
fc_done_item = ResponseFunctionCallItem( fc_done_item = ResponseFunctionCallItem(
@@ -321,7 +315,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_done_item, item=fc_done_item,
) )
yield _format_sse(fc_item_done) yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
function_call_items.append(fc_done_item) function_call_items.append(fc_done_item)
next_output_index += 1 next_output_index += 1
@@ -337,7 +331,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
delta=chunk.text, delta=chunk.text,
) )
yield _format_sse(delta_event) yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done # response.output_text.done
text_done = ResponseTextDoneEvent( text_done = ResponseTextDoneEvent(
@@ -347,7 +341,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
text=accumulated_text, text=accumulated_text,
) )
yield _format_sse(text_done) yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done # response.content_part.done
final_part = ResponseOutputText(text=accumulated_text) final_part = ResponseOutputText(text=accumulated_text)
@@ -358,7 +352,7 @@ async def generate_responses_stream(
content_index=0, content_index=0,
part=final_part, part=final_part,
) )
yield _format_sse(part_done) yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done # response.output_item.done
final_message_item = ResponseMessageItem( final_message_item = ResponseMessageItem(
@@ -369,7 +363,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent( item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item sequence_number=next(seq), output_index=0, item=final_message_item
) )
yield _format_sse(item_done) yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from usage data if available # Create usage from usage data if available
usage = None usage = None
@@ -394,4 +388,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent( completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response sequence_number=next(seq), response=final_response
) )
yield _format_sse(completed_event) yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -132,11 +132,11 @@ from exo.shared.types.commands import (
TaskFinished, TaskFinished,
TextGeneration, TextGeneration,
) )
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
ChunkGenerated, ChunkGenerated,
Event, Event,
GlobalForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
PrefillProgress, PrefillProgress,
TracesMerged, TracesMerged,
@@ -177,7 +177,8 @@ class API:
session_id: SessionId, session_id: SessionId,
*, *,
port: int, port: int,
global_event_receiver: Receiver[GlobalForwarderEvent], # Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand], command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand], download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running # This lets us pause the API if an election is running
@@ -185,7 +186,6 @@ class API:
) -> None: ) -> None:
self.state = State() self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR) self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender self.command_sender = command_sender
self.download_command_sender = download_command_sender self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver self.global_event_receiver = global_event_receiver
@@ -237,7 +237,6 @@ class API:
self._event_log.close() self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR) self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State() self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]() self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {} self._text_generation_queues = {}
@@ -555,7 +554,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id) command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command) ForwarderCommand(origin=self.node_id, command=command)
) )
raise raise
finally: finally:
@@ -903,7 +902,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id) command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command) ForwarderCommand(origin=self.node_id, command=command)
) )
raise raise
finally: finally:
@@ -989,7 +988,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id) command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True): with anyio.CancelScope(shield=True):
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command) ForwarderCommand(origin=self.node_id, command=command)
) )
raise raise
finally: finally:
@@ -1430,8 +1429,6 @@ class API:
async def _apply_state(self): async def _apply_state(self):
with self.global_event_receiver as events: with self.global_event_receiver as events:
async for f_event in events: async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id: if f_event.origin != self.session_id.master_node_id:
continue continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event) self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -1511,12 +1508,12 @@ class API:
while self.paused: while self.paused:
await self.paused_ev.wait() await self.paused_ev.wait()
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command) ForwarderCommand(origin=self.node_id, command=command)
) )
async def _send_download(self, command: DownloadCommand): async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send( await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self._system_id, command=command) ForwarderDownloadCommand(origin=self.node_id, command=command)
) )
async def start_download( async def start_download(

View File

@@ -29,14 +29,13 @@ from exo.shared.types.commands import (
TestCommand, TestCommand,
TextGeneration, TextGeneration,
) )
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
Event, Event,
GlobalForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
InputChunkReceived, InputChunkReceived,
InstanceDeleted, InstanceDeleted,
LocalForwarderEvent,
NodeGatheredInfo, NodeGatheredInfo,
NodeTimedOut, NodeTimedOut,
TaskCreated, TaskCreated,
@@ -72,8 +71,8 @@ class Master:
session_id: SessionId, session_id: SessionId,
*, *,
command_receiver: Receiver[ForwarderCommand], command_receiver: Receiver[ForwarderCommand],
local_event_receiver: Receiver[LocalForwarderEvent], local_event_receiver: Receiver[ForwarderEvent],
global_event_sender: Sender[GlobalForwarderEvent], global_event_sender: Sender[ForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand], download_command_sender: Sender[ForwarderDownloadCommand],
): ):
self.state = State() self.state = State()
@@ -88,11 +87,10 @@ class Master:
send, recv = channel[Event]() send, recv = channel[Event]()
self.event_sender: Sender[Event] = send self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[LocalForwarderEvent] = ( self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender() local_event_receiver.clone_sender()
) )
self._system_id = SystemId() self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master") self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {} self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {} self._expected_ranks: dict[TaskId, set[int]] = {}
@@ -290,7 +288,7 @@ class Master:
): ):
await self.download_command_sender.send( await self.download_command_sender.send(
ForwarderDownloadCommand( ForwarderDownloadCommand(
origin=self._system_id, command=cmd origin=self.node_id, command=cmd
) )
) )
generated_events.extend(transition_events) generated_events.extend(transition_events)
@@ -416,8 +414,8 @@ class Master:
with self._loopback_event_receiver as events: with self._loopback_event_receiver as events:
async for event in events: async for event in events:
await self._loopback_event_sender.send( await self._loopback_event_sender.send(
LocalForwarderEvent( ForwarderEvent(
origin=self._system_id, origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index, origin_idx=local_index,
session=self.session_id, session=self.session_id,
event=event, event=event,
@@ -429,7 +427,7 @@ class Master:
async def _send_event(self, event: IndexedEvent): async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly # Convenience method since this line is ugly
await self.global_event_sender.send( await self.global_event_sender.send(
GlobalForwarderEvent( ForwarderEvent(
origin=self.node_id, origin=self.node_id,
origin_idx=event.idx, origin_idx=event.idx,
session=self.session_id, session=self.session_id,

View File

@@ -15,12 +15,11 @@ from exo.shared.types.commands import (
PlaceInstance, PlaceInstance,
TextGeneration, TextGeneration,
) )
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
GlobalForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
InstanceCreated, InstanceCreated,
LocalForwarderEvent,
NodeGatheredInfo, NodeGatheredInfo,
TaskCreated, TaskCreated,
) )
@@ -43,12 +42,12 @@ from exo.utils.channels import channel
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_master(): async def test_master():
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id()) node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[GlobalForwarderEvent]() ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]() command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[LocalForwarderEvent]() local_event_sender, le_receiver = channel[ForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]() fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = [] all_events: list[IndexedEvent] = []
@@ -76,12 +75,13 @@ async def test_master():
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(master.run) tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event # inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event") logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send( await local_event_sender.send(
LocalForwarderEvent( ForwarderEvent(
origin_idx=0, origin_idx=0,
origin=SystemId("Worker"), origin=sender_node_id,
session=session_id, session=session_id,
event=( event=(
NodeGatheredInfo( NodeGatheredInfo(
@@ -108,7 +108,7 @@ async def test_master():
logger.info("inject a CreateInstance Command") logger.info("inject a CreateInstance Command")
await command_sender.send( await command_sender.send(
ForwarderCommand( ForwarderCommand(
origin=SystemId("API"), origin=node_id,
command=( command=(
PlaceInstance( PlaceInstance(
command_id=CommandId(), command_id=CommandId(),
@@ -133,7 +133,7 @@ async def test_master():
logger.info("inject a TextGeneration Command") logger.info("inject a TextGeneration Command")
await command_sender.send( await command_sender.send(
ForwarderCommand( ForwarderCommand(
origin=SystemId("API"), origin=node_id,
command=( command=(
TextGeneration( TextGeneration(
command_id=CommandId(), command_id=CommandId(),

View File

@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
@classmethod @classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage": def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls( return cls(
node_id=NodeId(update.peer_id), node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type), connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4, remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port, remote_tcp_port=update.remote_tcp_port,

View File

@@ -221,7 +221,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it. Obtain the :class:`PeerId` by from it.
""" """
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate() return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock") return Path(str(path) + ".lock")
@@ -235,12 +235,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read() protobuf_encoded = f.read()
try: # if decoded successfully, save & return try: # if decoded successfully, save & return
return Keypair.from_bytes(protobuf_encoded) return Keypair.from_protobuf_encoding(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}") logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist # if no valid credentials, create new ones and persist
with open(path, "w+b") as f: with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519() keypair = Keypair.generate_ed25519()
f.write(keypair.to_bytes()) f.write(keypair.to_protobuf_encoding())
return keypair return keypair

View File

@@ -5,8 +5,7 @@ from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import ( from exo.shared.types.events import (
GlobalForwarderEvent, ForwarderEvent,
LocalForwarderEvent,
) )
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
@@ -37,8 +36,8 @@ class TypedTopic[T: CamelCaseModel]:
return self.model_type.model_validate_json(b.decode("utf-8")) return self.model_type.model_validate_json(b.decode("utf-8"))
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, GlobalForwarderEvent) GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent)
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, LocalForwarderEvent) LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent)
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand) COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
ELECTION_MESSAGES = TypedTopic( ELECTION_MESSAGES = TypedTopic(
"election_messages", PublishPolicy.Always, ElectionMessage "election_messages", PublishPolicy.Always, ElectionMessage

View File

@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId, SystemId from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import channel from exo.utils.channels import channel
# ======= # # ======= #
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
# Pump local commands so our commands_seen is high before the round starts # Pump local commands so our commands_seen is high before the round starts
for _ in range(50): for _ in range(50):
await co_tx.send( await co_tx.send(
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand()) ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
) )
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands # Trigger a round at clock=1 with a peer of equal seniority but fewer commands

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release() sem.release()
# wait to be told to begin simultaneous read # wait to be told to begin simultaneous read
ev.wait() ev.wait()
queue.put(get_node_id_keypair().to_bytes()) queue.put(get_node_id_keypair().to_protobuf_encoding())
def _get_keypair_concurrent(num_procs: int) -> bytes: def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams, ImageGenerationTaskParams,
) )
from exo.shared.types.chunks import InputImageChunk from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId, SystemId from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -100,10 +100,10 @@ Command = (
class ForwarderCommand(CamelCaseModel): class ForwarderCommand(CamelCaseModel):
origin: SystemId origin: NodeId
command: Command command: Command
class ForwarderDownloadCommand(CamelCaseModel): class ForwarderDownloadCommand(CamelCaseModel):
origin: SystemId origin: NodeId
command: DownloadCommand command: DownloadCommand

View File

@@ -25,10 +25,6 @@ class NodeId(Id):
pass pass
class SystemId(Id):
pass
class ModelId(Id): class ModelId(Id):
def normalize(self) -> str: def normalize(self) -> str:
return self.replace("/", "--") return self.replace("/", "--")

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId, SystemId from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId from exo.shared.types.worker.instances import Instance, InstanceId
@@ -170,19 +170,10 @@ class IndexedEvent(CamelCaseModel):
event: Event event: Event
class GlobalForwarderEvent(CamelCaseModel): class ForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network""" """An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0) origin_idx: int = Field(ge=0)
origin: NodeId origin: NodeId
session: SessionId session: SessionId
event: Event event: Event
class LocalForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: SystemId
session: SessionId
event: Event

View File

@@ -1,6 +1,7 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from random import random from random import random
from typing import Iterator
import anyio import anyio
from anyio import CancelScope, create_task_group, fail_after from anyio import CancelScope, create_task_group, fail_after
@@ -16,14 +17,13 @@ from exo.shared.types.commands import (
RequestEventLog, RequestEventLog,
StartDownload, StartDownload,
) )
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ( from exo.shared.types.events import (
Event, Event,
EventId, EventId,
GlobalForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
InputChunkReceived, InputChunkReceived,
LocalForwarderEvent,
NodeGatheredInfo, NodeGatheredInfo,
TaskCreated, TaskCreated,
TaskStatusUpdated, TaskStatusUpdated,
@@ -58,22 +58,24 @@ class Worker:
node_id: NodeId, node_id: NodeId,
session_id: SessionId, session_id: SessionId,
*, *,
global_event_receiver: Receiver[GlobalForwarderEvent], global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[LocalForwarderEvent], local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now, # This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands # but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand], command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand], download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
): ):
self.node_id: NodeId = node_id self.node_id: NodeId = node_id
self.session_id: SessionId = session_id self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.command_sender = command_sender self.command_sender = command_sender
self.download_command_sender = download_command_sender self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]() self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {} self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State() self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {} self.runners: dict[RunnerId, RunnerSupervisor] = {}
@@ -84,8 +86,6 @@ class Worker:
self._nack_base_seconds: float = 0.5 self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0 self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]() self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing) # Buffer for input image chunks (for image editing)
@@ -132,8 +132,6 @@ class Worker:
async def _event_applier(self): async def _event_applier(self):
with self.global_event_receiver as events: with self.global_event_receiver as events:
async for f_event in events: async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id: if f_event.origin != self.session_id.master_node_id:
continue continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event) self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -214,7 +212,7 @@ class Worker:
await self.download_command_sender.send( await self.download_command_sender.send(
ForwarderDownloadCommand( ForwarderDownloadCommand(
origin=self._system_id, origin=self.node_id,
command=StartDownload( command=StartDownload(
target_node_id=self.node_id, target_node_id=self.node_id,
shard_metadata=shard, shard_metadata=shard,
@@ -243,11 +241,6 @@ class Worker:
cancelled_task_id=cancelled_task_id, runner_id=runner_id cancelled_task_id=cancelled_task_id, runner_id=runner_id
): ):
await self.runners[runner_id].cancel_task(cancelled_task_id) await self.runners[runner_id].cancel_task(cancelled_task_id)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0: case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task # Assemble image from chunks and inject into task
cmd_id = task.command_id cmd_id = task.command_id
@@ -319,7 +312,7 @@ class Worker:
) )
await self.command_sender.send( await self.command_sender.send(
ForwarderCommand( ForwarderCommand(
origin=self._system_id, origin=self.node_id,
command=RequestEventLog(since_idx=since_idx), command=RequestEventLog(since_idx=since_idx),
) )
) )
@@ -346,16 +339,15 @@ class Worker:
return runner return runner
async def _forward_events(self) -> None: async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events: with self.event_receiver as events:
async for event in events: async for event in events:
fe = LocalForwarderEvent( idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx, origin_idx=idx,
origin=self._system_id, origin=self.node_id,
session=self.session_id, session=self.session_id,
event=event, event=event,
) )
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}") logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe) await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe self.out_for_delivery[event.event_id] = fe