Compare commits

...

4 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
31cd8b9d68 Delete model earlier in shutdown case 2026-02-19 19:56:07 +00:00
Evan Quiney
4c4c6ce99f simplify rust ident module
this is partly dead code, partly narrowing the rust-python boundary in
prep for future rewrites. no testing as this is all type safe
refactoring.
2026-02-19 17:19:31 +00:00
Jake Hillion
42e1e7322b bench: restore --danger-delete-downloads planning phase (#1542)
c2f2111b extracted shared utilities from exo_bench.py into harness.py
but accidentally dropped the run_planning_phase function and
--danger-delete-downloads CLI argument in the process.

Restored run_planning_phase in harness.py (where its dependencies now
live) and re-added the --danger-delete-downloads argument to
add_common_instance_args. Re-wired the planning phase call in
exo_bench.py's main() before the benchmark loop.
2026-02-19 15:42:02 +00:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
14 changed files with 247 additions and 271 deletions

View File

@@ -20,6 +20,7 @@ from harness import (
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
@@ -962,6 +963,21 @@ Examples:
selected.sort(key=_placement_sort_key)
preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])

View File

@@ -35,6 +35,7 @@ from harness import (
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
@@ -332,6 +333,20 @@ def main() -> int:
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

View File

@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -19,7 +19,7 @@ class ConnectionUpdate:
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
def peer_id(self) -> builtins.str:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@@ -40,92 +40,22 @@ class Keypair:
Identity keypair of a node.
"""
@staticmethod
def generate_ed25519() -> Keypair:
def generate() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def generate_ecdsa() -> Keypair:
def from_bytes(bytes: bytes) -> Keypair:
r"""
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
Construct an Ed25519 keypair from secret key bytes
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
Get the secret key bytes underlying the keypair
"""
def to_string(self) -> builtins.str:
def to_node_id(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
"""
@typing.final
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""

View File

@@ -1,8 +1,6 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::types::{PyBytes, PyBytesMethods as _};
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
@@ -18,142 +16,32 @@ pub struct PyKeypair(pub Keypair);
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
/// Construct an Ed25519 keypair from secret key bytes
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
/// Get the secret key bytes underlying the keypair
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
fn to_node_id(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -8,9 +8,10 @@ mod allow_threading;
mod ident;
mod networking;
use crate::ident::ident_submodule;
use crate::ident::PyKeypair;
use crate::networking::networking_submodule;
use pyo3::prelude::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -158,7 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
m.add_class::<PyKeypair>()?;
networking_submodule(m)?;
// top-level constructs

View File

@@ -8,7 +8,7 @@
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::{PyKeypair, PyPeerId};
use crate::ident::PyKeypair;
use crate::pyclass;
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
peer_id: String,
/// Remote connection's IPv4 address.
#[pyo3(get)]
@@ -251,7 +251,7 @@ async fn networking_task(
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
@@ -272,7 +272,7 @@ async fn networking_task(
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {

View File

@@ -45,7 +45,7 @@ class Node:
@classmethod
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)

View File

@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
@@ -219,13 +225,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
yield _format_sse(created_event)
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
yield _format_sse(in_progress_event)
# response.output_item.added
initial_item = ResponseMessageItem(
@@ -236,7 +242,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
yield _format_sse(item_added)
# response.content_part.added
initial_part = ResponseOutputText(text="")
@@ -247,7 +253,7 @@ async def generate_responses_stream(
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
yield _format_sse(part_added)
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
@@ -281,7 +287,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
yield _format_sse(fc_added)
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -290,7 +296,7 @@ async def generate_responses_stream(
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
yield _format_sse(args_delta)
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -300,7 +306,7 @@ async def generate_responses_stream(
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
yield _format_sse(args_done)
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
@@ -315,7 +321,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item)
next_output_index += 1
@@ -331,7 +337,7 @@ async def generate_responses_stream(
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
yield _format_sse(delta_event)
# response.output_text.done
text_done = ResponseTextDoneEvent(
@@ -341,7 +347,7 @@ async def generate_responses_stream(
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
yield _format_sse(text_done)
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
@@ -352,7 +358,7 @@ async def generate_responses_stream(
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
yield _format_sse(part_done)
# response.output_item.done
final_message_item = ResponseMessageItem(
@@ -363,7 +369,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
yield _format_sse(item_done)
# Create usage from usage data if available
usage = None
@@ -388,4 +394,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
yield _format_sse(completed_event)

View File

@@ -42,7 +42,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -75,7 +75,7 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(

View File

@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id.to_base58()),
node_id=NodeId(update.peer_id),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,

View File

@@ -221,7 +221,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
return Keypair.generate()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
@@ -235,12 +235,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded)
return Keypair.from_bytes(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding())
f.write(keypair.to_bytes())
return keypair

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release()
# wait to be told to begin simultaneous read
ev.wait()
queue.put(get_node_id_keypair().to_protobuf_encoding())
queue.put(get_node_id_keypair().to_bytes())
def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -573,6 +573,13 @@ def main(
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
@@ -581,6 +588,7 @@ def main(
event_sender.send(TaskAcknowledged(task_id=task.task_id))
current_status = RunnerShutdown()
break
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
@@ -597,13 +605,6 @@ def main(
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
break
@cache