mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-20 07:46:42 -05:00
Compare commits
4 Commits
meta-insta
...
session-id
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
932878c9e7 | ||
|
|
4c4c6ce99f | ||
|
|
42e1e7322b | ||
|
|
aa3f106fb9 |
@@ -20,6 +20,7 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
|
run_planning_phase,
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -962,6 +963,21 @@ Examples:
|
|||||||
|
|
||||||
selected.sort(key=_placement_sort_key)
|
selected.sort(key=_placement_sort_key)
|
||||||
preview = selected[0]
|
preview = selected[0]
|
||||||
|
|
||||||
|
settle_deadline = (
|
||||||
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Planning phase: checking downloads...", file=log)
|
||||||
|
run_planning_phase(
|
||||||
|
exo,
|
||||||
|
full_model_id,
|
||||||
|
preview,
|
||||||
|
args.danger_delete_downloads,
|
||||||
|
args.timeout,
|
||||||
|
settle_deadline,
|
||||||
|
)
|
||||||
|
|
||||||
instance = preview["instance"]
|
instance = preview["instance"]
|
||||||
instance_id = instance_id_from_instance(instance)
|
instance_id = instance_id_from_instance(instance)
|
||||||
sharding = str(preview["sharding"])
|
sharding = str(preview["sharding"])
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
|
run_planning_phase,
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -332,6 +333,20 @@ def main() -> int:
|
|||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
settle_deadline = (
|
||||||
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Planning phase: checking downloads...")
|
||||||
|
run_planning_phase(
|
||||||
|
client,
|
||||||
|
full_model_id,
|
||||||
|
selected[0],
|
||||||
|
args.danger_delete_downloads,
|
||||||
|
args.timeout,
|
||||||
|
settle_deadline,
|
||||||
|
)
|
||||||
|
|
||||||
all_rows: list[dict[str, Any]] = []
|
all_rows: list[dict[str, Any]] = []
|
||||||
|
|
||||||
for preview in selected:
|
for preview in selected:
|
||||||
|
|||||||
150
bench/harness.py
150
bench/harness.py
@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
|
|||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
def run_planning_phase(
|
||||||
|
client: ExoClient,
|
||||||
|
full_model_id: str,
|
||||||
|
preview: dict[str, Any],
|
||||||
|
danger_delete: bool,
|
||||||
|
timeout: float,
|
||||||
|
settle_deadline: float | None,
|
||||||
|
) -> None:
|
||||||
|
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||||
|
# Get model size from /models
|
||||||
|
models = client.request_json("GET", "/models") or {}
|
||||||
|
model_bytes = 0
|
||||||
|
for m in models.get("data", []):
|
||||||
|
if m.get("hugging_face_id") == full_model_id:
|
||||||
|
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||||
|
break
|
||||||
|
|
||||||
|
if not model_bytes:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get nodes from preview
|
||||||
|
inner = unwrap_instance(preview["instance"])
|
||||||
|
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||||
|
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||||
|
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
downloads = state.get("downloads", {})
|
||||||
|
node_disk = state.get("nodeDisk", {})
|
||||||
|
|
||||||
|
for node_id in node_ids:
|
||||||
|
node_downloads = downloads.get(node_id, [])
|
||||||
|
|
||||||
|
# Check if model already downloaded on this node
|
||||||
|
already_downloaded = any(
|
||||||
|
"DownloadCompleted" in p
|
||||||
|
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
]
|
||||||
|
== full_model_id
|
||||||
|
for p in node_downloads
|
||||||
|
)
|
||||||
|
if already_downloaded:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Wait for disk info if settle_deadline is set
|
||||||
|
disk_info = node_disk.get(node_id, {})
|
||||||
|
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||||
|
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||||
|
remaining = settle_deadline - time.monotonic()
|
||||||
|
logger.info(
|
||||||
|
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||||
|
)
|
||||||
|
time.sleep(min(backoff, remaining))
|
||||||
|
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
node_disk = state.get("nodeDisk", {})
|
||||||
|
disk_info = node_disk.get(node_id, {})
|
||||||
|
|
||||||
|
if not disk_info:
|
||||||
|
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||||
|
continue
|
||||||
|
|
||||||
|
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||||
|
if avail >= model_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not danger_delete:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||||
|
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete from smallest to largest
|
||||||
|
completed = [
|
||||||
|
(
|
||||||
|
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
],
|
||||||
|
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||||
|
)
|
||||||
|
for p in node_downloads
|
||||||
|
if "DownloadCompleted" in p
|
||||||
|
]
|
||||||
|
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||||
|
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||||
|
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||||
|
avail += size
|
||||||
|
if avail >= model_bytes:
|
||||||
|
break
|
||||||
|
|
||||||
|
if avail < model_bytes:
|
||||||
|
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||||
|
|
||||||
|
# Start downloads (idempotent)
|
||||||
|
for node_id in node_ids:
|
||||||
|
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||||
|
shard = runner_to_shard[runner_id]
|
||||||
|
client.request_json(
|
||||||
|
"POST",
|
||||||
|
"/download/start",
|
||||||
|
body={
|
||||||
|
"targetNodeId": node_id,
|
||||||
|
"shardMetadata": shard,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(f"Started download on {node_id}")
|
||||||
|
|
||||||
|
# Wait for downloads
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
downloads = state.get("downloads", {})
|
||||||
|
all_done = True
|
||||||
|
for node_id in node_ids:
|
||||||
|
done = any(
|
||||||
|
"DownloadCompleted" in p
|
||||||
|
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||||
|
"modelCard"
|
||||||
|
]["modelId"]
|
||||||
|
== full_model_id
|
||||||
|
for p in downloads.get(node_id, [])
|
||||||
|
)
|
||||||
|
failed = [
|
||||||
|
p["DownloadFailed"]["errorMessage"]
|
||||||
|
for p in downloads.get(node_id, [])
|
||||||
|
if "DownloadFailed" in p
|
||||||
|
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
]
|
||||||
|
== full_model_id
|
||||||
|
]
|
||||||
|
if failed:
|
||||||
|
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||||
|
if not done:
|
||||||
|
all_done = False
|
||||||
|
if all_done:
|
||||||
|
return
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
raise TimeoutError("Downloads did not complete in time")
|
||||||
|
|
||||||
|
|
||||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
|||||||
default=0,
|
default=0,
|
||||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--danger-delete-downloads",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||||
|
)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ConnectionUpdate:
|
|||||||
Whether this is a connection or disconnection event
|
Whether this is a connection or disconnection event
|
||||||
"""
|
"""
|
||||||
@property
|
@property
|
||||||
def peer_id(self) -> PeerId:
|
def peer_id(self) -> builtins.str:
|
||||||
r"""
|
r"""
|
||||||
Identity of the peer that we have connected to or disconnected from.
|
Identity of the peer that we have connected to or disconnected from.
|
||||||
"""
|
"""
|
||||||
@@ -40,92 +40,22 @@ class Keypair:
|
|||||||
Identity keypair of a node.
|
Identity keypair of a node.
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_ed25519() -> Keypair:
|
def generate() -> Keypair:
|
||||||
r"""
|
r"""
|
||||||
Generate a new Ed25519 keypair.
|
Generate a new Ed25519 keypair.
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_ecdsa() -> Keypair:
|
def from_bytes(bytes: bytes) -> Keypair:
|
||||||
r"""
|
r"""
|
||||||
Generate a new ECDSA keypair.
|
Construct an Ed25519 keypair from secret key bytes
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def generate_secp256k1() -> Keypair:
|
|
||||||
r"""
|
|
||||||
Generate a new Secp256k1 keypair.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
|
||||||
format (i.e. unencrypted) as defined in [RFC5208].
|
|
||||||
|
|
||||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
|
||||||
structure as defined in [RFC5915].
|
|
||||||
|
|
||||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
|
||||||
def to_protobuf_encoding(self) -> bytes:
|
|
||||||
r"""
|
|
||||||
Encode a private key as protobuf structure.
|
|
||||||
"""
|
|
||||||
def to_peer_id(self) -> PeerId:
|
|
||||||
r"""
|
|
||||||
Convert the `Keypair` into the corresponding `PeerId`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class Multiaddr:
|
|
||||||
r"""
|
|
||||||
Representation of a Multiaddr.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def empty() -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Create a new, empty multiaddress.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Create a new, empty multiaddress with the given capacity.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Parse a `Multiaddr` value from its byte slice representation.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_string(string: builtins.str) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Parse a `Multiaddr` value from its string representation.
|
|
||||||
"""
|
|
||||||
def len(self) -> builtins.int:
|
|
||||||
r"""
|
|
||||||
Return the length in bytes of this multiaddress.
|
|
||||||
"""
|
|
||||||
def is_empty(self) -> builtins.bool:
|
|
||||||
r"""
|
|
||||||
Returns true if the length of this multiaddress is 0.
|
|
||||||
"""
|
"""
|
||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
r"""
|
r"""
|
||||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
Get the secret key bytes underlying the keypair
|
||||||
"""
|
"""
|
||||||
def to_string(self) -> builtins.str:
|
def to_node_id(self) -> builtins.str:
|
||||||
r"""
|
r"""
|
||||||
Convert a Multiaddr to a string.
|
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
|||||||
def __repr__(self) -> builtins.str: ...
|
def __repr__(self) -> builtins.str: ...
|
||||||
def __str__(self) -> builtins.str: ...
|
def __str__(self) -> builtins.str: ...
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class PeerId:
|
|
||||||
r"""
|
|
||||||
Identifier of a peer of the network.
|
|
||||||
|
|
||||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
|
||||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def random() -> PeerId:
|
|
||||||
r"""
|
|
||||||
Generates a random peer ID from a cryptographically secure PRNG.
|
|
||||||
|
|
||||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_bytes(bytes: bytes) -> PeerId:
|
|
||||||
r"""
|
|
||||||
Parses a `PeerId` from bytes.
|
|
||||||
"""
|
|
||||||
def to_bytes(self) -> bytes:
|
|
||||||
r"""
|
|
||||||
Returns a raw bytes representation of this `PeerId`.
|
|
||||||
"""
|
|
||||||
def to_base58(self) -> builtins.str:
|
|
||||||
r"""
|
|
||||||
Returns a base-58 encoded string of this `PeerId`.
|
|
||||||
"""
|
|
||||||
def __repr__(self) -> builtins.str: ...
|
|
||||||
def __str__(self) -> builtins.str: ...
|
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
class ConnectionUpdateType(enum.Enum):
|
class ConnectionUpdateType(enum.Enum):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
use crate::ext::ResultExt as _;
|
use crate::ext::ResultExt as _;
|
||||||
use libp2p::PeerId;
|
|
||||||
use libp2p::identity::Keypair;
|
use libp2p::identity::Keypair;
|
||||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
use pyo3::types::{PyBytes, PyBytesMethods as _};
|
||||||
use pyo3::types::PyBytes;
|
|
||||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||||
|
|
||||||
@@ -18,142 +16,32 @@ pub struct PyKeypair(pub Keypair);
|
|||||||
impl PyKeypair {
|
impl PyKeypair {
|
||||||
/// Generate a new Ed25519 keypair.
|
/// Generate a new Ed25519 keypair.
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn generate_ed25519() -> Self {
|
fn generate() -> Self {
|
||||||
Self(Keypair::generate_ed25519())
|
Self(Keypair::generate_ed25519())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a new ECDSA keypair.
|
/// Construct an Ed25519 keypair from secret key bytes
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn generate_ecdsa() -> Self {
|
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||||
Self(Keypair::generate_ecdsa())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a new Secp256k1 keypair.
|
|
||||||
#[staticmethod]
|
|
||||||
fn generate_secp256k1() -> Self {
|
|
||||||
Self(Keypair::generate_secp256k1())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
|
||||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
|
||||||
///
|
|
||||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
|
||||||
#[staticmethod]
|
|
||||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
|
||||||
/// structure as defined in [RFC5915].
|
|
||||||
///
|
|
||||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
|
||||||
#[staticmethod]
|
|
||||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[staticmethod]
|
|
||||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
let mut bytes = Vec::from(bytes.as_bytes());
|
||||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encode a private key as protobuf structure.
|
/// Get the secret key bytes underlying the keypair
|
||||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
let bytes = self
|
||||||
|
.0
|
||||||
|
.clone()
|
||||||
|
.try_into_ed25519()
|
||||||
|
.pyerr()?
|
||||||
|
.secret()
|
||||||
|
.as_ref()
|
||||||
|
.to_vec();
|
||||||
Ok(PyBytes::new(py, &bytes))
|
Ok(PyBytes::new(py, &bytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||||
fn to_peer_id(&self) -> PyPeerId {
|
fn to_node_id(&self) -> String {
|
||||||
PyPeerId(self.0.public().to_peer_id())
|
self.0.public().to_peer_id().to_base58()
|
||||||
}
|
|
||||||
|
|
||||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// #[new]
|
|
||||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
// Self::from_protobuf_encoding(bytes)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
|
||||||
// *self = Self::from_protobuf_encoding(state)?;
|
|
||||||
// Ok(())
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
|
||||||
// self.to_protobuf_encoding(py)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
|
||||||
// Ok((self.to_protobuf_encoding(py)?,))
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Identifier of a peer of the network.
|
|
||||||
///
|
|
||||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
|
||||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
|
||||||
#[gen_stub_pyclass]
|
|
||||||
#[pyclass(name = "PeerId", frozen)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct PyPeerId(pub PeerId);
|
|
||||||
|
|
||||||
#[gen_stub_pymethods]
|
|
||||||
#[pymethods]
|
|
||||||
#[allow(clippy::needless_pass_by_value)]
|
|
||||||
impl PyPeerId {
|
|
||||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
|
||||||
///
|
|
||||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
|
||||||
#[staticmethod]
|
|
||||||
fn random() -> Self {
|
|
||||||
Self(PeerId::random())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses a `PeerId` from bytes.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a raw bytes representation of this `PeerId`.
|
|
||||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
|
||||||
let bytes = self.0.to_bytes();
|
|
||||||
PyBytes::new(py, &bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a base-58 encoded string of this `PeerId`.
|
|
||||||
fn to_base58(&self) -> String {
|
|
||||||
self.0.to_base58()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __repr__(&self) -> String {
|
|
||||||
format!("PeerId({})", self.to_base58())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __str__(&self) -> String {
|
|
||||||
self.to_base58()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
||||||
m.add_class::<PyKeypair>()?;
|
|
||||||
m.add_class::<PyPeerId>()?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ mod allow_threading;
|
|||||||
mod ident;
|
mod ident;
|
||||||
mod networking;
|
mod networking;
|
||||||
|
|
||||||
use crate::ident::ident_submodule;
|
use crate::ident::PyKeypair;
|
||||||
use crate::networking::networking_submodule;
|
use crate::networking::networking_submodule;
|
||||||
use pyo3::prelude::PyModule;
|
use pyo3::prelude::PyModule;
|
||||||
|
use pyo3::types::PyModuleMethods;
|
||||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||||
|
|
||||||
@@ -158,7 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||||
// work with maturin, where the types generate correctly, in the right folder, without
|
// work with maturin, where the types generate correctly, in the right folder, without
|
||||||
// too many importing issues...
|
// too many importing issues...
|
||||||
ident_submodule(m)?;
|
m.add_class::<PyKeypair>()?;
|
||||||
networking_submodule(m)?;
|
networking_submodule(m)?;
|
||||||
|
|
||||||
// top-level constructs
|
// top-level constructs
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||||
use crate::ident::{PyKeypair, PyPeerId};
|
use crate::ident::PyKeypair;
|
||||||
use crate::pyclass;
|
use crate::pyclass;
|
||||||
use libp2p::futures::StreamExt as _;
|
use libp2p::futures::StreamExt as _;
|
||||||
use libp2p::gossipsub;
|
use libp2p::gossipsub;
|
||||||
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
|
|||||||
|
|
||||||
/// Identity of the peer that we have connected to or disconnected from.
|
/// Identity of the peer that we have connected to or disconnected from.
|
||||||
#[pyo3(get)]
|
#[pyo3(get)]
|
||||||
peer_id: PyPeerId,
|
peer_id: String,
|
||||||
|
|
||||||
/// Remote connection's IPv4 address.
|
/// Remote connection's IPv4 address.
|
||||||
#[pyo3(get)]
|
#[pyo3(get)]
|
||||||
@@ -251,7 +251,7 @@ async fn networking_task(
|
|||||||
// send connection event to channel (or exit if connection closed)
|
// send connection event to channel (or exit if connection closed)
|
||||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||||
update_type: PyConnectionUpdateType::Connected,
|
update_type: PyConnectionUpdateType::Connected,
|
||||||
peer_id: PyPeerId(peer_id),
|
peer_id: peer_id.to_base58(),
|
||||||
remote_ipv4,
|
remote_ipv4,
|
||||||
remote_tcp_port,
|
remote_tcp_port,
|
||||||
}).await {
|
}).await {
|
||||||
@@ -272,7 +272,7 @@ async fn networking_task(
|
|||||||
// send disconnection event to channel (or exit if connection closed)
|
// send disconnection event to channel (or exit if connection closed)
|
||||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||||
update_type: PyConnectionUpdateType::Disconnected,
|
update_type: PyConnectionUpdateType::Disconnected,
|
||||||
peer_id: PyPeerId(peer_id),
|
peer_id: peer_id.to_base58(),
|
||||||
remote_ipv4,
|
remote_ipv4,
|
||||||
remote_tcp_port,
|
remote_tcp_port,
|
||||||
}).await {
|
}).await {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import socket
|
import socket
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Iterator
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio import current_time
|
from anyio import current_time
|
||||||
@@ -22,10 +21,10 @@ from exo.shared.types.commands import (
|
|||||||
ForwarderDownloadCommand,
|
ForwarderDownloadCommand,
|
||||||
StartDownload,
|
StartDownload,
|
||||||
)
|
)
|
||||||
from exo.shared.types.common import NodeId, SessionId
|
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
Event,
|
Event,
|
||||||
ForwarderEvent,
|
LocalForwarderEvent,
|
||||||
NodeDownloadProgress,
|
NodeDownloadProgress,
|
||||||
)
|
)
|
||||||
from exo.shared.types.worker.downloads import (
|
from exo.shared.types.worker.downloads import (
|
||||||
@@ -45,9 +44,9 @@ class DownloadCoordinator:
|
|||||||
session_id: SessionId
|
session_id: SessionId
|
||||||
shard_downloader: ShardDownloader
|
shard_downloader: ShardDownloader
|
||||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||||
local_event_sender: Sender[ForwarderEvent]
|
local_event_sender: Sender[LocalForwarderEvent]
|
||||||
event_index_counter: Iterator[int]
|
|
||||||
offline: bool = False
|
offline: bool = False
|
||||||
|
_system_id: SystemId = field(default_factory=SystemId)
|
||||||
|
|
||||||
# Local state
|
# Local state
|
||||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||||
@@ -298,15 +297,16 @@ class DownloadCoordinator:
|
|||||||
del self.download_status[model_id]
|
del self.download_status[model_id]
|
||||||
|
|
||||||
async def _forward_events(self) -> None:
|
async def _forward_events(self) -> None:
|
||||||
|
idx = 0
|
||||||
with self.event_receiver as events:
|
with self.event_receiver as events:
|
||||||
async for event in events:
|
async for event in events:
|
||||||
idx = next(self.event_index_counter)
|
fe = LocalForwarderEvent(
|
||||||
fe = ForwarderEvent(
|
|
||||||
origin_idx=idx,
|
origin_idx=idx,
|
||||||
origin=self.node_id,
|
origin=self._system_id,
|
||||||
session=self.session_id,
|
session=self.session_id,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
|
idx += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||||
)
|
)
|
||||||
@@ -338,17 +338,7 @@ class DownloadCoordinator:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif progress.status in ["in_progress", "not_started"]:
|
elif progress.status in ["in_progress", "not_started"]:
|
||||||
if (
|
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||||
progress.downloaded_bytes.in_bytes
|
|
||||||
>= progress.total_bytes.in_bytes
|
|
||||||
> 0
|
|
||||||
):
|
|
||||||
status = DownloadCompleted(
|
|
||||||
node_id=self.node_id,
|
|
||||||
shard_metadata=progress.shard,
|
|
||||||
total_bytes=progress.total_bytes,
|
|
||||||
)
|
|
||||||
elif progress.downloaded_bytes_this_session.in_bytes == 0:
|
|
||||||
status = DownloadPending(
|
status = DownloadPending(
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
shard_metadata=progress.shard,
|
shard_metadata=progress.shard,
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import itertools
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import resource
|
import resource
|
||||||
import signal
|
import signal
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Iterator, Self
|
from typing import Self
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.abc import TaskGroup
|
from anyio.abc import TaskGroup
|
||||||
@@ -38,14 +37,13 @@ class Node:
|
|||||||
api: API | None
|
api: API | None
|
||||||
|
|
||||||
node_id: NodeId
|
node_id: NodeId
|
||||||
event_index_counter: Iterator[int]
|
|
||||||
offline: bool
|
offline: bool
|
||||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, args: "Args") -> "Self":
|
async def create(cls, args: "Args") -> Self:
|
||||||
keypair = get_node_id_keypair()
|
keypair = get_node_id_keypair()
|
||||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
node_id = NodeId(keypair.to_node_id())
|
||||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||||
router = Router.create(keypair)
|
router = Router.create(keypair)
|
||||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||||
@@ -57,9 +55,6 @@ class Node:
|
|||||||
|
|
||||||
logger.info(f"Starting node {node_id}")
|
logger.info(f"Starting node {node_id}")
|
||||||
|
|
||||||
# Create shared event index counter for Worker and DownloadCoordinator
|
|
||||||
event_index_counter = itertools.count()
|
|
||||||
|
|
||||||
# Create DownloadCoordinator (unless --no-downloads)
|
# Create DownloadCoordinator (unless --no-downloads)
|
||||||
if not args.no_downloads:
|
if not args.no_downloads:
|
||||||
download_coordinator = DownloadCoordinator(
|
download_coordinator = DownloadCoordinator(
|
||||||
@@ -68,7 +63,6 @@ class Node:
|
|||||||
exo_shard_downloader(),
|
exo_shard_downloader(),
|
||||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||||
event_index_counter=event_index_counter,
|
|
||||||
offline=args.offline,
|
offline=args.offline,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -95,7 +89,6 @@ class Node:
|
|||||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||||
command_sender=router.sender(topics.COMMANDS),
|
command_sender=router.sender(topics.COMMANDS),
|
||||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||||
event_index_counter=event_index_counter,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
worker = None
|
worker = None
|
||||||
@@ -133,7 +126,6 @@ class Node:
|
|||||||
master,
|
master,
|
||||||
api,
|
api,
|
||||||
node_id,
|
node_id,
|
||||||
event_index_counter,
|
|
||||||
args.offline,
|
args.offline,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,8 +204,6 @@ class Node:
|
|||||||
)
|
)
|
||||||
if result.is_new_master:
|
if result.is_new_master:
|
||||||
await anyio.sleep(0)
|
await anyio.sleep(0)
|
||||||
# Fresh counter for new session (buffer expects indices from 0)
|
|
||||||
self.event_index_counter = itertools.count()
|
|
||||||
if self.download_coordinator:
|
if self.download_coordinator:
|
||||||
self.download_coordinator.shutdown()
|
self.download_coordinator.shutdown()
|
||||||
self.download_coordinator = DownloadCoordinator(
|
self.download_coordinator = DownloadCoordinator(
|
||||||
@@ -224,7 +214,6 @@ class Node:
|
|||||||
topics.DOWNLOAD_COMMANDS
|
topics.DOWNLOAD_COMMANDS
|
||||||
),
|
),
|
||||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||||
event_index_counter=self.event_index_counter,
|
|
||||||
offline=self.offline,
|
offline=self.offline,
|
||||||
)
|
)
|
||||||
self._tg.start_soon(self.download_coordinator.run)
|
self._tg.start_soon(self.download_coordinator.run)
|
||||||
@@ -242,7 +231,6 @@ class Node:
|
|||||||
download_command_sender=self.router.sender(
|
download_command_sender=self.router.sender(
|
||||||
topics.DOWNLOAD_COMMANDS
|
topics.DOWNLOAD_COMMANDS
|
||||||
),
|
),
|
||||||
event_index_counter=self.event_index_counter,
|
|
||||||
)
|
)
|
||||||
self._tg.start_soon(self.worker.run)
|
self._tg.start_soon(self.worker.run)
|
||||||
if self.api:
|
if self.api:
|
||||||
@@ -258,7 +246,7 @@ def main():
|
|||||||
target = min(max(soft, 65535), hard)
|
target = min(max(soft, 65535), hard)
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||||
|
|
||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn")
|
||||||
# TODO: Refactor the current verbosity system
|
# TODO: Refactor the current verbosity system
|
||||||
logger_setup(EXO_LOG, args.verbosity)
|
logger_setup(EXO_LOG, args.verbosity)
|
||||||
logger.info("Starting EXO")
|
logger.info("Starting EXO")
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
|
|||||||
ResponseOutputText,
|
ResponseOutputText,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
|
ResponsesStreamEvent,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
|
|||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
|
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||||
|
"""Format a streaming event as an SSE message."""
|
||||||
|
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -219,13 +225,13 @@ async def generate_responses_stream(
|
|||||||
created_event = ResponseCreatedEvent(
|
created_event = ResponseCreatedEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
yield _format_sse(created_event)
|
||||||
|
|
||||||
# response.in_progress
|
# response.in_progress
|
||||||
in_progress_event = ResponseInProgressEvent(
|
in_progress_event = ResponseInProgressEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
yield _format_sse(in_progress_event)
|
||||||
|
|
||||||
# response.output_item.added
|
# response.output_item.added
|
||||||
initial_item = ResponseMessageItem(
|
initial_item = ResponseMessageItem(
|
||||||
@@ -236,7 +242,7 @@ async def generate_responses_stream(
|
|||||||
item_added = ResponseOutputItemAddedEvent(
|
item_added = ResponseOutputItemAddedEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=initial_item
|
sequence_number=next(seq), output_index=0, item=initial_item
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
yield _format_sse(item_added)
|
||||||
|
|
||||||
# response.content_part.added
|
# response.content_part.added
|
||||||
initial_part = ResponseOutputText(text="")
|
initial_part = ResponseOutputText(text="")
|
||||||
@@ -247,7 +253,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=initial_part,
|
part=initial_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
yield _format_sse(part_added)
|
||||||
|
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
function_call_items: list[ResponseFunctionCallItem] = []
|
function_call_items: list[ResponseFunctionCallItem] = []
|
||||||
@@ -281,7 +287,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_item,
|
item=fc_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
yield _format_sse(fc_added)
|
||||||
|
|
||||||
# response.function_call_arguments.delta
|
# response.function_call_arguments.delta
|
||||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||||
@@ -290,7 +296,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
delta=tool.arguments,
|
delta=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
yield _format_sse(args_delta)
|
||||||
|
|
||||||
# response.function_call_arguments.done
|
# response.function_call_arguments.done
|
||||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||||
@@ -300,7 +306,7 @@ async def generate_responses_stream(
|
|||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool.arguments,
|
arguments=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
yield _format_sse(args_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
fc_done_item = ResponseFunctionCallItem(
|
fc_done_item = ResponseFunctionCallItem(
|
||||||
@@ -315,7 +321,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_done_item,
|
item=fc_done_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
yield _format_sse(fc_item_done)
|
||||||
|
|
||||||
function_call_items.append(fc_done_item)
|
function_call_items.append(fc_done_item)
|
||||||
next_output_index += 1
|
next_output_index += 1
|
||||||
@@ -331,7 +337,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
delta=chunk.text,
|
delta=chunk.text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
yield _format_sse(delta_event)
|
||||||
|
|
||||||
# response.output_text.done
|
# response.output_text.done
|
||||||
text_done = ResponseTextDoneEvent(
|
text_done = ResponseTextDoneEvent(
|
||||||
@@ -341,7 +347,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
text=accumulated_text,
|
text=accumulated_text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
yield _format_sse(text_done)
|
||||||
|
|
||||||
# response.content_part.done
|
# response.content_part.done
|
||||||
final_part = ResponseOutputText(text=accumulated_text)
|
final_part = ResponseOutputText(text=accumulated_text)
|
||||||
@@ -352,7 +358,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=final_part,
|
part=final_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
yield _format_sse(part_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
final_message_item = ResponseMessageItem(
|
final_message_item = ResponseMessageItem(
|
||||||
@@ -363,7 +369,7 @@ async def generate_responses_stream(
|
|||||||
item_done = ResponseOutputItemDoneEvent(
|
item_done = ResponseOutputItemDoneEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
yield _format_sse(item_done)
|
||||||
|
|
||||||
# Create usage from usage data if available
|
# Create usage from usage data if available
|
||||||
usage = None
|
usage = None
|
||||||
@@ -388,4 +394,4 @@ async def generate_responses_stream(
|
|||||||
completed_event = ResponseCompletedEvent(
|
completed_event = ResponseCompletedEvent(
|
||||||
sequence_number=next(seq), response=final_response
|
sequence_number=next(seq), response=final_response
|
||||||
)
|
)
|
||||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
yield _format_sse(completed_event)
|
||||||
|
|||||||
@@ -132,11 +132,11 @@ from exo.shared.types.commands import (
|
|||||||
TaskFinished,
|
TaskFinished,
|
||||||
TextGeneration,
|
TextGeneration,
|
||||||
)
|
)
|
||||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
ChunkGenerated,
|
ChunkGenerated,
|
||||||
Event,
|
Event,
|
||||||
ForwarderEvent,
|
GlobalForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
PrefillProgress,
|
PrefillProgress,
|
||||||
TracesMerged,
|
TracesMerged,
|
||||||
@@ -177,8 +177,7 @@ class API:
|
|||||||
session_id: SessionId,
|
session_id: SessionId,
|
||||||
*,
|
*,
|
||||||
port: int,
|
port: int,
|
||||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
global_event_receiver: Receiver[GlobalForwarderEvent],
|
||||||
global_event_receiver: Receiver[ForwarderEvent],
|
|
||||||
command_sender: Sender[ForwarderCommand],
|
command_sender: Sender[ForwarderCommand],
|
||||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||||
# This lets us pause the API if an election is running
|
# This lets us pause the API if an election is running
|
||||||
@@ -186,6 +185,7 @@ class API:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.state = State()
|
self.state = State()
|
||||||
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
||||||
|
self._system_id = SystemId()
|
||||||
self.command_sender = command_sender
|
self.command_sender = command_sender
|
||||||
self.download_command_sender = download_command_sender
|
self.download_command_sender = download_command_sender
|
||||||
self.global_event_receiver = global_event_receiver
|
self.global_event_receiver = global_event_receiver
|
||||||
@@ -237,6 +237,7 @@ class API:
|
|||||||
self._event_log.close()
|
self._event_log.close()
|
||||||
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
||||||
self.state = State()
|
self.state = State()
|
||||||
|
self._system_id = SystemId()
|
||||||
self.session_id = new_session_id
|
self.session_id = new_session_id
|
||||||
self.event_buffer = OrderedBuffer[Event]()
|
self.event_buffer = OrderedBuffer[Event]()
|
||||||
self._text_generation_queues = {}
|
self._text_generation_queues = {}
|
||||||
@@ -554,7 +555,7 @@ class API:
|
|||||||
command = TaskCancelled(cancelled_command_id=command_id)
|
command = TaskCancelled(cancelled_command_id=command_id)
|
||||||
with anyio.CancelScope(shield=True):
|
with anyio.CancelScope(shield=True):
|
||||||
await self.command_sender.send(
|
await self.command_sender.send(
|
||||||
ForwarderCommand(origin=self.node_id, command=command)
|
ForwarderCommand(origin=self._system_id, command=command)
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
@@ -902,7 +903,7 @@ class API:
|
|||||||
command = TaskCancelled(cancelled_command_id=command_id)
|
command = TaskCancelled(cancelled_command_id=command_id)
|
||||||
with anyio.CancelScope(shield=True):
|
with anyio.CancelScope(shield=True):
|
||||||
await self.command_sender.send(
|
await self.command_sender.send(
|
||||||
ForwarderCommand(origin=self.node_id, command=command)
|
ForwarderCommand(origin=self._system_id, command=command)
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
@@ -988,7 +989,7 @@ class API:
|
|||||||
command = TaskCancelled(cancelled_command_id=command_id)
|
command = TaskCancelled(cancelled_command_id=command_id)
|
||||||
with anyio.CancelScope(shield=True):
|
with anyio.CancelScope(shield=True):
|
||||||
await self.command_sender.send(
|
await self.command_sender.send(
|
||||||
ForwarderCommand(origin=self.node_id, command=command)
|
ForwarderCommand(origin=self._system_id, command=command)
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
@@ -1429,6 +1430,8 @@ class API:
|
|||||||
async def _apply_state(self):
|
async def _apply_state(self):
|
||||||
with self.global_event_receiver as events:
|
with self.global_event_receiver as events:
|
||||||
async for f_event in events:
|
async for f_event in events:
|
||||||
|
if f_event.session != self.session_id:
|
||||||
|
continue
|
||||||
if f_event.origin != self.session_id.master_node_id:
|
if f_event.origin != self.session_id.master_node_id:
|
||||||
continue
|
continue
|
||||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||||
@@ -1508,12 +1511,12 @@ class API:
|
|||||||
while self.paused:
|
while self.paused:
|
||||||
await self.paused_ev.wait()
|
await self.paused_ev.wait()
|
||||||
await self.command_sender.send(
|
await self.command_sender.send(
|
||||||
ForwarderCommand(origin=self.node_id, command=command)
|
ForwarderCommand(origin=self._system_id, command=command)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _send_download(self, command: DownloadCommand):
|
async def _send_download(self, command: DownloadCommand):
|
||||||
await self.download_command_sender.send(
|
await self.download_command_sender.send(
|
||||||
ForwarderDownloadCommand(origin=self.node_id, command=command)
|
ForwarderDownloadCommand(origin=self._system_id, command=command)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start_download(
|
async def start_download(
|
||||||
|
|||||||
@@ -29,13 +29,14 @@ from exo.shared.types.commands import (
|
|||||||
TestCommand,
|
TestCommand,
|
||||||
TextGeneration,
|
TextGeneration,
|
||||||
)
|
)
|
||||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
Event,
|
Event,
|
||||||
ForwarderEvent,
|
GlobalForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
InstanceDeleted,
|
InstanceDeleted,
|
||||||
|
LocalForwarderEvent,
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
NodeTimedOut,
|
NodeTimedOut,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
@@ -71,8 +72,8 @@ class Master:
|
|||||||
session_id: SessionId,
|
session_id: SessionId,
|
||||||
*,
|
*,
|
||||||
command_receiver: Receiver[ForwarderCommand],
|
command_receiver: Receiver[ForwarderCommand],
|
||||||
local_event_receiver: Receiver[ForwarderEvent],
|
local_event_receiver: Receiver[LocalForwarderEvent],
|
||||||
global_event_sender: Sender[ForwarderEvent],
|
global_event_sender: Sender[GlobalForwarderEvent],
|
||||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||||
):
|
):
|
||||||
self.state = State()
|
self.state = State()
|
||||||
@@ -87,10 +88,11 @@ class Master:
|
|||||||
send, recv = channel[Event]()
|
send, recv = channel[Event]()
|
||||||
self.event_sender: Sender[Event] = send
|
self.event_sender: Sender[Event] = send
|
||||||
self._loopback_event_receiver: Receiver[Event] = recv
|
self._loopback_event_receiver: Receiver[Event] = recv
|
||||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
|
||||||
local_event_receiver.clone_sender()
|
local_event_receiver.clone_sender()
|
||||||
)
|
)
|
||||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
self._system_id = SystemId()
|
||||||
|
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
|
||||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||||
@@ -288,7 +290,7 @@ class Master:
|
|||||||
):
|
):
|
||||||
await self.download_command_sender.send(
|
await self.download_command_sender.send(
|
||||||
ForwarderDownloadCommand(
|
ForwarderDownloadCommand(
|
||||||
origin=self.node_id, command=cmd
|
origin=self._system_id, command=cmd
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generated_events.extend(transition_events)
|
generated_events.extend(transition_events)
|
||||||
@@ -414,8 +416,8 @@ class Master:
|
|||||||
with self._loopback_event_receiver as events:
|
with self._loopback_event_receiver as events:
|
||||||
async for event in events:
|
async for event in events:
|
||||||
await self._loopback_event_sender.send(
|
await self._loopback_event_sender.send(
|
||||||
ForwarderEvent(
|
LocalForwarderEvent(
|
||||||
origin=NodeId(f"master_{self.node_id}"),
|
origin=self._system_id,
|
||||||
origin_idx=local_index,
|
origin_idx=local_index,
|
||||||
session=self.session_id,
|
session=self.session_id,
|
||||||
event=event,
|
event=event,
|
||||||
@@ -427,7 +429,7 @@ class Master:
|
|||||||
async def _send_event(self, event: IndexedEvent):
|
async def _send_event(self, event: IndexedEvent):
|
||||||
# Convenience method since this line is ugly
|
# Convenience method since this line is ugly
|
||||||
await self.global_event_sender.send(
|
await self.global_event_sender.send(
|
||||||
ForwarderEvent(
|
GlobalForwarderEvent(
|
||||||
origin=self.node_id,
|
origin=self.node_id,
|
||||||
origin_idx=event.idx,
|
origin_idx=event.idx,
|
||||||
session=self.session_id,
|
session=self.session_id,
|
||||||
|
|||||||
@@ -15,11 +15,12 @@ from exo.shared.types.commands import (
|
|||||||
PlaceInstance,
|
PlaceInstance,
|
||||||
TextGeneration,
|
TextGeneration,
|
||||||
)
|
)
|
||||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
ForwarderEvent,
|
GlobalForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InstanceCreated,
|
InstanceCreated,
|
||||||
|
LocalForwarderEvent,
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
)
|
)
|
||||||
@@ -42,12 +43,12 @@ from exo.utils.channels import channel
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_master():
|
async def test_master():
|
||||||
keypair = get_node_id_keypair()
|
keypair = get_node_id_keypair()
|
||||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
node_id = NodeId(keypair.to_node_id())
|
||||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||||
|
|
||||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
ge_sender, global_event_receiver = channel[GlobalForwarderEvent]()
|
||||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
|
||||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||||
|
|
||||||
all_events: list[IndexedEvent] = []
|
all_events: list[IndexedEvent] = []
|
||||||
@@ -75,13 +76,12 @@ async def test_master():
|
|||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
tg.start_soon(master.run)
|
tg.start_soon(master.run)
|
||||||
|
|
||||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
|
||||||
# inject a NodeGatheredInfo event
|
# inject a NodeGatheredInfo event
|
||||||
logger.info("inject a NodeGatheredInfo event")
|
logger.info("inject a NodeGatheredInfo event")
|
||||||
await local_event_sender.send(
|
await local_event_sender.send(
|
||||||
ForwarderEvent(
|
LocalForwarderEvent(
|
||||||
origin_idx=0,
|
origin_idx=0,
|
||||||
origin=sender_node_id,
|
origin=SystemId("Worker"),
|
||||||
session=session_id,
|
session=session_id,
|
||||||
event=(
|
event=(
|
||||||
NodeGatheredInfo(
|
NodeGatheredInfo(
|
||||||
@@ -108,7 +108,7 @@ async def test_master():
|
|||||||
logger.info("inject a CreateInstance Command")
|
logger.info("inject a CreateInstance Command")
|
||||||
await command_sender.send(
|
await command_sender.send(
|
||||||
ForwarderCommand(
|
ForwarderCommand(
|
||||||
origin=node_id,
|
origin=SystemId("API"),
|
||||||
command=(
|
command=(
|
||||||
PlaceInstance(
|
PlaceInstance(
|
||||||
command_id=CommandId(),
|
command_id=CommandId(),
|
||||||
@@ -133,7 +133,7 @@ async def test_master():
|
|||||||
logger.info("inject a TextGeneration Command")
|
logger.info("inject a TextGeneration Command")
|
||||||
await command_sender.send(
|
await command_sender.send(
|
||||||
ForwarderCommand(
|
ForwarderCommand(
|
||||||
origin=node_id,
|
origin=SystemId("API"),
|
||||||
command=(
|
command=(
|
||||||
TextGeneration(
|
TextGeneration(
|
||||||
command_id=CommandId(),
|
command_id=CommandId(),
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||||
return cls(
|
return cls(
|
||||||
node_id=NodeId(update.peer_id.to_base58()),
|
node_id=NodeId(update.peer_id),
|
||||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||||
remote_ipv4=update.remote_ipv4,
|
remote_ipv4=update.remote_ipv4,
|
||||||
remote_tcp_port=update.remote_tcp_port,
|
remote_tcp_port=update.remote_tcp_port,
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ def get_node_id_keypair(
|
|||||||
Obtain the :class:`PeerId` by from it.
|
Obtain the :class:`PeerId` by from it.
|
||||||
"""
|
"""
|
||||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||||
return Keypair.generate_ed25519()
|
return Keypair.generate()
|
||||||
|
|
||||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||||
return Path(str(path) + ".lock")
|
return Path(str(path) + ".lock")
|
||||||
@@ -235,12 +235,12 @@ def get_node_id_keypair(
|
|||||||
protobuf_encoded = f.read()
|
protobuf_encoded = f.read()
|
||||||
|
|
||||||
try: # if decoded successfully, save & return
|
try: # if decoded successfully, save & return
|
||||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
return Keypair.from_bytes(protobuf_encoded)
|
||||||
except ValueError as e: # on runtime error, assume corrupt file
|
except ValueError as e: # on runtime error, assume corrupt file
|
||||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||||
|
|
||||||
# if no valid credentials, create new ones and persist
|
# if no valid credentials, create new ones and persist
|
||||||
with open(path, "w+b") as f:
|
with open(path, "w+b") as f:
|
||||||
keypair = Keypair.generate_ed25519()
|
keypair = Keypair.generate_ed25519()
|
||||||
f.write(keypair.to_protobuf_encoding())
|
f.write(keypair.to_bytes())
|
||||||
return keypair
|
return keypair
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from exo.routing.connection_message import ConnectionMessage
|
|||||||
from exo.shared.election import ElectionMessage
|
from exo.shared.election import ElectionMessage
|
||||||
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
ForwarderEvent,
|
GlobalForwarderEvent,
|
||||||
|
LocalForwarderEvent,
|
||||||
)
|
)
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel
|
from exo.utils.pydantic_ext import CamelCaseModel
|
||||||
|
|
||||||
@@ -36,8 +37,8 @@ class TypedTopic[T: CamelCaseModel]:
|
|||||||
return self.model_type.model_validate_json(b.decode("utf-8"))
|
return self.model_type.model_validate_json(b.decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent)
|
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, GlobalForwarderEvent)
|
||||||
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent)
|
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, LocalForwarderEvent)
|
||||||
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
|
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
|
||||||
ELECTION_MESSAGES = TypedTopic(
|
ELECTION_MESSAGES = TypedTopic(
|
||||||
"election_messages", PublishPolicy.Always, ElectionMessage
|
"election_messages", PublishPolicy.Always, ElectionMessage
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
|
|||||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||||
from exo.shared.types.common import NodeId, SessionId
|
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||||
from exo.utils.channels import channel
|
from exo.utils.channels import channel
|
||||||
|
|
||||||
# ======= #
|
# ======= #
|
||||||
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
|
|||||||
# Pump local commands so our commands_seen is high before the round starts
|
# Pump local commands so our commands_seen is high before the round starts
|
||||||
for _ in range(50):
|
for _ in range(50):
|
||||||
await co_tx.send(
|
await co_tx.send(
|
||||||
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
|
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands
|
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
|||||||
sem.release()
|
sem.release()
|
||||||
# wait to be told to begin simultaneous read
|
# wait to be told to begin simultaneous read
|
||||||
ev.wait()
|
ev.wait()
|
||||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
queue.put(get_node_id_keypair().to_bytes())
|
||||||
|
|
||||||
|
|
||||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from exo.shared.types.api import (
|
|||||||
ImageGenerationTaskParams,
|
ImageGenerationTaskParams,
|
||||||
)
|
)
|
||||||
from exo.shared.types.chunks import InputImageChunk
|
from exo.shared.types.chunks import InputImageChunk
|
||||||
from exo.shared.types.common import CommandId, NodeId
|
from exo.shared.types.common import CommandId, NodeId, SystemId
|
||||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||||
@@ -100,10 +100,10 @@ Command = (
|
|||||||
|
|
||||||
|
|
||||||
class ForwarderCommand(CamelCaseModel):
|
class ForwarderCommand(CamelCaseModel):
|
||||||
origin: NodeId
|
origin: SystemId
|
||||||
command: Command
|
command: Command
|
||||||
|
|
||||||
|
|
||||||
class ForwarderDownloadCommand(CamelCaseModel):
|
class ForwarderDownloadCommand(CamelCaseModel):
|
||||||
origin: NodeId
|
origin: SystemId
|
||||||
command: DownloadCommand
|
command: DownloadCommand
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ class NodeId(Id):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SystemId(Id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelId(Id):
|
class ModelId(Id):
|
||||||
def normalize(self) -> str:
|
def normalize(self) -> str:
|
||||||
return self.replace("/", "--")
|
return self.replace("/", "--")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
|||||||
|
|
||||||
from exo.shared.topology import Connection
|
from exo.shared.topology import Connection
|
||||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId, SystemId
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||||
from exo.shared.types.worker.downloads import DownloadProgress
|
from exo.shared.types.worker.downloads import DownloadProgress
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||||
@@ -170,10 +170,19 @@ class IndexedEvent(CamelCaseModel):
|
|||||||
event: Event
|
event: Event
|
||||||
|
|
||||||
|
|
||||||
class ForwarderEvent(CamelCaseModel):
|
class GlobalForwarderEvent(CamelCaseModel):
|
||||||
"""An event the forwarder will serialize and send over the network"""
|
"""An event the forwarder will serialize and send over the network"""
|
||||||
|
|
||||||
origin_idx: int = Field(ge=0)
|
origin_idx: int = Field(ge=0)
|
||||||
origin: NodeId
|
origin: NodeId
|
||||||
session: SessionId
|
session: SessionId
|
||||||
event: Event
|
event: Event
|
||||||
|
|
||||||
|
|
||||||
|
class LocalForwarderEvent(CamelCaseModel):
|
||||||
|
"""An event the forwarder will serialize and send over the network"""
|
||||||
|
|
||||||
|
origin_idx: int = Field(ge=0)
|
||||||
|
origin: SystemId
|
||||||
|
session: SessionId
|
||||||
|
event: Event
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from random import random
|
from random import random
|
||||||
from typing import Iterator
|
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio import CancelScope, create_task_group, fail_after
|
from anyio import CancelScope, create_task_group, fail_after
|
||||||
@@ -17,13 +16,14 @@ from exo.shared.types.commands import (
|
|||||||
RequestEventLog,
|
RequestEventLog,
|
||||||
StartDownload,
|
StartDownload,
|
||||||
)
|
)
|
||||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
Event,
|
Event,
|
||||||
EventId,
|
EventId,
|
||||||
ForwarderEvent,
|
GlobalForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
|
LocalForwarderEvent,
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
@@ -58,24 +58,22 @@ class Worker:
|
|||||||
node_id: NodeId,
|
node_id: NodeId,
|
||||||
session_id: SessionId,
|
session_id: SessionId,
|
||||||
*,
|
*,
|
||||||
global_event_receiver: Receiver[ForwarderEvent],
|
global_event_receiver: Receiver[GlobalForwarderEvent],
|
||||||
local_event_sender: Sender[ForwarderEvent],
|
local_event_sender: Sender[LocalForwarderEvent],
|
||||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||||
# but I think it's the correct way to be thinking about commands
|
# but I think it's the correct way to be thinking about commands
|
||||||
command_sender: Sender[ForwarderCommand],
|
command_sender: Sender[ForwarderCommand],
|
||||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||||
event_index_counter: Iterator[int],
|
|
||||||
):
|
):
|
||||||
self.node_id: NodeId = node_id
|
self.node_id: NodeId = node_id
|
||||||
self.session_id: SessionId = session_id
|
self.session_id: SessionId = session_id
|
||||||
|
|
||||||
self.global_event_receiver = global_event_receiver
|
self.global_event_receiver = global_event_receiver
|
||||||
self.local_event_sender = local_event_sender
|
self.local_event_sender = local_event_sender
|
||||||
self.event_index_counter = event_index_counter
|
|
||||||
self.command_sender = command_sender
|
self.command_sender = command_sender
|
||||||
self.download_command_sender = download_command_sender
|
self.download_command_sender = download_command_sender
|
||||||
self.event_buffer = OrderedBuffer[Event]()
|
self.event_buffer = OrderedBuffer[Event]()
|
||||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {}
|
||||||
|
|
||||||
self.state: State = State()
|
self.state: State = State()
|
||||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||||
@@ -86,6 +84,8 @@ class Worker:
|
|||||||
self._nack_base_seconds: float = 0.5
|
self._nack_base_seconds: float = 0.5
|
||||||
self._nack_cap_seconds: float = 10.0
|
self._nack_cap_seconds: float = 10.0
|
||||||
|
|
||||||
|
self._system_id = SystemId()
|
||||||
|
|
||||||
self.event_sender, self.event_receiver = channel[Event]()
|
self.event_sender, self.event_receiver = channel[Event]()
|
||||||
|
|
||||||
# Buffer for input image chunks (for image editing)
|
# Buffer for input image chunks (for image editing)
|
||||||
@@ -132,6 +132,8 @@ class Worker:
|
|||||||
async def _event_applier(self):
|
async def _event_applier(self):
|
||||||
with self.global_event_receiver as events:
|
with self.global_event_receiver as events:
|
||||||
async for f_event in events:
|
async for f_event in events:
|
||||||
|
if f_event.session != self.session_id:
|
||||||
|
continue
|
||||||
if f_event.origin != self.session_id.master_node_id:
|
if f_event.origin != self.session_id.master_node_id:
|
||||||
continue
|
continue
|
||||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||||
@@ -212,7 +214,7 @@ class Worker:
|
|||||||
|
|
||||||
await self.download_command_sender.send(
|
await self.download_command_sender.send(
|
||||||
ForwarderDownloadCommand(
|
ForwarderDownloadCommand(
|
||||||
origin=self.node_id,
|
origin=self._system_id,
|
||||||
command=StartDownload(
|
command=StartDownload(
|
||||||
target_node_id=self.node_id,
|
target_node_id=self.node_id,
|
||||||
shard_metadata=shard,
|
shard_metadata=shard,
|
||||||
@@ -317,7 +319,7 @@ class Worker:
|
|||||||
)
|
)
|
||||||
await self.command_sender.send(
|
await self.command_sender.send(
|
||||||
ForwarderCommand(
|
ForwarderCommand(
|
||||||
origin=self.node_id,
|
origin=self._system_id,
|
||||||
command=RequestEventLog(since_idx=since_idx),
|
command=RequestEventLog(since_idx=since_idx),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -344,15 +346,16 @@ class Worker:
|
|||||||
return runner
|
return runner
|
||||||
|
|
||||||
async def _forward_events(self) -> None:
|
async def _forward_events(self) -> None:
|
||||||
|
idx = 0
|
||||||
with self.event_receiver as events:
|
with self.event_receiver as events:
|
||||||
async for event in events:
|
async for event in events:
|
||||||
idx = next(self.event_index_counter)
|
fe = LocalForwarderEvent(
|
||||||
fe = ForwarderEvent(
|
|
||||||
origin_idx=idx,
|
origin_idx=idx,
|
||||||
origin=self.node_id,
|
origin=self._system_id,
|
||||||
session=self.session_id,
|
session=self.session_id,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
|
idx += 1
|
||||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||||
await self.local_event_sender.send(fe)
|
await self.local_event_sender.send(fe)
|
||||||
self.out_for_delivery[event.event_id] = fe
|
self.out_for_delivery[event.event_id] = fe
|
||||||
|
|||||||
@@ -98,16 +98,11 @@ class RunnerSupervisor:
|
|||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
logger.info("Runner supervisor shutting down")
|
logger.info("Runner supervisor shutting down")
|
||||||
with contextlib.suppress(ClosedResourceError):
|
self._ev_recv.close()
|
||||||
self._ev_recv.close()
|
self._task_sender.close()
|
||||||
with contextlib.suppress(ClosedResourceError):
|
self._event_sender.close()
|
||||||
self._task_sender.close()
|
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
with contextlib.suppress(ClosedResourceError):
|
self._cancel_sender.close()
|
||||||
self._event_sender.close()
|
|
||||||
with contextlib.suppress(ClosedResourceError):
|
|
||||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
|
||||||
with contextlib.suppress(ClosedResourceError):
|
|
||||||
self._cancel_sender.close()
|
|
||||||
self.runner_process.join(5)
|
self.runner_process.join(5)
|
||||||
if not self.runner_process.is_alive():
|
if not self.runner_process.is_alive():
|
||||||
logger.info("Runner process succesfully terminated")
|
logger.info("Runner process succesfully terminated")
|
||||||
|
|||||||
Reference in New Issue
Block a user