Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
089612acde feat: show ETA on prefill progress bar
Track when prefill starts via performance.now() and extrapolate
remaining time from observed tokens/sec. Displays "~Xs remaining"
(or "~Xm Ys remaining" for longer prompts) next to the percentage.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:37:16 -08:00
29 changed files with 389 additions and 392 deletions

View File

@@ -20,7 +20,6 @@ from harness import (
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
@@ -963,21 +962,6 @@ Examples:
selected.sort(key=_placement_sort_key)
preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])

View File

@@ -35,7 +35,6 @@ from harness import (
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
@@ -333,20 +332,6 @@ def main() -> int:
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

View File

@@ -282,151 +282,6 @@ def settle_and_fetch_placements(
return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
@@ -470,8 +325,3 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -14,6 +14,21 @@
: 0,
);
const etaText = $derived.by(() => {
if (progress.processed <= 0 || progress.total <= 0) return null;
const elapsedMs = performance.now() - progress.startedAt;
if (elapsedMs < 200) return null; // need a minimum sample window
const tokensPerMs = progress.processed / elapsedMs;
const remainingTokens = progress.total - progress.processed;
const remainingMs = remainingTokens / tokensPerMs;
const remainingSec = Math.ceil(remainingMs / 1000);
if (remainingSec <= 0) return null;
if (remainingSec < 60) return `~${remainingSec}s remaining`;
const mins = Math.floor(remainingSec / 60);
const secs = remainingSec % 60;
return `~${mins}m ${secs}s remaining`;
});
function formatTokenCount(count: number | undefined): string {
if (count == null) return "0";
if (count >= 1000) {
@@ -40,8 +55,11 @@
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
<div
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
>
<span>{etaText ?? ""}</span>
<span>{percentage}%</span>
</div>
</div>

View File

@@ -250,11 +250,6 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// Disk usage per node
nodeDisk?: Record<
string,
{ total: { inBytes: number }; available: { inBytes: number } }
>;
}
export interface MessageAttachment {
@@ -281,6 +276,8 @@ export interface TokenData {
export interface PrefillProgress {
processed: number;
total: number;
/** Timestamp (performance.now()) when prefill started. */
startedAt: number;
}
export interface Message {
@@ -2425,6 +2422,7 @@ class AppStore {
this.prefillProgress = {
processed: inner.processed_tokens,
total: inner.total_tokens,
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
};
},
},

View File

@@ -790,8 +790,10 @@
if (!progress || typeof progress !== "object") return null;
const prog = progress as Record<string, unknown>;
const totalBytes = getBytes(prog.total);
const downloadedBytes = getBytes(prog.downloaded);
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const speed = (prog.speed as number) ?? 0;
const completedFiles =
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
@@ -804,8 +806,8 @@
for (const [fileName, fileData] of Object.entries(filesObj)) {
if (!fileData || typeof fileData !== "object") continue;
const fd = fileData as Record<string, unknown>;
const fTotal = getBytes(fd.total);
const fDownloaded = getBytes(fd.downloaded);
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
files.push({
name: fileName,
totalBytes: fTotal,
@@ -1194,6 +1196,7 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;

View File

@@ -74,6 +74,7 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;
@@ -230,14 +231,23 @@
undefined;
let cell: CellStatus;
if (tag === "DownloadCompleted") {
const totalBytes = getBytes(payload.total);
const totalBytes = getBytes(
payload.total_bytes ?? payload.totalBytes,
);
cell = { kind: "completed", totalBytes, modelDirectory };
} else if (tag === "DownloadOngoing") {
const rawProgress =
payload.download_progress ?? payload.downloadProgress ?? {};
const prog = rawProgress as Record<string, unknown>;
const totalBytes = getBytes(prog.total ?? payload.total);
const downloadedBytes = getBytes(prog.downloaded);
const totalBytes = getBytes(
prog.total_bytes ??
prog.totalBytes ??
payload.total_bytes ??
payload.totalBytes,
);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const speed = (prog.speed as number) ?? 0;
const etaMs =
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;

View File

@@ -19,7 +19,7 @@ class ConnectionUpdate:
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> builtins.str:
def peer_id(self) -> PeerId:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@@ -40,22 +40,92 @@ class Keypair:
Identity keypair of a node.
"""
@staticmethod
def generate() -> Keypair:
def generate_ed25519() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Keypair:
def generate_ecdsa() -> Keypair:
r"""
Construct an Ed25519 keypair from secret key bytes
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
"""
def to_bytes(self) -> bytes:
r"""
Get the secret key bytes underlying the keypair
Return a copy of this [`Multiaddr`]'s byte representation.
"""
def to_node_id(self) -> builtins.str:
def to_string(self) -> builtins.str:
r"""
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
Convert a Multiaddr to a string.
"""
@typing.final
@@ -110,6 +180,37 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""

View File

@@ -1,6 +1,8 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::types::{PyBytes, PyBytesMethods as _};
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
@@ -16,32 +18,142 @@ pub struct PyKeypair(pub Keypair);
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate() -> Self {
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Construct an Ed25519 keypair from secret key bytes
/// Generate a new ECDSA keypair.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Get the secret key bytes underlying the keypair
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
fn to_node_id(&self) -> String {
self.0.public().to_peer_id().to_base58()
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -8,10 +8,9 @@ mod allow_threading;
mod ident;
mod networking;
use crate::ident::PyKeypair;
use crate::ident::ident_submodule;
use crate::networking::networking_submodule;
use pyo3::prelude::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -159,7 +158,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
m.add_class::<PyKeypair>()?;
ident_submodule(m)?;
networking_submodule(m)?;
// top-level constructs

View File

@@ -8,7 +8,7 @@
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair;
use crate::ident::{PyKeypair, PyPeerId};
use crate::pyclass;
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: String,
peer_id: PyPeerId,
/// Remote connection's IPv4 address.
#[pyo3(get)]
@@ -251,7 +251,7 @@ async fn networking_task(
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: peer_id.to_base58(),
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
@@ -272,7 +272,7 @@ async fn networking_task(
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: peer_id.to_base58(),
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {

View File

@@ -80,7 +80,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total=progress.total,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -203,7 +203,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total=initial_progress.total,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -332,13 +332,13 @@ class DownloadCoordinator:
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total=progress.total,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_this_session.in_bytes == 0:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,

View File

@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
repo_file_download_progress: RepoFileDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
downloaded=repo_file_download_progress.downloaded,
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
total=repo_file_download_progress.total,
downloaded_bytes=repo_file_download_progress.downloaded,
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
total_bytes=repo_file_download_progress.total,
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
total_files=1,
speed=repo_file_download_progress.speed,
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
repo_download_progress: RepoDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
total=repo_download_progress.total,
downloaded=repo_download_progress.downloaded,
downloaded_this_session=repo_download_progress.downloaded_this_session,
total_bytes=repo_download_progress.total_bytes,
downloaded_bytes=repo_download_progress.downloaded_bytes,
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
completed_files=repo_download_progress.completed_files,
total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed,
@@ -578,20 +578,19 @@ def calculate_repo_progress(
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
) -> RepoDownloadProgress:
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
all_downloaded = sum(
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes = sum(
(p.downloaded.in_bytes for p in file_progress.values()), 0
)
all_downloaded_this_session = sum(
(p.downloaded_this_session for p in file_progress.values()),
Memory.from_bytes(0),
all_downloaded_bytes_this_session = sum(
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
)
elapsed_time = time.time() - all_start_time
all_speed = (
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
)
all_eta = (
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
if all_speed > 0
else timedelta(seconds=0)
)
@@ -610,9 +609,11 @@ def calculate_repo_progress(
[p for p in file_progress.values() if p.downloaded == p.total]
),
total_files=len(file_progress),
downloaded=all_downloaded,
downloaded_this_session=all_downloaded_this_session,
total=all_total,
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
downloaded_bytes_this_session=Memory.from_bytes(
all_downloaded_bytes_this_session
),
total_bytes=Memory.from_bytes(all_total_bytes),
overall_speed=all_speed,
overall_eta=all_eta,
status=status,

View File

@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
),
completed_files=0,
total_files=0,
downloaded=Memory.from_bytes(0),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(0),
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",

View File

@@ -45,7 +45,7 @@ class Node:
@classmethod
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)

View File

@@ -1323,7 +1323,7 @@ class API:
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=card.storage_size.in_mb,
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),

View File

@@ -102,21 +102,22 @@ def _allocate_and_validate_layers(
layer_allocations = allocate_layers_proportionally(
total_layers=model_card.n_layers,
memory_fractions=[
node_memory[node_id].ram_available / total_memory for node_id in node_ids
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
for node_id in node_ids
],
)
total_storage = model_card.storage_size
total_storage_bytes = model_card.storage_size.in_bytes
total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i]
required_memory = (total_storage * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available
required_memory = (total_storage_bytes * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
f"but only has {available_memory.in_gb:.2f} GB available"
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
return layer_allocations

View File

@@ -42,7 +42,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -75,7 +75,7 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(

View File

@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size = Memory.from_bytes(
sum(available_memory)
model_card.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size = Memory.from_bytes(1500)
model_card.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()

View File

@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id),
node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,

View File

@@ -221,7 +221,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate()
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")
@@ -235,12 +235,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read()
try: # if decoded successfully, save & return
return Keypair.from_bytes(protobuf_encoded)
return Keypair.from_protobuf_encoding(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist
with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519()
f.write(keypair.to_bytes())
f.write(keypair.to_protobuf_encoding())
return keypair

View File

@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total=Memory(),
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total=Memory(),
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total=Memory(),
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release()
# wait to be told to begin simultaneous read
ev.wait()
queue.put(get_node_id_keypair().to_bytes())
queue.put(get_node_id_keypair().to_protobuf_encoding())
def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -1,10 +1,10 @@
from math import ceil
from typing import Self, overload
from typing import Self
from exo.utils.pydantic_ext import FrozenModel
from exo.utils.pydantic_ext import CamelCaseModel
class Memory(FrozenModel):
class Memory(CamelCaseModel):
in_bytes: int = 0
@classmethod
@@ -33,22 +33,12 @@ class Memory(FrozenModel):
return cls(in_bytes=round(val * 1024))
@property
def in_mb(self) -> int:
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
return round(self.in_bytes / (1024**2))
@in_mb.setter
def in_mb(self, val: int):
"""Set the megabytes for this memory."""
self.in_bytes = val * (1024**2)
@property
def in_float_mb(self) -> float:
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
def in_mb(self) -> float:
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
return self.in_bytes / (1024**2)
@in_float_mb.setter
def in_float_mb(self, val: float):
@in_mb.setter
def in_mb(self, val: float):
"""Set the megabytes for this memory, rounded to the nearest byte."""
self.in_bytes = round(val * (1024**2))
@@ -67,85 +57,17 @@ class Memory(FrozenModel):
"""The approximate gigabytes this memory represents."""
return self.in_bytes / (1024**3)
def __add__(self, other: object) -> "Memory":
if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes + other.in_bytes)
return NotImplemented
def __add__(self, other: "Memory") -> "Memory":
return Memory.from_bytes(self.in_bytes + other.in_bytes)
def __radd__(self, other: object) -> "Memory":
if other == 0:
return self
return NotImplemented
def __lt__(self, other: Self) -> bool:
return self.in_bytes < other.in_bytes
def __sub__(self, other: object) -> "Memory":
if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes - other.in_bytes)
return NotImplemented
def __le__(self, other: Self) -> bool:
return self.in_bytes <= other.in_bytes
def __mul__(self, other: int | float):
return Memory.from_bytes(round(self.in_bytes * other))
def __gt__(self, other: Self) -> bool:
return self.in_bytes > other.in_bytes
def __rmul__(self, other: int | float):
return self * other
@overload
def __truediv__(self, other: "Memory") -> float: ...
@overload
def __truediv__(self, other: int) -> "Memory": ...
@overload
def __truediv__(self, other: float) -> "Memory": ...
def __truediv__(self, other: object) -> "Memory | float":
if isinstance(other, Memory):
return self.in_bytes / other.in_bytes
if isinstance(other, (int, float)):
return Memory.from_bytes(round(self.in_bytes / other))
return NotImplemented
def __floordiv__(self, other: object) -> "Memory":
if isinstance(other, (int, float)):
return Memory.from_bytes(int(self.in_bytes // other))
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes < other.in_bytes
return NotImplemented
def __le__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes <= other.in_bytes
return NotImplemented
def __gt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes > other.in_bytes
return NotImplemented
def __ge__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes >= other.in_bytes
return NotImplemented
def __eq__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes == other.in_bytes
return NotImplemented
def __repr__(self) -> str:
return f"Memory.from_bytes({self.in_bytes})"
def __str__(self) -> str:
if self.in_gb > 2:
val = self.in_gb
unit = "GiB"
elif self.in_mb > 2:
val = self.in_mb
unit = "MiB"
elif self.in_kb > 3:
val = self.in_kb
unit = "KiB"
else:
val = self.in_bytes
unit = "B"
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
def __ge__(self, other: Self) -> bool:
return self.in_bytes >= other.in_bytes

View File

@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class DownloadProgressData(CamelCaseModel):
total: Memory
downloaded: Memory
downloaded_this_session: Memory
total_bytes: Memory
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
completed_files: int
total_files: int
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
total: Memory
total_bytes: Memory
class DownloadFailed(BaseDownloadProgress):
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
shard: ShardMetadata
completed_files: int
total_files: int
downloaded: Memory
downloaded_this_session: Memory
total: Memory
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]

View File

@@ -166,7 +166,7 @@ def generate_image(
else 0.0
)
peak_memory = Memory.from_bytes(mx.get_peak_memory())
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
@@ -175,7 +175,7 @@ def generate_image(
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=peak_memory,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()

View File

@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
total_gb = psutil.virtual_memory().total / (1024**3)
if total_gb >= 128:
return 0.85
if total_gb >= 64:

View File

@@ -232,11 +232,11 @@ def shard_and_load(
# Estimate timeout based on model size (5x default for large queued workloads)
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
model_size = get_weights_size(shard_metadata)
timeout_seconds = base_timeout + model_size.in_gb
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size.in_gb:.1f}GB)"
f"(model size: {model_size_gb:.1f}GB)"
)
match shard_metadata:
@@ -617,17 +617,18 @@ def set_wired_limit_for_model(model_size: Memory):
if not mx.metal.is_available():
return
max_rec_size = Memory.from_bytes(
int(mx.metal.device_info()["max_recommended_working_set_size"])
)
if model_size > 0.9 * max_rec_size:
model_bytes = model_size.in_bytes
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
logger.warning(
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
f"Generating with a model that requires {model_mb} MB "
f"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
mx.set_wired_limit(max_rec_size.in_bytes)
mx.set_wired_limit(max_rec_size)
logger.info(f"Wired limit set to {max_rec_size}.")

View File

@@ -90,10 +90,14 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
global_download_status = {
NODE_A: [
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
],
}
@@ -134,7 +138,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Global state shows shard is downloaded for NODE_A
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [],
}
@@ -181,7 +187,9 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = {
NODE_A: [
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [], # NODE_B has no downloads completed yet
}
@@ -199,10 +207,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = {
NODE_A: [
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet
}