Compare commits

..

3 Commits

Author SHA1 Message Date
Alex Cheema
23f295e684 feat: show ETA on prefill progress bar
Track when prefill starts via performance.now() and extrapolate
remaining time from observed tokens/sec. Displays "~Xs remaining"
(or "~Xm Ys remaining" for longer prompts) next to the percentage.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:31:47 -08:00
Alex Cheema
7637fb554f refactor: address PR #1181 review comments from Evanev7
- Rename PrefillProgressData to PrefillProgressChunk for consistency
- Convert isinstance chain to match/case in collect_chat_response
- Remove unused StreamEvent type alias from chunks.py
- Update docstrings to reflect new naming

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:31:14 -08:00
Alex Cheema
1a9c5fa6fb fix: wire prefill progress callback to prefill stream_generate, not decode
- Move on_prefill_progress callback from decode stream_generate to prefill()
- Fix SSE parser to handle named event types (event: prefill_progress)
- Wire PrefillProgressBar component into ChatMessages
- Add prefillProgress reactive state to the store

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:30:18 -08:00
8 changed files with 61 additions and 190 deletions

View File

@@ -20,7 +20,6 @@ from harness import (
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
@@ -963,21 +962,6 @@ Examples:
selected.sort(key=_placement_sort_key)
preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])

View File

@@ -35,7 +35,6 @@ from harness import (
instance_id_from_instance,
nodes_used_in_instance,
resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements,
wait_for_instance_gone,
wait_for_instance_ready,
@@ -333,20 +332,6 @@ def main() -> int:
if args.dry_run:
return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = []
for preview in selected:

View File

@@ -282,151 +282,6 @@ def settle_and_fetch_placements(
return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
@@ -470,8 +325,3 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
)
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -12,6 +12,7 @@
} from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import ImageLightbox from "./ImageLightbox.svelte";
@@ -625,7 +626,9 @@
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
{#if loading && !message.content && prefill}
<PrefillProgressBar progress={prefill} class="mt-2" />
{:else if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>

View File

@@ -14,6 +14,21 @@
: 0,
);
const etaText = $derived.by(() => {
if (progress.processed <= 0 || progress.total <= 0) return null;
const elapsedMs = performance.now() - progress.startedAt;
if (elapsedMs < 200) return null; // need a minimum sample window
const tokensPerMs = progress.processed / elapsedMs;
const remainingTokens = progress.total - progress.processed;
const remainingMs = remainingTokens / tokensPerMs;
const remainingSec = Math.ceil(remainingMs / 1000);
if (remainingSec <= 0) return null;
if (remainingSec < 60) return `~${remainingSec}s remaining`;
const mins = Math.floor(remainingSec / 60);
const secs = remainingSec % 60;
return `~${mins}m ${secs}s remaining`;
});
function formatTokenCount(count: number | undefined): string {
if (count == null) return "0";
if (count >= 1000) {
@@ -40,8 +55,11 @@
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
<div
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
>
<span>{etaText ?? ""}</span>
<span>{percentage}%</span>
</div>
</div>

View File

@@ -276,6 +276,8 @@ export interface TokenData {
export interface PrefillProgress {
processed: number;
total: number;
/** Timestamp (performance.now()) when prefill started. */
startedAt: number;
}
export interface Message {
@@ -520,12 +522,12 @@ class AppStore {
messages = $state<Message[]>([]);
currentResponse = $state("");
isLoading = $state(false);
prefillProgress = $state<PrefillProgress | null>(null);
// Performance metrics
ttftMs = $state<number | null>(null); // Time to first token in ms
tps = $state<number | null>(null); // Tokens per second
totalTokens = $state<number>(0); // Total tokens in current response
prefillProgress = $state<PrefillProgress | null>(null);
// Abort controller for stopping generation
private currentAbortController: AbortController | null = null;
@@ -2018,6 +2020,7 @@ class AppStore {
): Promise<void> {
const decoder = new TextDecoder();
let buffer = "";
let currentEventType = "";
while (true) {
const { done, value } = await reader.read();
@@ -2033,7 +2036,15 @@ class AppStore {
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
if (!trimmed) {
currentEventType = "";
continue;
}
if (trimmed.startsWith("event: ")) {
currentEventType = trimmed.slice(7);
continue;
}
// Handle SSE comments (": key json") for prefill progress etc.
if (trimmed.startsWith(": ") && onEvent) {
@@ -2055,14 +2066,22 @@ class AppStore {
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
if (data === "[DONE]") {
currentEventType = "";
continue;
}
try {
const parsed = JSON.parse(data) as T;
onChunk(parsed);
const parsed = JSON.parse(data);
if (currentEventType && onEvent?.[currentEventType]) {
onEvent[currentEventType](parsed);
} else {
onChunk(parsed as T);
}
} catch {
// Skip malformed JSON
}
currentEventType = "";
}
}
}
@@ -2163,6 +2182,7 @@ class AppStore {
this.isLoading = true;
this.currentResponse = "";
this.prefillProgress = null;
this.ttftMs = null;
this.tps = null;
this.totalTokens = 0;
@@ -2367,6 +2387,11 @@ class AppStore {
}
if (tokenContent) {
// Clear prefill progress once tokens start arriving
if (this.prefillProgress !== null) {
this.prefillProgress = null;
}
// Track first token for TTFT
if (firstTokenTime === null) {
firstTokenTime = performance.now();
@@ -2420,6 +2445,7 @@ class AppStore {
this.prefillProgress = {
processed: inner.processed_tokens,
total: inner.total_tokens,
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
};
},
},
@@ -2474,6 +2500,7 @@ class AppStore {
this.prefillProgress = null;
this.isLoading = false;
this.currentResponse = "";
this.prefillProgress = null;
this.saveConversationsToStorage();
}
}
@@ -3106,10 +3133,10 @@ export const hasStartedChat = () => appStore.hasStartedChat;
export const messages = () => appStore.messages;
export const currentResponse = () => appStore.currentResponse;
export const isLoading = () => appStore.isLoading;
export const prefillProgress = () => appStore.prefillProgress;
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const prefillProgress = () => appStore.prefillProgress;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const runners = () => appStore.runners;

View File

@@ -562,6 +562,8 @@ class API:
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
async def _collect_text_generation_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:

View File

@@ -90,6 +90,8 @@ def prefill(
)
if has_ssm:
snapshots.append(snapshot_ssm_states(cache))
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
if on_prefill_progress is not None:
on_prefill_progress(processed, total)