mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 15:27:02 -05:00
Compare commits
3 Commits
test-scree
...
feat/mac-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23f295e684 | ||
|
|
7637fb554f | ||
|
|
1a9c5fa6fb |
@@ -20,7 +20,6 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -963,21 +962,6 @@ Examples:
|
||||
|
||||
selected.sort(key=_placement_sort_key)
|
||||
preview = selected[0]
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
print("Planning phase: checking downloads...", file=log)
|
||||
run_planning_phase(
|
||||
exo,
|
||||
full_model_id,
|
||||
preview,
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
sharding = str(preview["sharding"])
|
||||
|
||||
@@ -35,7 +35,6 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -333,20 +332,6 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
150
bench/harness.py
150
bench/harness.py
@@ -282,151 +282,6 @@ def settle_and_fetch_placements(
|
||||
return selected
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
@@ -470,8 +325,3 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
} from "$lib/stores/app.svelte";
|
||||
import type { MessageAttachment } from "$lib/stores/app.svelte";
|
||||
import MarkdownContent from "./MarkdownContent.svelte";
|
||||
import PrefillProgressBar from "./PrefillProgressBar.svelte";
|
||||
import TokenHeatmap from "./TokenHeatmap.svelte";
|
||||
import PrefillProgressBar from "./PrefillProgressBar.svelte";
|
||||
import ImageLightbox from "./ImageLightbox.svelte";
|
||||
@@ -625,7 +626,9 @@
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
{#if loading && !message.content && prefill}
|
||||
<PrefillProgressBar progress={prefill} class="mt-2" />
|
||||
{:else if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
|
||||
@@ -14,6 +14,21 @@
|
||||
: 0,
|
||||
);
|
||||
|
||||
const etaText = $derived.by(() => {
|
||||
if (progress.processed <= 0 || progress.total <= 0) return null;
|
||||
const elapsedMs = performance.now() - progress.startedAt;
|
||||
if (elapsedMs < 200) return null; // need a minimum sample window
|
||||
const tokensPerMs = progress.processed / elapsedMs;
|
||||
const remainingTokens = progress.total - progress.processed;
|
||||
const remainingMs = remainingTokens / tokensPerMs;
|
||||
const remainingSec = Math.ceil(remainingMs / 1000);
|
||||
if (remainingSec <= 0) return null;
|
||||
if (remainingSec < 60) return `~${remainingSec}s remaining`;
|
||||
const mins = Math.floor(remainingSec / 60);
|
||||
const secs = remainingSec % 60;
|
||||
return `~${mins}m ${secs}s remaining`;
|
||||
});
|
||||
|
||||
function formatTokenCount(count: number | undefined): string {
|
||||
if (count == null) return "0";
|
||||
if (count >= 1000) {
|
||||
@@ -40,8 +55,11 @@
|
||||
style="width: {percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
|
||||
{percentage}%
|
||||
<div
|
||||
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
|
||||
>
|
||||
<span>{etaText ?? ""}</span>
|
||||
<span>{percentage}%</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -276,6 +276,8 @@ export interface TokenData {
|
||||
export interface PrefillProgress {
|
||||
processed: number;
|
||||
total: number;
|
||||
/** Timestamp (performance.now()) when prefill started. */
|
||||
startedAt: number;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
@@ -520,12 +522,12 @@ class AppStore {
|
||||
messages = $state<Message[]>([]);
|
||||
currentResponse = $state("");
|
||||
isLoading = $state(false);
|
||||
prefillProgress = $state<PrefillProgress | null>(null);
|
||||
|
||||
// Performance metrics
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
tps = $state<number | null>(null); // Tokens per second
|
||||
totalTokens = $state<number>(0); // Total tokens in current response
|
||||
prefillProgress = $state<PrefillProgress | null>(null);
|
||||
|
||||
// Abort controller for stopping generation
|
||||
private currentAbortController: AbortController | null = null;
|
||||
@@ -2018,6 +2020,7 @@ class AppStore {
|
||||
): Promise<void> {
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
let currentEventType = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@@ -2033,7 +2036,15 @@ class AppStore {
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) continue;
|
||||
if (!trimmed) {
|
||||
currentEventType = "";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (trimmed.startsWith("event: ")) {
|
||||
currentEventType = trimmed.slice(7);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle SSE comments (": key json") for prefill progress etc.
|
||||
if (trimmed.startsWith(": ") && onEvent) {
|
||||
@@ -2055,14 +2066,22 @@ class AppStore {
|
||||
|
||||
if (trimmed.startsWith("data: ")) {
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
if (data === "[DONE]") {
|
||||
currentEventType = "";
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data) as T;
|
||||
onChunk(parsed);
|
||||
const parsed = JSON.parse(data);
|
||||
if (currentEventType && onEvent?.[currentEventType]) {
|
||||
onEvent[currentEventType](parsed);
|
||||
} else {
|
||||
onChunk(parsed as T);
|
||||
}
|
||||
} catch {
|
||||
// Skip malformed JSON
|
||||
}
|
||||
currentEventType = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2163,6 +2182,7 @@ class AppStore {
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
this.prefillProgress = null;
|
||||
this.ttftMs = null;
|
||||
this.tps = null;
|
||||
this.totalTokens = 0;
|
||||
@@ -2367,6 +2387,11 @@ class AppStore {
|
||||
}
|
||||
|
||||
if (tokenContent) {
|
||||
// Clear prefill progress once tokens start arriving
|
||||
if (this.prefillProgress !== null) {
|
||||
this.prefillProgress = null;
|
||||
}
|
||||
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
@@ -2420,6 +2445,7 @@ class AppStore {
|
||||
this.prefillProgress = {
|
||||
processed: inner.processed_tokens,
|
||||
total: inner.total_tokens,
|
||||
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
|
||||
};
|
||||
},
|
||||
},
|
||||
@@ -2474,6 +2500,7 @@ class AppStore {
|
||||
this.prefillProgress = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.prefillProgress = null;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
@@ -3106,10 +3133,10 @@ export const hasStartedChat = () => appStore.hasStartedChat;
|
||||
export const messages = () => appStore.messages;
|
||||
export const currentResponse = () => appStore.currentResponse;
|
||||
export const isLoading = () => appStore.isLoading;
|
||||
export const prefillProgress = () => appStore.prefillProgress;
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const prefillProgress = () => appStore.prefillProgress;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const runners = () => appStore.runners;
|
||||
|
||||
@@ -562,6 +562,8 @@ class API:
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
|
||||
|
||||
async def _collect_text_generation_with_stats(
|
||||
self, command_id: CommandId
|
||||
) -> BenchChatCompletionResponse:
|
||||
|
||||
@@ -90,6 +90,8 @@ def prefill(
|
||||
)
|
||||
if has_ssm:
|
||||
snapshots.append(snapshot_ssm_states(cache))
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
if on_prefill_progress is not None:
|
||||
on_prefill_progress(processed, total)
|
||||
|
||||
Reference in New Issue
Block a user