Compare commits

..

2 Commits

Author SHA1 Message Date
Evan
41c4240919 got kimi to update the dashboard! 2026-02-19 20:17:10 +00:00
Evan
08b4ec5724 wa 2026-02-19 18:54:00 +00:00
17 changed files with 177 additions and 123 deletions

View File

@@ -250,6 +250,11 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// Disk usage per node
nodeDisk?: Record<
string,
{ total: { inBytes: number }; available: { inBytes: number } }
>;
}
export interface MessageAttachment {

View File

@@ -790,10 +790,8 @@
if (!progress || typeof progress !== "object") return null;
const prog = progress as Record<string, unknown>;
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const totalBytes = getBytes(prog.total);
const downloadedBytes = getBytes(prog.downloaded);
const speed = (prog.speed as number) ?? 0;
const completedFiles =
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
@@ -806,8 +804,8 @@
for (const [fileName, fileData] of Object.entries(filesObj)) {
if (!fileData || typeof fileData !== "object") continue;
const fd = fileData as Record<string, unknown>;
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
const fTotal = getBytes(fd.total);
const fDownloaded = getBytes(fd.downloaded);
files.push({
name: fileName,
totalBytes: fTotal,
@@ -1196,7 +1194,6 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;

View File

@@ -74,7 +74,6 @@
if (typeof value === "number") return value;
if (value && typeof value === "object") {
const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes;
}
return 0;
@@ -231,23 +230,14 @@
undefined;
let cell: CellStatus;
if (tag === "DownloadCompleted") {
const totalBytes = getBytes(
payload.total_bytes ?? payload.totalBytes,
);
const totalBytes = getBytes(payload.total);
cell = { kind: "completed", totalBytes, modelDirectory };
} else if (tag === "DownloadOngoing") {
const rawProgress =
payload.download_progress ?? payload.downloadProgress ?? {};
const prog = rawProgress as Record<string, unknown>;
const totalBytes = getBytes(
prog.total_bytes ??
prog.totalBytes ??
payload.total_bytes ??
payload.totalBytes,
);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const totalBytes = getBytes(prog.total ?? payload.total);
const downloadedBytes = getBytes(prog.downloaded);
const speed = (prog.speed as number) ?? 0;
const etaMs =
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;

View File

@@ -80,7 +80,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
total=progress.total,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -203,7 +203,7 @@ class DownloadCoordinator:
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
total=initial_progress.total,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
@@ -332,13 +332,13 @@ class DownloadCoordinator:
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
total=progress.total,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
if progress.downloaded_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,

View File

@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
repo_file_download_progress: RepoFileDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
downloaded_bytes=repo_file_download_progress.downloaded,
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
total_bytes=repo_file_download_progress.total,
downloaded=repo_file_download_progress.downloaded,
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
total=repo_file_download_progress.total,
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
total_files=1,
speed=repo_file_download_progress.speed,
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
repo_download_progress: RepoDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
total_bytes=repo_download_progress.total_bytes,
downloaded_bytes=repo_download_progress.downloaded_bytes,
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
total=repo_download_progress.total,
downloaded=repo_download_progress.downloaded,
downloaded_this_session=repo_download_progress.downloaded_this_session,
completed_files=repo_download_progress.completed_files,
total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed,
@@ -578,19 +578,20 @@ def calculate_repo_progress(
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
) -> RepoDownloadProgress:
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes = sum(
(p.downloaded.in_bytes for p in file_progress.values()), 0
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
all_downloaded = sum(
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
)
all_downloaded_bytes_this_session = sum(
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
all_downloaded_this_session = sum(
(p.downloaded_this_session for p in file_progress.values()),
Memory.from_bytes(0),
)
elapsed_time = time.time() - all_start_time
all_speed = (
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
)
all_eta = (
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
if all_speed > 0
else timedelta(seconds=0)
)
@@ -609,11 +610,9 @@ def calculate_repo_progress(
[p for p in file_progress.values() if p.downloaded == p.total]
),
total_files=len(file_progress),
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
downloaded_bytes_this_session=Memory.from_bytes(
all_downloaded_bytes_this_session
),
total_bytes=Memory.from_bytes(all_total_bytes),
downloaded=all_downloaded,
downloaded_this_session=all_downloaded_this_session,
total=all_total,
overall_speed=all_speed,
overall_eta=all_eta,
status=status,

View File

@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
downloaded=Memory.from_bytes(0),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",

View File

@@ -1323,7 +1323,7 @@ class API:
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
storage_size_megabytes=card.storage_size.in_mb,
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),

View File

@@ -102,22 +102,21 @@ def _allocate_and_validate_layers(
layer_allocations = allocate_layers_proportionally(
total_layers=model_card.n_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
for node_id in node_ids
node_memory[node_id].ram_available / total_memory for node_id in node_ids
],
)
total_storage_bytes = model_card.storage_size.in_bytes
total_storage = model_card.storage_size
total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i]
required_memory = (total_storage_bytes * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available.in_bytes
required_memory = (total_storage * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
f"but only has {available_memory.in_gb:.2f} GB available"
)
return layer_allocations

View File

@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
available_memory
model_card.storage_size = Memory.from_bytes(
sum(available_memory)
) # make it exactly fit across all nodes
topology = Topology()
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
model_card.storage_size = Memory.from_bytes(1500)
node_a = NodeId()
node_b = NodeId()

View File

@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
total=Memory(),
)
new_state = apply_node_download_progress(
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
total=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
total=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})

View File

@@ -1,10 +1,10 @@
from math import ceil
from typing import Self
from typing import Self, overload
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import FrozenModel
class Memory(CamelCaseModel):
class Memory(FrozenModel):
in_bytes: int = 0
@classmethod
@@ -33,12 +33,22 @@ class Memory(CamelCaseModel):
return cls(in_bytes=round(val * 1024))
@property
def in_mb(self) -> float:
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
return self.in_bytes / (1024**2)
def in_mb(self) -> int:
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
return round(self.in_bytes / (1024**2))
@in_mb.setter
def in_mb(self, val: float):
def in_mb(self, val: int):
"""Set the megabytes for this memory."""
self.in_bytes = val * (1024**2)
@property
def in_float_mb(self) -> float:
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
return self.in_bytes / (1024**2)
@in_float_mb.setter
def in_float_mb(self, val: float):
"""Set the megabytes for this memory, rounded to the nearest byte."""
self.in_bytes = round(val * (1024**2))
@@ -57,17 +67,85 @@ class Memory(CamelCaseModel):
"""The approximate gigabytes this memory represents."""
return self.in_bytes / (1024**3)
def __add__(self, other: "Memory") -> "Memory":
return Memory.from_bytes(self.in_bytes + other.in_bytes)
def __add__(self, other: object) -> "Memory":
if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes + other.in_bytes)
return NotImplemented
def __lt__(self, other: Self) -> bool:
return self.in_bytes < other.in_bytes
def __radd__(self, other: object) -> "Memory":
if other == 0:
return self
return NotImplemented
def __le__(self, other: Self) -> bool:
return self.in_bytes <= other.in_bytes
def __sub__(self, other: object) -> "Memory":
if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes - other.in_bytes)
return NotImplemented
def __gt__(self, other: Self) -> bool:
return self.in_bytes > other.in_bytes
def __mul__(self, other: int | float):
return Memory.from_bytes(round(self.in_bytes * other))
def __ge__(self, other: Self) -> bool:
return self.in_bytes >= other.in_bytes
def __rmul__(self, other: int | float):
return self * other
@overload
def __truediv__(self, other: "Memory") -> float: ...
@overload
def __truediv__(self, other: int) -> "Memory": ...
@overload
def __truediv__(self, other: float) -> "Memory": ...
def __truediv__(self, other: object) -> "Memory | float":
if isinstance(other, Memory):
return self.in_bytes / other.in_bytes
if isinstance(other, (int, float)):
return Memory.from_bytes(round(self.in_bytes / other))
return NotImplemented
def __floordiv__(self, other: object) -> "Memory":
if isinstance(other, (int, float)):
return Memory.from_bytes(int(self.in_bytes // other))
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes < other.in_bytes
return NotImplemented
def __le__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes <= other.in_bytes
return NotImplemented
def __gt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes > other.in_bytes
return NotImplemented
def __ge__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes >= other.in_bytes
return NotImplemented
def __eq__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes == other.in_bytes
return NotImplemented
def __repr__(self) -> str:
return f"Memory.from_bytes({self.in_bytes})"
def __str__(self) -> str:
if self.in_gb > 2:
val = self.in_gb
unit = "GiB"
elif self.in_mb > 2:
val = self.in_mb
unit = "MiB"
elif self.in_kb > 3:
val = self.in_kb
unit = "KiB"
else:
val = self.in_bytes
unit = "B"
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"

View File

@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class DownloadProgressData(CamelCaseModel):
total_bytes: Memory
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total: Memory
downloaded: Memory
downloaded_this_session: Memory
completed_files: int
total_files: int
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
total_bytes: Memory
total: Memory
class DownloadFailed(BaseDownloadProgress):
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
downloaded: Memory
downloaded_this_session: Memory
total: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]

View File

@@ -166,7 +166,7 @@ def generate_image(
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
peak_memory = Memory.from_bytes(mx.get_peak_memory())
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
@@ -175,7 +175,7 @@ def generate_image(
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
peak_memory_usage=peak_memory,
)
buffer = io.BytesIO()

View File

@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float:
total_gb = psutil.virtual_memory().total / (1024**3)
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128:
return 0.85
if total_gb >= 64:

View File

@@ -232,11 +232,11 @@ def shard_and_load(
# Estimate timeout based on model size (5x default for large queued workloads)
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb
model_size = get_weights_size(shard_metadata)
timeout_seconds = base_timeout + model_size.in_gb
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
f"(model size: {model_size.in_gb:.1f}GB)"
)
match shard_metadata:
@@ -617,18 +617,17 @@ def set_wired_limit_for_model(model_size: Memory):
if not mx.metal.is_available():
return
model_bytes = model_size.in_bytes
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
max_rec_size = Memory.from_bytes(
int(mx.metal.device_info()["max_recommended_working_set_size"])
)
if model_size > 0.9 * max_rec_size:
logger.warning(
f"Generating with a model that requires {model_mb} MB "
f"which is close to the maximum recommended size of {max_rec_mb} "
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
mx.set_wired_limit(max_rec_size)
mx.set_wired_limit(max_rec_size.in_bytes)
logger.info(f"Wired limit set to {max_rec_size}.")

View File

@@ -573,13 +573,6 @@ def main(
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
@@ -588,7 +581,6 @@ def main(
event_sender.send(TaskAcknowledged(task_id=task.task_id))
current_status = RunnerShutdown()
break
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
@@ -605,6 +597,13 @@ def main(
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
break
@cache

View File

@@ -90,14 +90,10 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
],
}
@@ -138,9 +134,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Global state shows shard is downloaded for NODE_A
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
],
NODE_B: [],
}
@@ -187,9 +181,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
],
NODE_B: [], # NODE_B has no downloads completed yet
}
@@ -207,14 +199,10 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
], # NODE_B has no downloads completed yet
}