mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 11:11:45 -05:00
Compare commits
23 Commits
sami/dashb
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
258785be84 | ||
|
|
13a6b9819a | ||
|
|
1733d07cb3 | ||
|
|
b3e4c9b1e5 | ||
|
|
4c74792373 | ||
|
|
eadb6de1f7 | ||
|
|
ba7148ccec | ||
|
|
a64b8addc6 | ||
|
|
e6599a9408 | ||
|
|
93f4753598 | ||
|
|
75fe505275 | ||
|
|
d7c044e349 | ||
|
|
53b6d56e9f | ||
|
|
7fe0a61230 | ||
|
|
5a36542631 | ||
|
|
955e0105b3 | ||
|
|
4d1eb1d9bd | ||
|
|
365416c65e | ||
|
|
04af76e10f | ||
|
|
a84c3431cd | ||
|
|
52445b21f6 | ||
|
|
435bd7f6fa | ||
|
|
dd25b5b90e |
@@ -6,11 +6,13 @@
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken,
|
||||
setEditingImage,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import type { Message } from "$lib/stores/app.svelte";
|
||||
import type { MessageAttachment } from "$lib/stores/app.svelte";
|
||||
import MarkdownContent from "./MarkdownContent.svelte";
|
||||
import TokenHeatmap from "./TokenHeatmap.svelte";
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -99,6 +101,23 @@
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
// Uncertainty heatmap toggle
|
||||
let heatmapMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function toggleHeatmap(messageId: string) {
|
||||
const next = new Set(heatmapMessageIds);
|
||||
if (next.has(messageId)) {
|
||||
next.delete(messageId);
|
||||
} else {
|
||||
next.add(messageId);
|
||||
}
|
||||
heatmapMessageIds = next;
|
||||
}
|
||||
|
||||
function isHeatmapVisible(messageId: string): boolean {
|
||||
return heatmapMessageIds.has(messageId);
|
||||
}
|
||||
|
||||
function formatTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toLocaleTimeString("en-US", {
|
||||
hour12: false,
|
||||
@@ -548,13 +567,23 @@
|
||||
>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some((a) => a.type === "generated-image"))}
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
{#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}
|
||||
<TokenHeatmap
|
||||
tokens={message.tokens}
|
||||
isGenerating={loading &&
|
||||
isLastAssistantMessage(message.id)}
|
||||
onRegenerateFrom={(tokenIndex) =>
|
||||
regenerateFromToken(message.id, tokenIndex)}
|
||||
/>
|
||||
{:else}
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
{/if}
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
@@ -629,6 +658,35 @@
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Uncertainty heatmap toggle (assistant messages with tokens) -->
|
||||
{#if message.role === "assistant" && message.tokens && message.tokens.length > 0}
|
||||
<button
|
||||
onclick={() => toggleHeatmap(message.id)}
|
||||
class="p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(
|
||||
message.id,
|
||||
)
|
||||
? 'text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
title={isHeatmapVisible(message.id)
|
||||
? "Hide uncertainty heatmap"
|
||||
: "Show uncertainty heatmap"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button (last assistant message only) -->
|
||||
{#if message.role === "assistant" && isLastAssistantMessage(message.id) && !loading}
|
||||
<button
|
||||
|
||||
236
dashboard/src/lib/components/TokenHeatmap.svelte
Normal file
236
dashboard/src/lib/components/TokenHeatmap.svelte
Normal file
@@ -0,0 +1,236 @@
|
||||
<script lang="ts">
|
||||
import type { TokenData } from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
tokens: TokenData[];
|
||||
class?: string;
|
||||
isGenerating?: boolean;
|
||||
onRegenerateFrom?: (tokenIndex: number) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
tokens,
|
||||
class: className = "",
|
||||
isGenerating = false,
|
||||
onRegenerateFrom,
|
||||
}: Props = $props();
|
||||
|
||||
// Tooltip state - track both token data and index
|
||||
let hoveredTokenIndex = $state<number | null>(null);
|
||||
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
|
||||
let isTooltipHovered = $state(false);
|
||||
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
// Derive the hovered token from the index (stable across re-renders)
|
||||
const hoveredToken = $derived(
|
||||
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
|
||||
? {
|
||||
token: tokens[hoveredTokenIndex],
|
||||
index: hoveredTokenIndex,
|
||||
...hoveredPosition,
|
||||
}
|
||||
: null,
|
||||
);
|
||||
|
||||
/**
|
||||
* Get confidence styling based on probability.
|
||||
* Following Apple design principles: high confidence tokens blend in,
|
||||
* only uncertainty draws attention.
|
||||
*/
|
||||
function getConfidenceClass(probability: number): string {
|
||||
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
|
||||
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
|
||||
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
|
||||
return "bg-red-500/20 text-red-200/90"; // Draws attention
|
||||
}
|
||||
|
||||
/**
|
||||
* Get border/underline styling for uncertain tokens
|
||||
*/
|
||||
function getBorderClass(probability: number): string {
|
||||
if (probability > 0.8) return "border-transparent"; // No border for expected
|
||||
if (probability > 0.5) return "border-gray-500/20";
|
||||
if (probability > 0.2) return "border-amber-500/30";
|
||||
return "border-red-500/40";
|
||||
}
|
||||
|
||||
function clearHideTimeout() {
|
||||
if (hideTimeoutId) {
|
||||
clearTimeout(hideTimeoutId);
|
||||
hideTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseEnter(
|
||||
event: MouseEvent,
|
||||
token: TokenData,
|
||||
index: number,
|
||||
) {
|
||||
clearHideTimeout();
|
||||
const rects = (event.target as HTMLElement).getClientRects();
|
||||
let rect = rects[0];
|
||||
for (let j = 0; j < rects.length; j++) {
|
||||
if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {
|
||||
rect = rects[j];
|
||||
break;
|
||||
}
|
||||
}
|
||||
hoveredTokenIndex = index;
|
||||
hoveredPosition = {
|
||||
x: rect.left + rect.width / 2,
|
||||
y: rect.top - 10,
|
||||
};
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
clearHideTimeout();
|
||||
// Use longer delay during generation to account for re-renders
|
||||
const delay = isGenerating ? 300 : 200;
|
||||
hideTimeoutId = setTimeout(() => {
|
||||
if (!isTooltipHovered) {
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
}, delay);
|
||||
}
|
||||
|
||||
function handleTooltipEnter() {
|
||||
clearHideTimeout();
|
||||
isTooltipHovered = true;
|
||||
}
|
||||
|
||||
function handleTooltipLeave() {
|
||||
isTooltipHovered = false;
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
|
||||
function handleRegenerate() {
|
||||
if (hoveredToken && onRegenerateFrom) {
|
||||
const indexToRegenerate = hoveredToken.index;
|
||||
// Clear hover state immediately
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
isTooltipHovered = false;
|
||||
// Call regenerate
|
||||
onRegenerateFrom(indexToRegenerate);
|
||||
}
|
||||
}
|
||||
|
||||
function formatProbability(prob: number): string {
|
||||
return (prob * 100).toFixed(1) + "%";
|
||||
}
|
||||
|
||||
function formatLogprob(logprob: number): string {
|
||||
return logprob.toFixed(3);
|
||||
}
|
||||
|
||||
function getProbabilityColor(probability: number): string {
|
||||
if (probability > 0.8) return "text-gray-300";
|
||||
if (probability > 0.5) return "text-gray-400";
|
||||
if (probability > 0.2) return "text-amber-400";
|
||||
return "text-red-400";
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="token-heatmap leading-relaxed {className}">
|
||||
{#each tokens as tokenData, i (i)}
|
||||
<span
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
|
||||
tokenData.probability,
|
||||
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
|
||||
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
|
||||
onmouseleave={handleMouseLeave}>{tokenData.token}</span
|
||||
>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Tooltip -->
|
||||
{#if hoveredToken}
|
||||
<div
|
||||
class="fixed z-50 pb-2"
|
||||
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
|
||||
onmouseenter={handleTooltipEnter}
|
||||
onmouseleave={handleTooltipLeave}
|
||||
>
|
||||
<div
|
||||
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
|
||||
>
|
||||
<!-- Token info -->
|
||||
<div class="mb-2">
|
||||
<span class="text-gray-500 text-xs">Token:</span>
|
||||
<span class="text-white font-mono ml-1"
|
||||
>"{hoveredToken.token.token}"</span
|
||||
>
|
||||
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
|
||||
>{formatProbability(hoveredToken.token.probability)}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<div class="text-gray-400 text-xs mb-1">
|
||||
logprob: <span class="text-gray-300 font-mono"
|
||||
>{formatLogprob(hoveredToken.token.logprob)}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<!-- Top alternatives -->
|
||||
{#if hoveredToken.token.topLogprobs.length > 0}
|
||||
<div class="border-t border-gray-700/50 mt-2 pt-2">
|
||||
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
|
||||
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
|
||||
{@const altProb = Math.exp(alt.logprob)}
|
||||
<div class="flex justify-between items-center text-xs py-0.5">
|
||||
<span class="text-gray-300 font-mono truncate max-w-24"
|
||||
>"{alt.token}"</span
|
||||
>
|
||||
<span class="text-gray-400 ml-2"
|
||||
>{formatProbability(altProb)}</span
|
||||
>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button -->
|
||||
{#if onRegenerateFrom}
|
||||
<button
|
||||
onclick={handleRegenerate}
|
||||
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
|
||||
/>
|
||||
</svg>
|
||||
Regenerate from here
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Arrow -->
|
||||
<div class="absolute left-1/2 -translate-x-1/2 top-full">
|
||||
<div class="border-8 border-transparent border-t-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.token-heatmap {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.token-span {
|
||||
margin: 0;
|
||||
border-width: 1px;
|
||||
}
|
||||
</style>
|
||||
@@ -242,6 +242,19 @@ export interface MessageAttachment {
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
export interface TopLogprob {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}
|
||||
|
||||
export interface TokenData {
|
||||
token: string;
|
||||
logprob: number;
|
||||
probability: number;
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -253,6 +266,7 @@ export interface Message {
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
requestType?: "chat" | "image-generation" | "image-editing";
|
||||
sourceImageDataUrl?: string; // For image editing regeneration
|
||||
tokens?: TokenData[];
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -540,7 +554,18 @@ class AppStore {
|
||||
*/
|
||||
private saveConversationsToStorage() {
|
||||
try {
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
|
||||
// Strip tokens from messages before saving to avoid bloating localStorage
|
||||
const stripped = this.conversations.map((conv) => ({
|
||||
...conv,
|
||||
messages: conv.messages.map((msg) => {
|
||||
if (msg.tokens) {
|
||||
const { tokens: _, ...rest } = msg;
|
||||
return rest;
|
||||
}
|
||||
return msg;
|
||||
}),
|
||||
}));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));
|
||||
} catch (error) {
|
||||
console.error("Failed to save conversations:", error);
|
||||
}
|
||||
@@ -1445,6 +1470,213 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate response from a specific token index.
|
||||
* Truncates the assistant message at the given token and re-generates from there.
|
||||
*/
|
||||
async regenerateFromToken(
|
||||
messageId: string,
|
||||
tokenIndex: number,
|
||||
): Promise<void> {
|
||||
if (this.isLoading) return;
|
||||
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
const msgIndex = this.messages.findIndex((m) => m.id === messageId);
|
||||
if (msgIndex === -1) return;
|
||||
|
||||
const msg = this.messages[msgIndex];
|
||||
if (
|
||||
msg.role !== "assistant" ||
|
||||
!msg.tokens ||
|
||||
tokenIndex >= msg.tokens.length
|
||||
)
|
||||
return;
|
||||
|
||||
// Keep tokens up to (not including) the specified index
|
||||
const tokensToKeep = msg.tokens.slice(0, tokenIndex);
|
||||
const prefixText = tokensToKeep.map((t) => t.token).join("");
|
||||
|
||||
// Remove all messages after this assistant message
|
||||
this.messages = this.messages.slice(0, msgIndex + 1);
|
||||
|
||||
// Update the message to show the prefix
|
||||
this.messages[msgIndex].content = prefixText;
|
||||
this.messages[msgIndex].tokens = tokensToKeep;
|
||||
this.updateActiveConversation();
|
||||
|
||||
// Set up for continuation - modify the existing message in place
|
||||
this.isLoading = true;
|
||||
this.currentResponse = prefixText;
|
||||
this.ttftMs = null;
|
||||
this.tps = null;
|
||||
this.totalTokens = tokensToKeep.length;
|
||||
|
||||
try {
|
||||
// Build messages for API - include the partial assistant message
|
||||
const systemPrompt = {
|
||||
role: "system" as const,
|
||||
content:
|
||||
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
|
||||
};
|
||||
|
||||
const apiMessages = [
|
||||
systemPrompt,
|
||||
...this.messages.map((m) => {
|
||||
let msgContent = m.content;
|
||||
if (m.attachments) {
|
||||
for (const attachment of m.attachments) {
|
||||
if (attachment.type === "text" && attachment.content) {
|
||||
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
|
||||
}
|
||||
}
|
||||
}
|
||||
return { role: m.role, content: msgContent };
|
||||
}),
|
||||
];
|
||||
|
||||
const modelToUse = this.getModelForRequest();
|
||||
if (!modelToUse) {
|
||||
throw new Error("No model available");
|
||||
}
|
||||
|
||||
const requestStartTime = performance.now();
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = tokensToKeep.length;
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) throw new Error("No response body");
|
||||
|
||||
let fullContent = prefixText;
|
||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
}
|
||||
|
||||
tokenCount += 1;
|
||||
this.totalTokens = tokenCount;
|
||||
|
||||
if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {
|
||||
const elapsed = performance.now() - firstTokenTime;
|
||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
this.currentResponse = displayContent;
|
||||
}
|
||||
|
||||
// Update existing message in place
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
messageId,
|
||||
(m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Final update
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||
if (this.tps !== null) m.tps = this.tps;
|
||||
});
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error regenerating from token:", error);
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = `${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
|
||||
});
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} finally {
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to regenerate a chat completion response
|
||||
*/
|
||||
@@ -1513,6 +1745,8 @@ class AppStore {
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1527,16 +1761,49 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{ delta?: { content?: string } }>;
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const delta = parsed.choices?.[0]?.delta?.content;
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
@@ -1554,6 +1821,7 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -1572,6 +1840,7 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -1914,6 +2183,8 @@ class AppStore {
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1930,14 +2201,48 @@ class AppStore {
|
||||
let streamedContent = "";
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{ delta?: { content?: string } }>;
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const tokenContent = parsed.choices?.[0]?.delta?.content;
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenContent) {
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
@@ -1973,6 +2278,7 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -1997,6 +2303,7 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
// Store performance metrics on the message
|
||||
if (this.ttftMs !== null) {
|
||||
msg.ttftMs = this.ttftMs;
|
||||
@@ -2693,6 +3000,8 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
|
||||
appStore.regenerateFromToken(messageId, tokenIndex);
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -14,6 +14,8 @@ from exo.shared.types.api import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
@@ -81,6 +83,8 @@ def chat_request_to_text_generation(
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
logprobs=request.logprobs or False,
|
||||
top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,6 +92,19 @@ def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
|
||||
# Build logprobs if available
|
||||
logprobs: Logprobs | None = None
|
||||
if chunk.logprob is not None:
|
||||
logprobs = Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -96,6 +113,7 @@ def chunk_to_response(
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -162,6 +180,7 @@ async def collect_chat_response(
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_content: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
@@ -176,6 +195,14 @@ async def collect_chat_response(
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -208,6 +235,9 @@ async def collect_chat_response(
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
logprobs=Logprobs(content=logprobs_content)
|
||||
if logprobs_content
|
||||
else None,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
|
||||
@@ -610,6 +610,11 @@ class API:
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
@@ -1159,6 +1164,11 @@ class API:
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_claude_response(
|
||||
@@ -1186,6 +1196,11 @@ class API:
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_responses_response(
|
||||
|
||||
@@ -2,7 +2,12 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
|
||||
from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -20,6 +25,8 @@ class TokenChunk(BaseChunk):
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
|
||||
@@ -40,3 +40,5 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
|
||||
stop: str | list[str] | None = None
|
||||
seed: int | None = None
|
||||
chat_template_messages: list[dict[str, Any]] | None = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: int | None = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
@@ -22,7 +23,8 @@ class TokenizedResponse(BaseRunnerResponse):
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
|
||||
@@ -11,5 +11,7 @@ QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = None
|
||||
|
||||
DEFAULT_TOP_LOGPROBS: int = 5
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -12,6 +12,7 @@ from exo.shared.types.api import (
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
PromptTokensDetails,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
@@ -23,7 +24,12 @@ from exo.shared.types.worker.runner_response import (
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
DEFAULT_TOP_LOGPROBS,
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
mx_barrier,
|
||||
@@ -155,6 +161,60 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_logprobs: int,
|
||||
selected_token: int,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top alternative tokens.
|
||||
|
||||
Args:
|
||||
logprobs: Full vocabulary logprobs array from MLX
|
||||
tokenizer: Tokenizer for decoding token IDs to strings
|
||||
top_logprobs: Number of top alternatives to return
|
||||
selected_token: The token ID that was actually sampled
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_token_logprob, list of TopLogprobItem for top alternatives)
|
||||
"""
|
||||
# Get the logprob of the selected token
|
||||
selected_logprob = float(logprobs[selected_token].item())
|
||||
|
||||
# Get top indices (most probable tokens)
|
||||
# mx.argpartition gives indices that would partition the array
|
||||
# We negate logprobs since argpartition finds smallest, and we want largest
|
||||
top_logprobs = min(top_logprobs, logprobs.shape[0]) # Don't exceed vocab size
|
||||
top_indices = mx.argpartition(-logprobs, top_logprobs)[:top_logprobs]
|
||||
|
||||
# Get the actual logprob values for these indices
|
||||
top_values = logprobs[top_indices]
|
||||
|
||||
# Sort by logprob (descending) for consistent ordering
|
||||
sort_order = mx.argsort(-top_values)
|
||||
top_indices = top_indices[sort_order]
|
||||
top_values = top_values[sort_order]
|
||||
|
||||
# Convert to list of TopLogprobItem
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for i in range(top_logprobs):
|
||||
token_id = int(top_indices[i].item())
|
||||
token_logprob = float(top_values[i].item())
|
||||
# Decode token ID to string
|
||||
token_str = tokenizer.decode([token_id])
|
||||
# Get byte representation
|
||||
token_bytes = list(token_str.encode("utf-8"))
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=token_logprob,
|
||||
bytes=token_bytes,
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -296,9 +356,22 @@ def mlx_generate(
|
||||
),
|
||||
)
|
||||
|
||||
# Extract logprobs from the full vocabulary logprobs array
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if task.logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs=out.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_logprobs=task.top_logprobs or DEFAULT_TOP_LOGPROBS,
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
|
||||
@@ -442,6 +442,12 @@ def apply_chat_template(
|
||||
continue
|
||||
formatted_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# For assistant prefilling, append content after templating to avoid a closing turn token.
|
||||
partial_assistant_content: str | None = None
|
||||
if formatted_messages and formatted_messages[-1].get("role") == "assistant":
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
@@ -449,6 +455,9 @@ def apply_chat_template(
|
||||
tools=task_params.tools,
|
||||
)
|
||||
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -320,6 +320,8 @@ def main(
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user