Compare commits

..

4 Commits

Author SHA1 Message Date
Alex Cheema
23f295e684 feat: show ETA on prefill progress bar
Track when prefill starts via performance.now() and extrapolate
remaining time from observed tokens/sec. Displays "~Xs remaining"
(or "~Xm Ys remaining" for longer prompts) next to the percentage.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:31:47 -08:00
Alex Cheema
7637fb554f refactor: address PR #1181 review comments from Evanev7
- Rename PrefillProgressData to PrefillProgressChunk for consistency
- Convert isinstance chain to match/case in collect_chat_response
- Remove unused StreamEvent type alias from chunks.py
- Update docstrings to reflect new naming

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:31:14 -08:00
Alex Cheema
1a9c5fa6fb fix: wire prefill progress callback to prefill stream_generate, not decode
- Move on_prefill_progress callback from decode stream_generate to prefill()
- Fix SSE parser to handle named event types (event: prefill_progress)
- Wire PrefillProgressBar component into ChatMessages
- Add prefillProgress reactive state to the store

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 07:30:18 -08:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
9 changed files with 87 additions and 44 deletions

View File

@@ -12,6 +12,7 @@
} from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import ImageLightbox from "./ImageLightbox.svelte";
@@ -625,7 +626,9 @@
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
{#if loading && !message.content && prefill}
<PrefillProgressBar progress={prefill} class="mt-2" />
{:else if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>

View File

@@ -14,6 +14,21 @@
: 0,
);
const etaText = $derived.by(() => {
if (progress.processed <= 0 || progress.total <= 0) return null;
const elapsedMs = performance.now() - progress.startedAt;
if (elapsedMs < 200) return null; // need a minimum sample window
const tokensPerMs = progress.processed / elapsedMs;
const remainingTokens = progress.total - progress.processed;
const remainingMs = remainingTokens / tokensPerMs;
const remainingSec = Math.ceil(remainingMs / 1000);
if (remainingSec <= 0) return null;
if (remainingSec < 60) return `~${remainingSec}s remaining`;
const mins = Math.floor(remainingSec / 60);
const secs = remainingSec % 60;
return `~${mins}m ${secs}s remaining`;
});
function formatTokenCount(count: number | undefined): string {
if (count == null) return "0";
if (count >= 1000) {
@@ -40,8 +55,11 @@
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
<div
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
>
<span>{etaText ?? ""}</span>
<span>{percentage}%</span>
</div>
</div>

View File

@@ -276,6 +276,8 @@ export interface TokenData {
export interface PrefillProgress {
processed: number;
total: number;
/** Timestamp (performance.now()) when prefill started. */
startedAt: number;
}
export interface Message {
@@ -520,12 +522,12 @@ class AppStore {
messages = $state<Message[]>([]);
currentResponse = $state("");
isLoading = $state(false);
prefillProgress = $state<PrefillProgress | null>(null);
// Performance metrics
ttftMs = $state<number | null>(null); // Time to first token in ms
tps = $state<number | null>(null); // Tokens per second
totalTokens = $state<number>(0); // Total tokens in current response
prefillProgress = $state<PrefillProgress | null>(null);
// Abort controller for stopping generation
private currentAbortController: AbortController | null = null;
@@ -2018,6 +2020,7 @@ class AppStore {
): Promise<void> {
const decoder = new TextDecoder();
let buffer = "";
let currentEventType = "";
while (true) {
const { done, value } = await reader.read();
@@ -2033,7 +2036,15 @@ class AppStore {
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
if (!trimmed) {
currentEventType = "";
continue;
}
if (trimmed.startsWith("event: ")) {
currentEventType = trimmed.slice(7);
continue;
}
// Handle SSE comments (": key json") for prefill progress etc.
if (trimmed.startsWith(": ") && onEvent) {
@@ -2055,14 +2066,22 @@ class AppStore {
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
if (data === "[DONE]") {
currentEventType = "";
continue;
}
try {
const parsed = JSON.parse(data) as T;
onChunk(parsed);
const parsed = JSON.parse(data);
if (currentEventType && onEvent?.[currentEventType]) {
onEvent[currentEventType](parsed);
} else {
onChunk(parsed as T);
}
} catch {
// Skip malformed JSON
}
currentEventType = "";
}
}
}
@@ -2163,6 +2182,7 @@ class AppStore {
this.isLoading = true;
this.currentResponse = "";
this.prefillProgress = null;
this.ttftMs = null;
this.tps = null;
this.totalTokens = 0;
@@ -2367,6 +2387,11 @@ class AppStore {
}
if (tokenContent) {
// Clear prefill progress once tokens start arriving
if (this.prefillProgress !== null) {
this.prefillProgress = null;
}
// Track first token for TTFT
if (firstTokenTime === null) {
firstTokenTime = performance.now();
@@ -2420,6 +2445,7 @@ class AppStore {
this.prefillProgress = {
processed: inner.processed_tokens,
total: inner.total_tokens,
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
};
},
},
@@ -2474,6 +2500,7 @@ class AppStore {
this.prefillProgress = null;
this.isLoading = false;
this.currentResponse = "";
this.prefillProgress = null;
this.saveConversationsToStorage();
}
}
@@ -3106,10 +3133,10 @@ export const hasStartedChat = () => appStore.hasStartedChat;
export const messages = () => appStore.messages;
export const currentResponse = () => appStore.currentResponse;
export const isLoading = () => appStore.isLoading;
export const prefillProgress = () => appStore.prefillProgress;
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const prefillProgress = () => appStore.prefillProgress;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const runners = () => appStore.runners;

View File

@@ -338,17 +338,7 @@ class DownloadCoordinator:
),
)
elif progress.status in ["in_progress", "not_started"]:
if (
progress.downloaded_bytes.in_bytes
>= progress.total_bytes.in_bytes
> 0
):
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,

View File

@@ -258,7 +258,7 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")

View File

@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
@@ -219,13 +225,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
yield _format_sse(created_event)
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
yield _format_sse(in_progress_event)
# response.output_item.added
initial_item = ResponseMessageItem(
@@ -236,7 +242,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
yield _format_sse(item_added)
# response.content_part.added
initial_part = ResponseOutputText(text="")
@@ -247,7 +253,7 @@ async def generate_responses_stream(
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
yield _format_sse(part_added)
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
@@ -281,7 +287,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
yield _format_sse(fc_added)
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -290,7 +296,7 @@ async def generate_responses_stream(
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
yield _format_sse(args_delta)
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -300,7 +306,7 @@ async def generate_responses_stream(
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
yield _format_sse(args_done)
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
@@ -315,7 +321,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item)
next_output_index += 1
@@ -331,7 +337,7 @@ async def generate_responses_stream(
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
yield _format_sse(delta_event)
# response.output_text.done
text_done = ResponseTextDoneEvent(
@@ -341,7 +347,7 @@ async def generate_responses_stream(
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
yield _format_sse(text_done)
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
@@ -352,7 +358,7 @@ async def generate_responses_stream(
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
yield _format_sse(part_done)
# response.output_item.done
final_message_item = ResponseMessageItem(
@@ -363,7 +369,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
yield _format_sse(item_done)
# Create usage from usage data if available
usage = None
@@ -388,4 +394,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
yield _format_sse(completed_event)

View File

@@ -562,6 +562,8 @@ class API:
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
async def _collect_text_generation_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:

View File

@@ -90,6 +90,8 @@ def prefill(
)
if has_ssm:
snapshots.append(snapshot_ssm_states(cache))
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
if on_prefill_progress is not None:
on_prefill_progress(processed, total)

View File

@@ -98,16 +98,11 @@ class RunnerSupervisor:
def shutdown(self):
logger.info("Runner supervisor shutting down")
with contextlib.suppress(ClosedResourceError):
self._ev_recv.close()
with contextlib.suppress(ClosedResourceError):
self._task_sender.close()
with contextlib.suppress(ClosedResourceError):
self._event_sender.close()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(5)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")