mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 19:22:39 -05:00
Compare commits
29 Commits
rust-explo
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
937da476b0 | ||
|
|
1c1c286127 | ||
|
|
258785be84 | ||
|
|
13a6b9819a | ||
|
|
1733d07cb3 | ||
|
|
b3e4c9b1e5 | ||
|
|
4c74792373 | ||
|
|
eadb6de1f7 | ||
|
|
7ba2408eed | ||
|
|
ce4d7f4d43 | ||
|
|
6727523eab | ||
|
|
a8f81e0495 | ||
|
|
ba7148ccec | ||
|
|
a64b8addc6 | ||
|
|
e6599a9408 | ||
|
|
93f4753598 | ||
|
|
75fe505275 | ||
|
|
d7c044e349 | ||
|
|
53b6d56e9f | ||
|
|
7fe0a61230 | ||
|
|
5a36542631 | ||
|
|
955e0105b3 | ||
|
|
4d1eb1d9bd | ||
|
|
365416c65e | ||
|
|
04af76e10f | ||
|
|
a84c3431cd | ||
|
|
52445b21f6 | ||
|
|
435bd7f6fa | ||
|
|
dd25b5b90e |
@@ -6,11 +6,12 @@
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken,
|
||||
setEditingImage,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import type { Message } from "$lib/stores/app.svelte";
|
||||
import type { MessageAttachment } from "$lib/stores/app.svelte";
|
||||
import MarkdownContent from "./MarkdownContent.svelte";
|
||||
import TokenHeatmap from "./TokenHeatmap.svelte";
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -99,6 +100,23 @@
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
// Uncertainty heatmap toggle
|
||||
let heatmapMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function toggleHeatmap(messageId: string) {
|
||||
const next = new Set(heatmapMessageIds);
|
||||
if (next.has(messageId)) {
|
||||
next.delete(messageId);
|
||||
} else {
|
||||
next.add(messageId);
|
||||
}
|
||||
heatmapMessageIds = next;
|
||||
}
|
||||
|
||||
function isHeatmapVisible(messageId: string): boolean {
|
||||
return heatmapMessageIds.has(messageId);
|
||||
}
|
||||
|
||||
function formatTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toLocaleTimeString("en-US", {
|
||||
hour12: false,
|
||||
@@ -548,13 +566,23 @@
|
||||
>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some((a) => a.type === "generated-image"))}
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
{#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}
|
||||
<TokenHeatmap
|
||||
tokens={message.tokens}
|
||||
isGenerating={loading &&
|
||||
isLastAssistantMessage(message.id)}
|
||||
onRegenerateFrom={(tokenIndex) =>
|
||||
regenerateFromToken(message.id, tokenIndex)}
|
||||
/>
|
||||
{:else}
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
{/if}
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
@@ -629,6 +657,35 @@
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Uncertainty heatmap toggle (assistant messages with tokens) -->
|
||||
{#if message.role === "assistant" && message.tokens && message.tokens.length > 0}
|
||||
<button
|
||||
onclick={() => toggleHeatmap(message.id)}
|
||||
class="p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(
|
||||
message.id,
|
||||
)
|
||||
? 'text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
title={isHeatmapVisible(message.id)
|
||||
? "Hide uncertainty heatmap"
|
||||
: "Show uncertainty heatmap"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button (last assistant message only) -->
|
||||
{#if message.role === "assistant" && isLastAssistantMessage(message.id) && !loading}
|
||||
<button
|
||||
|
||||
51
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
51
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
@@ -0,0 +1,51 @@
|
||||
<script lang="ts">
|
||||
import type { PrefillProgress } from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
progress: PrefillProgress;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { progress, class: className = "" }: Props = $props();
|
||||
|
||||
const percentage = $derived(
|
||||
progress.total > 0
|
||||
? Math.round((progress.processed / progress.total) * 100)
|
||||
: 0,
|
||||
);
|
||||
|
||||
function formatTokenCount(count: number): string {
|
||||
if (count >= 1000) {
|
||||
return `${(count / 1000).toFixed(1)}k`;
|
||||
}
|
||||
return count.toString();
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="prefill-progress {className}">
|
||||
<div
|
||||
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
|
||||
>
|
||||
<span>Processing prompt</span>
|
||||
<span class="font-mono">
|
||||
{formatTokenCount(progress.processed)} / {formatTokenCount(
|
||||
progress.total,
|
||||
)} tokens
|
||||
</span>
|
||||
</div>
|
||||
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
|
||||
style="width: {percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
|
||||
{percentage}%
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.prefill-progress {
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
236
dashboard/src/lib/components/TokenHeatmap.svelte
Normal file
236
dashboard/src/lib/components/TokenHeatmap.svelte
Normal file
@@ -0,0 +1,236 @@
|
||||
<script lang="ts">
|
||||
import type { TokenData } from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
tokens: TokenData[];
|
||||
class?: string;
|
||||
isGenerating?: boolean;
|
||||
onRegenerateFrom?: (tokenIndex: number) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
tokens,
|
||||
class: className = "",
|
||||
isGenerating = false,
|
||||
onRegenerateFrom,
|
||||
}: Props = $props();
|
||||
|
||||
// Tooltip state - track both token data and index
|
||||
let hoveredTokenIndex = $state<number | null>(null);
|
||||
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
|
||||
let isTooltipHovered = $state(false);
|
||||
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
// Derive the hovered token from the index (stable across re-renders)
|
||||
const hoveredToken = $derived(
|
||||
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
|
||||
? {
|
||||
token: tokens[hoveredTokenIndex],
|
||||
index: hoveredTokenIndex,
|
||||
...hoveredPosition,
|
||||
}
|
||||
: null,
|
||||
);
|
||||
|
||||
/**
|
||||
* Get confidence styling based on probability.
|
||||
* Following Apple design principles: high confidence tokens blend in,
|
||||
* only uncertainty draws attention.
|
||||
*/
|
||||
function getConfidenceClass(probability: number): string {
|
||||
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
|
||||
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
|
||||
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
|
||||
return "bg-red-500/20 text-red-200/90"; // Draws attention
|
||||
}
|
||||
|
||||
/**
|
||||
* Get border/underline styling for uncertain tokens
|
||||
*/
|
||||
function getBorderClass(probability: number): string {
|
||||
if (probability > 0.8) return "border-transparent"; // No border for expected
|
||||
if (probability > 0.5) return "border-gray-500/20";
|
||||
if (probability > 0.2) return "border-amber-500/30";
|
||||
return "border-red-500/40";
|
||||
}
|
||||
|
||||
function clearHideTimeout() {
|
||||
if (hideTimeoutId) {
|
||||
clearTimeout(hideTimeoutId);
|
||||
hideTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseEnter(
|
||||
event: MouseEvent,
|
||||
token: TokenData,
|
||||
index: number,
|
||||
) {
|
||||
clearHideTimeout();
|
||||
const rects = (event.target as HTMLElement).getClientRects();
|
||||
let rect = rects[0];
|
||||
for (let j = 0; j < rects.length; j++) {
|
||||
if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {
|
||||
rect = rects[j];
|
||||
break;
|
||||
}
|
||||
}
|
||||
hoveredTokenIndex = index;
|
||||
hoveredPosition = {
|
||||
x: rect.left + rect.width / 2,
|
||||
y: rect.top - 10,
|
||||
};
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
clearHideTimeout();
|
||||
// Use longer delay during generation to account for re-renders
|
||||
const delay = isGenerating ? 300 : 200;
|
||||
hideTimeoutId = setTimeout(() => {
|
||||
if (!isTooltipHovered) {
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
}, delay);
|
||||
}
|
||||
|
||||
function handleTooltipEnter() {
|
||||
clearHideTimeout();
|
||||
isTooltipHovered = true;
|
||||
}
|
||||
|
||||
function handleTooltipLeave() {
|
||||
isTooltipHovered = false;
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
|
||||
function handleRegenerate() {
|
||||
if (hoveredToken && onRegenerateFrom) {
|
||||
const indexToRegenerate = hoveredToken.index;
|
||||
// Clear hover state immediately
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
isTooltipHovered = false;
|
||||
// Call regenerate
|
||||
onRegenerateFrom(indexToRegenerate);
|
||||
}
|
||||
}
|
||||
|
||||
function formatProbability(prob: number): string {
|
||||
return (prob * 100).toFixed(1) + "%";
|
||||
}
|
||||
|
||||
function formatLogprob(logprob: number): string {
|
||||
return logprob.toFixed(3);
|
||||
}
|
||||
|
||||
function getProbabilityColor(probability: number): string {
|
||||
if (probability > 0.8) return "text-gray-300";
|
||||
if (probability > 0.5) return "text-gray-400";
|
||||
if (probability > 0.2) return "text-amber-400";
|
||||
return "text-red-400";
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="token-heatmap leading-relaxed {className}">
|
||||
{#each tokens as tokenData, i (i)}
|
||||
<span
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
|
||||
tokenData.probability,
|
||||
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
|
||||
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
|
||||
onmouseleave={handleMouseLeave}>{tokenData.token}</span
|
||||
>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Tooltip -->
|
||||
{#if hoveredToken}
|
||||
<div
|
||||
class="fixed z-50 pb-2"
|
||||
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
|
||||
onmouseenter={handleTooltipEnter}
|
||||
onmouseleave={handleTooltipLeave}
|
||||
>
|
||||
<div
|
||||
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
|
||||
>
|
||||
<!-- Token info -->
|
||||
<div class="mb-2">
|
||||
<span class="text-gray-500 text-xs">Token:</span>
|
||||
<span class="text-white font-mono ml-1"
|
||||
>"{hoveredToken.token.token}"</span
|
||||
>
|
||||
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
|
||||
>{formatProbability(hoveredToken.token.probability)}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<div class="text-gray-400 text-xs mb-1">
|
||||
logprob: <span class="text-gray-300 font-mono"
|
||||
>{formatLogprob(hoveredToken.token.logprob)}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<!-- Top alternatives -->
|
||||
{#if hoveredToken.token.topLogprobs.length > 0}
|
||||
<div class="border-t border-gray-700/50 mt-2 pt-2">
|
||||
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
|
||||
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
|
||||
{@const altProb = Math.exp(alt.logprob)}
|
||||
<div class="flex justify-between items-center text-xs py-0.5">
|
||||
<span class="text-gray-300 font-mono truncate max-w-24"
|
||||
>"{alt.token}"</span
|
||||
>
|
||||
<span class="text-gray-400 ml-2"
|
||||
>{formatProbability(altProb)}</span
|
||||
>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button -->
|
||||
{#if onRegenerateFrom}
|
||||
<button
|
||||
onclick={handleRegenerate}
|
||||
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
|
||||
/>
|
||||
</svg>
|
||||
Regenerate from here
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Arrow -->
|
||||
<div class="absolute left-1/2 -translate-x-1/2 top-full">
|
||||
<div class="border-8 border-transparent border-t-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.token-heatmap {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.token-span {
|
||||
margin: 0;
|
||||
border-width: 1px;
|
||||
}
|
||||
</style>
|
||||
@@ -242,6 +242,24 @@ export interface MessageAttachment {
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
export interface TopLogprob {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}
|
||||
|
||||
export interface TokenData {
|
||||
token: string;
|
||||
logprob: number;
|
||||
probability: number;
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface PrefillProgress {
|
||||
processed: number;
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -253,6 +271,7 @@ export interface Message {
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
requestType?: "chat" | "image-generation" | "image-editing";
|
||||
sourceImageDataUrl?: string; // For image editing regeneration
|
||||
tokens?: TokenData[];
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -540,7 +559,18 @@ class AppStore {
|
||||
*/
|
||||
private saveConversationsToStorage() {
|
||||
try {
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
|
||||
// Strip tokens from messages before saving to avoid bloating localStorage
|
||||
const stripped = this.conversations.map((conv) => ({
|
||||
...conv,
|
||||
messages: conv.messages.map((msg) => {
|
||||
if (msg.tokens) {
|
||||
const { tokens: _, ...rest } = msg;
|
||||
return rest;
|
||||
}
|
||||
return msg;
|
||||
}),
|
||||
}));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));
|
||||
} catch (error) {
|
||||
console.error("Failed to save conversations:", error);
|
||||
}
|
||||
@@ -1445,6 +1475,213 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate response from a specific token index.
|
||||
* Truncates the assistant message at the given token and re-generates from there.
|
||||
*/
|
||||
async regenerateFromToken(
|
||||
messageId: string,
|
||||
tokenIndex: number,
|
||||
): Promise<void> {
|
||||
if (this.isLoading) return;
|
||||
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
const msgIndex = this.messages.findIndex((m) => m.id === messageId);
|
||||
if (msgIndex === -1) return;
|
||||
|
||||
const msg = this.messages[msgIndex];
|
||||
if (
|
||||
msg.role !== "assistant" ||
|
||||
!msg.tokens ||
|
||||
tokenIndex >= msg.tokens.length
|
||||
)
|
||||
return;
|
||||
|
||||
// Keep tokens up to (not including) the specified index
|
||||
const tokensToKeep = msg.tokens.slice(0, tokenIndex);
|
||||
const prefixText = tokensToKeep.map((t) => t.token).join("");
|
||||
|
||||
// Remove all messages after this assistant message
|
||||
this.messages = this.messages.slice(0, msgIndex + 1);
|
||||
|
||||
// Update the message to show the prefix
|
||||
this.messages[msgIndex].content = prefixText;
|
||||
this.messages[msgIndex].tokens = tokensToKeep;
|
||||
this.updateActiveConversation();
|
||||
|
||||
// Set up for continuation - modify the existing message in place
|
||||
this.isLoading = true;
|
||||
this.currentResponse = prefixText;
|
||||
this.ttftMs = null;
|
||||
this.tps = null;
|
||||
this.totalTokens = tokensToKeep.length;
|
||||
|
||||
try {
|
||||
// Build messages for API - include the partial assistant message
|
||||
const systemPrompt = {
|
||||
role: "system" as const,
|
||||
content:
|
||||
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
|
||||
};
|
||||
|
||||
const apiMessages = [
|
||||
systemPrompt,
|
||||
...this.messages.map((m) => {
|
||||
let msgContent = m.content;
|
||||
if (m.attachments) {
|
||||
for (const attachment of m.attachments) {
|
||||
if (attachment.type === "text" && attachment.content) {
|
||||
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
|
||||
}
|
||||
}
|
||||
}
|
||||
return { role: m.role, content: msgContent };
|
||||
}),
|
||||
];
|
||||
|
||||
const modelToUse = this.getModelForRequest();
|
||||
if (!modelToUse) {
|
||||
throw new Error("No model available");
|
||||
}
|
||||
|
||||
const requestStartTime = performance.now();
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = tokensToKeep.length;
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) throw new Error("No response body");
|
||||
|
||||
let fullContent = prefixText;
|
||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
}
|
||||
|
||||
tokenCount += 1;
|
||||
this.totalTokens = tokenCount;
|
||||
|
||||
if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {
|
||||
const elapsed = performance.now() - firstTokenTime;
|
||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
this.currentResponse = displayContent;
|
||||
}
|
||||
|
||||
// Update existing message in place
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
messageId,
|
||||
(m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Final update
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||
if (this.tps !== null) m.tps = this.tps;
|
||||
});
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error regenerating from token:", error);
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = `${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
|
||||
});
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} finally {
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to regenerate a chat completion response
|
||||
*/
|
||||
@@ -1513,6 +1750,8 @@ class AppStore {
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1527,16 +1766,49 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{ delta?: { content?: string } }>;
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const delta = parsed.choices?.[0]?.delta?.content;
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
@@ -1554,6 +1826,7 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -1572,6 +1845,7 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -1914,6 +2188,8 @@ class AppStore {
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1930,14 +2206,48 @@ class AppStore {
|
||||
let streamedContent = "";
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{ delta?: { content?: string } }>;
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const tokenContent = parsed.choices?.[0]?.delta?.content;
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenContent) {
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
@@ -1973,6 +2283,7 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -1997,6 +2308,7 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
// Store performance metrics on the message
|
||||
if (this.ttftMs !== null) {
|
||||
msg.ttftMs = this.ttftMs;
|
||||
@@ -2693,6 +3005,8 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
|
||||
appStore.regenerateFromToken(messageId, tokenIndex);
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -14,10 +14,17 @@ from exo.shared.types.api import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressData,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
@@ -81,6 +88,8 @@ def chat_request_to_text_generation(
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
logprobs=request.logprobs or False,
|
||||
top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,6 +97,19 @@ def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
|
||||
# Build logprobs if available
|
||||
logprobs: Logprobs | None = None
|
||||
if chunk.logprob is not None:
|
||||
logprobs = Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -96,6 +118,7 @@ def chunk_to_response(
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -104,55 +127,65 @@ def chunk_to_response(
|
||||
|
||||
async def generate_chat_stream(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
event_stream: AsyncGenerator[
|
||||
PrefillProgressData | ErrorChunk | ToolCallChunk | TokenChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
"""Generate Chat Completions API streaming events from StreamEvents.
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
Handles PrefillProgressData, ErrorChunk, ToolCallChunk, and TokenChunk.
|
||||
"""
|
||||
async for event in event_stream:
|
||||
match event:
|
||||
case PrefillProgressData():
|
||||
yield f"event: prefill_progress\ndata: {event.model_dump_json()}\n\n"
|
||||
|
||||
case ErrorChunk():
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=event.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
],
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
case ToolCallChunk():
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(event.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=event.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
case TokenChunk():
|
||||
chunk_response = chunk_to_response(event, command_id)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if event.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def collect_chat_response(
|
||||
@@ -162,6 +195,7 @@ async def collect_chat_response(
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_content: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
@@ -176,6 +210,14 @@ async def collect_chat_response(
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -208,6 +250,9 @@ async def collect_chat_response(
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
logprobs=Logprobs(content=logprobs_content)
|
||||
if logprobs_content
|
||||
else None,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
|
||||
@@ -98,6 +98,7 @@ from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
PrefillProgressData,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
@@ -127,6 +128,7 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -199,7 +201,8 @@ class API:
|
||||
)
|
||||
|
||||
self._text_generation_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
CommandId,
|
||||
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData],
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
@@ -493,22 +496,27 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _token_chunk_stream(
|
||||
async def _stream_events(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
"""Yield chunks for a given command until completion.
|
||||
) -> AsyncGenerator[
|
||||
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData, None
|
||||
]:
|
||||
"""Yield stream events for a command.
|
||||
|
||||
This is the internal low-level stream used by all API adapters.
|
||||
"""
|
||||
try:
|
||||
self._text_generation_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
with recv as events:
|
||||
async for event in events:
|
||||
yield event
|
||||
if (
|
||||
isinstance(event, TokenChunk)
|
||||
and event.finish_reason is not None
|
||||
):
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
@@ -525,6 +533,14 @@ class API:
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
async def _chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
"""Yield chunks, filtering out prefill progress events."""
|
||||
async for event in self._stream_events(command_id):
|
||||
if not isinstance(event, PrefillProgressData):
|
||||
yield event
|
||||
|
||||
async def _collect_text_generation_with_stats(
|
||||
self, command_id: CommandId
|
||||
) -> BenchChatCompletionResponse:
|
||||
@@ -535,7 +551,7 @@ class API:
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._token_chunk_stream(command_id):
|
||||
async for chunk in self._chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -607,15 +623,23 @@ class API:
|
||||
return StreamingResponse(
|
||||
generate_chat_stream(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
self._stream_events(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
try:
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._chunk_stream(command.command_id),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
@@ -1156,16 +1180,24 @@ class API:
|
||||
generate_claude_stream(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
self._chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
try:
|
||||
return await collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._chunk_stream(command.command_id),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
@@ -1183,16 +1215,24 @@ class API:
|
||||
generate_responses_stream(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
self._chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
try:
|
||||
return await collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._chunk_stream(command.command_id),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
@@ -1275,6 +1315,20 @@ class API:
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
elif isinstance(event, PrefillProgress):
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
try:
|
||||
await queue.send(
|
||||
PrefillProgressData(
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from exo.shared.types.events import (
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -61,6 +62,7 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
): # Pass-through events that don't modify state
|
||||
|
||||
@@ -2,7 +2,12 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
|
||||
from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -20,6 +25,8 @@ class TokenChunk(BaseChunk):
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
@@ -70,3 +77,13 @@ class InputImageChunk(BaseChunk):
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
|
||||
|
||||
class PrefillProgressData(TaggedModel):
|
||||
"""Data class for prefill progress events during streaming."""
|
||||
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
StreamEvent = TokenChunk | PrefillProgressData
|
||||
|
||||
@@ -102,6 +102,12 @@ class InputChunkReceived(BaseEvent):
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -148,6 +154,7 @@ Event = (
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
|
||||
@@ -40,3 +40,5 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
|
||||
stop: str | list[str] | None = None
|
||||
seed: int | None = None
|
||||
chat_template_messages: list[dict[str, Any]] | None = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: int | None = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
@@ -22,7 +23,8 @@ class TokenizedResponse(BaseRunnerResponse):
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
@@ -64,3 +66,8 @@ class ToolCallResponse(BaseRunnerResponse):
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
|
||||
class PrefillProgressResponse(BaseRunnerResponse):
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
@@ -11,5 +11,7 @@ QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = None
|
||||
|
||||
DEFAULT_TOP_LOGPROBS: int = 5
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -12,6 +12,7 @@ from exo.shared.types.api import (
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
PromptTokensDetails,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
@@ -23,7 +24,12 @@ from exo.shared.types.worker.runner_response import (
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
DEFAULT_TOP_LOGPROBS,
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
mx_barrier,
|
||||
@@ -73,7 +79,7 @@ def prefill(
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
prefill_step_size=1024,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
@@ -121,7 +127,7 @@ def warmup_inference(
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
prefill_step_size=1024,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
@@ -155,12 +161,67 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_logprobs: int,
|
||||
selected_token: int,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top alternative tokens.
|
||||
|
||||
Args:
|
||||
logprobs: Full vocabulary logprobs array from MLX
|
||||
tokenizer: Tokenizer for decoding token IDs to strings
|
||||
top_logprobs: Number of top alternatives to return
|
||||
selected_token: The token ID that was actually sampled
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_token_logprob, list of TopLogprobItem for top alternatives)
|
||||
"""
|
||||
# Get the logprob of the selected token
|
||||
selected_logprob = float(logprobs[selected_token].item())
|
||||
|
||||
# Get top indices (most probable tokens)
|
||||
# mx.argpartition gives indices that would partition the array
|
||||
# We negate logprobs since argpartition finds smallest, and we want largest
|
||||
top_logprobs = min(top_logprobs, logprobs.shape[0]) # Don't exceed vocab size
|
||||
top_indices = mx.argpartition(-logprobs, top_logprobs)[:top_logprobs]
|
||||
|
||||
# Get the actual logprob values for these indices
|
||||
top_values = logprobs[top_indices]
|
||||
|
||||
# Sort by logprob (descending) for consistent ordering
|
||||
sort_order = mx.argsort(-top_values)
|
||||
top_indices = top_indices[sort_order]
|
||||
top_values = top_values[sort_order]
|
||||
|
||||
# Convert to list of TopLogprobItem
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for i in range(top_logprobs):
|
||||
token_id = int(top_indices[i].item())
|
||||
token_logprob = float(top_values[i].item())
|
||||
# Decode token ID to string
|
||||
token_str = tokenizer.decode([token_id])
|
||||
# Get byte representation
|
||||
token_bytes = list(token_str.encode("utf-8"))
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=token_logprob,
|
||||
bytes=token_bytes,
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -232,9 +293,10 @@ def mlx_generate(
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
prefill_step_size=1024,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=on_prefill_progress,
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
@@ -296,9 +358,22 @@ def mlx_generate(
|
||||
),
|
||||
)
|
||||
|
||||
# Extract logprobs from the full vocabulary logprobs array
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if task.logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs=out.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_logprobs=task.top_logprobs or DEFAULT_TOP_LOGPROBS,
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
|
||||
@@ -442,6 +442,12 @@ def apply_chat_template(
|
||||
continue
|
||||
formatted_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# For assistant prefilling, append content after templating to avoid a closing turn token.
|
||||
partial_assistant_content: str | None = None
|
||||
if formatted_messages and formatted_messages[-1].get("role") == "assistant":
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
@@ -449,6 +455,9 @@ def apply_chat_template(
|
||||
tools=task_params.tools,
|
||||
)
|
||||
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -25,6 +25,7 @@ from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
@@ -237,6 +238,17 @@ def main(
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
# Define callback to send prefill progress events directly
|
||||
def on_prefill_progress(processed: int, total: int) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
PrefillProgress(
|
||||
command_id=command_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
|
||||
@@ -250,6 +262,7 @@ def main(
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
@@ -320,6 +333,8 @@ def main(
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user