Compare commits

..

16 Commits

Author SHA1 Message Date
Sami Khan
d611f55332 testing macmon 2026-02-04 15:13:58 +05:00
Sami Khan
66174b6509 test macmon 2026-02-04 12:10:04 +05:00
Sami Khan
7a2abfa0ed test override 2026-02-04 11:25:32 +05:00
Sami Khan
5aea62c8ef fix test flow 2026-02-04 10:37:34 +05:00
Sami Khan
32ce382445 fix path 2026-02-04 10:08:25 +05:00
Sami Khan
a4c42993e0 networksetup fix 2026-02-04 09:45:38 +05:00
Sami Khan
38d03ce1fa macmon in path 2026-02-04 09:32:04 +05:00
Sami Khan
ad0b1a2ce9 Add macmon to CI and restore E2E tests for model launch and chat 2026-02-04 09:13:33 +05:00
Sami Khan
6f7c9000cf Simplify to basic UI element tests (no snapshots) 2026-02-04 07:50:17 +05:00
Sami Khan
c9ff05f012 Simplify to basic UI element tests (no snapshots) 2026-02-04 07:44:29 +05:00
Sami Khan
164f8fb38c Remove E2E tests, keep only visual snapshots for CI 2026-02-04 07:31:34 +05:00
Sami Khan
698eb9ad17 Skip model-launch tests in CI 2026-02-04 07:11:59 +05:00
Sami Khan
2ef29eeb5f Fix CI: add uv sync step to dashboard tests 2026-02-04 06:04:29 +05:00
Sami Khan
e847bbd675 Fix CI: pre-install Python deps, increase timeout, add @types/node 2026-02-04 05:28:03 +05:00
Sami Khan
8f1ca88e5d remove mock tests 2026-02-04 01:23:12 +05:00
Sami Khan
075c5c545e Add dashboard Playwright tests with CI 2026-02-04 01:02:39 +05:00
25 changed files with 567 additions and 780 deletions

View File

@@ -143,3 +143,139 @@ jobs:
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
dashboard-tests:
name: Dashboard E2E Tests
runs-on: macos-26
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- uses: cachix/install-nix-action@v31
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build Metal packages
run: |
if nix build .#metal-toolchain 2>/dev/null; then
echo "metal-toolchain built successfully (likely cache hit)"
else
echo "metal-toolchain build failed, extracting from Xcode..."
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
NAR_NAME="metal-toolchain-17C48.nar"
WORK_DIR="${RUNNER_TEMP}/metal-work"
mkdir -p "$WORK_DIR"
xcodebuild -downloadComponent MetalToolchain
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
if [ -z "$DMG_PATH" ]; then
echo "Error: Could not find Metal toolchain DMG"
exit 1
fi
echo "Found DMG at: $DMG_PATH"
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
hdiutil detach "${WORK_DIR}/metal-dmg"
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
echo "Added NAR to store: $STORE_PATH"
rm -rf "$WORK_DIR"
nix build .#metal-toolchain
fi
nix build .#mlx
- name: Install macmon for hardware monitoring
run: brew install macmon
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Sync Python dependencies
run: uv sync --all-packages
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: dashboard/package-lock.json
- name: Install dashboard dependencies
working-directory: dashboard
run: npm ci
- name: Install Playwright browsers
working-directory: dashboard
run: npx playwright install chromium --with-deps
- name: Build dashboard
working-directory: dashboard
run: npm run build
- name: Verify macmon is accessible
run: |
echo "PATH: $PATH"
which macmon || echo "macmon not in PATH"
macmon --version
# Test macmon actually works - capture stderr too
echo "Testing macmon pipe output (with stderr)..."
timeout 5 macmon pipe --interval 1000 2>&1 || echo "macmon pipe exit code: $?"
# Try running macmon raw (not pipe mode)
echo "Testing macmon raw output..."
macmon raw 2>&1 | head -5 || echo "macmon raw failed"
- name: Verify Python can find macmon
run: |
echo "Testing shutil.which from uv run python..."
uv run python -c "import shutil; print('Python shutil.which macmon:', shutil.which('macmon'))"
- name: Run Playwright tests
working-directory: dashboard
run: |
export PATH="/usr/sbin:/usr/bin:/opt/homebrew/bin:$PATH"
echo "Effective PATH: $PATH"
which macmon && echo "macmon found at $(which macmon)"
npm test
env:
CI: true
- name: Upload test results
uses: actions/upload-artifact@v4
if: always()
with:
name: playwright-report
path: dashboard/playwright-report/
retention-days: 30
- name: Upload video recordings
uses: actions/upload-artifact@v4
if: always()
with:
name: test-videos
path: dashboard/test-results/
retention-days: 30
- name: Upload snapshot diffs
uses: actions/upload-artifact@v4
if: failure()
with:
name: snapshot-diffs
path: dashboard/tests/**/*-snapshots/*-diff.png
retention-days: 30

7
.gitignore vendored
View File

@@ -29,5 +29,12 @@ dashboard/build/
dashboard/node_modules/
dashboard/.svelte-kit/
# playwright
dashboard/test-results/
dashboard/playwright-report/
dashboard/playwright/.cache/
dashboard/tests/**/*-snapshots/*-actual.png
dashboard/tests/**/*-snapshots/*-diff.png
# host config snapshots
hosts_*.json

View File

@@ -14,12 +14,13 @@
"mode-watcher": "^1.1.0"
},
"devDependencies": {
"@playwright/test": "^1.41.0",
"@sveltejs/adapter-static": "^3.0.10",
"@sveltejs/kit": "^2.48.4",
"@sveltejs/vite-plugin-svelte": "^5.0.0",
"@tailwindcss/vite": "^4.0.0",
"@types/d3": "^7.4.3",
"@types/node": "^22",
"@types/node": "^22.19.8",
"d3": "^7.9.0",
"prettier": "^3.4.2",
"prettier-plugin-svelte": "^3.3.3",
@@ -518,6 +519,22 @@
"@jridgewell/sourcemap-codec": "^1.4.14"
}
},
"node_modules/@playwright/test": {
"version": "1.58.1",
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.58.1.tgz",
"integrity": "sha512-6LdVIUERWxQMmUSSQi0I53GgCBYgM2RpGngCPY7hSeju+VrKjq3lvs7HpJoPbDiY5QM5EYRtRX5fvrinnMAz3w==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"playwright": "1.58.1"
},
"bin": {
"playwright": "cli.js"
},
"engines": {
"node": ">=18"
}
},
"node_modules/@polka/url": {
"version": "1.0.0-next.29",
"resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.29.tgz",
@@ -1515,9 +1532,9 @@
"license": "MIT"
},
"node_modules/@types/node": {
"version": "22.19.1",
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.1.tgz",
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"version": "22.19.8",
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.8.tgz",
"integrity": "sha512-ebO/Yl+EAvVe8DnMfi+iaAyIqYdK0q/q0y0rw82INWEKJOBe6b/P3YWE8NW7oOlF/nXFNrHwhARrN/hdgDkraA==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -2655,6 +2672,53 @@
"url": "https://github.com/sponsors/jonschlinkert"
}
},
"node_modules/playwright": {
"version": "1.58.1",
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.58.1.tgz",
"integrity": "sha512-+2uTZHxSCcxjvGc5C891LrS1/NlxglGxzrC4seZiVjcYVQfUa87wBL6rTDqzGjuoWNjnBzRqKmF6zRYGMvQUaQ==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"playwright-core": "1.58.1"
},
"bin": {
"playwright": "cli.js"
},
"engines": {
"node": ">=18"
},
"optionalDependencies": {
"fsevents": "2.3.2"
}
},
"node_modules/playwright-core": {
"version": "1.58.1",
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.58.1.tgz",
"integrity": "sha512-bcWzOaTxcW+VOOGBCQgnaKToLJ65d6AqfLVKEWvexyS3AS6rbXl+xdpYRMGSRBClPvyj44njOWoxjNdL/H9UNg==",
"dev": true,
"license": "Apache-2.0",
"bin": {
"playwright-core": "cli.js"
},
"engines": {
"node": ">=18"
}
},
"node_modules/playwright/node_modules/fsevents": {
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
"integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
}
},
"node_modules/postcss": {
"version": "8.5.6",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz",

View File

@@ -8,18 +8,23 @@
"build": "vite build",
"preview": "vite preview",
"prepare": "svelte-kit sync || echo ''",
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json"
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
"test": "playwright test",
"test:e2e": "playwright test tests/e2e",
"test:visual": "playwright test tests/visual",
"test:update-snapshots": "playwright test tests/visual --update-snapshots"
},
"devDependencies": {
"prettier": "^3.4.2",
"prettier-plugin-svelte": "^3.3.3",
"@playwright/test": "^1.41.0",
"@sveltejs/adapter-static": "^3.0.10",
"@sveltejs/kit": "^2.48.4",
"@sveltejs/vite-plugin-svelte": "^5.0.0",
"@tailwindcss/vite": "^4.0.0",
"@types/d3": "^7.4.3",
"@types/node": "^22",
"@types/node": "^22.19.8",
"d3": "^7.9.0",
"prettier": "^3.4.2",
"prettier-plugin-svelte": "^3.3.3",
"svelte": "^5.0.0",
"svelte-check": "^4.0.0",
"tailwindcss": "^4.0.0",

View File

@@ -0,0 +1,43 @@
/// <reference types="node" />
import { defineConfig, devices } from "@playwright/test";
export default defineConfig({
testDir: "./tests",
fullyParallel: true,
forbidOnly: !!process.env.CI,
retries: process.env.CI ? 2 : 0,
workers: process.env.CI ? 1 : undefined,
reporter: [["html", { open: "never" }], ["list"]],
use: {
baseURL: "http://localhost:52415",
trace: "on-first-retry",
video: "on",
screenshot: "only-on-failure",
},
projects: [
{
name: "chromium",
use: { ...devices["Desktop Chrome"] },
},
],
webServer: {
command: "cd .. && uv run exo",
url: "http://localhost:52415/node_id",
reuseExistingServer: !process.env.CI,
timeout: 300000, // 5 minutes - CI needs time to install dependencies
env: {
...process.env,
// Ensure macmon and system tools are accessible
PATH: `/usr/sbin:/usr/bin:/opt/homebrew/bin:${process.env.PATH}`,
// Override memory detection for CI (macmon may not work on CI runners)
// 24GB is typical for GitHub Actions macos-26 runners
...(process.env.CI ? { OVERRIDE_MEMORY_MB: "24000" } : {}),
},
},
expect: {
toHaveScreenshot: {
maxDiffPixelRatio: 0.05,
threshold: 0.2,
},
},
});

View File

@@ -407,6 +407,7 @@
<!-- Custom dropdown -->
<div class="relative flex-1 max-w-xs">
<button
data-testid="chat-model-selector"
bind:this={dropdownButtonRef}
type="button"
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
@@ -587,6 +588,7 @@
>
<textarea
data-testid="chat-input"
bind:this={textareaRef}
bind:value={message}
onkeydown={handleKeydown}
@@ -606,6 +608,7 @@
></textarea>
<button
data-testid="send-button"
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap

View File

@@ -6,13 +6,11 @@
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
interface Props {
class?: string;
@@ -101,23 +99,6 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty heatmap toggle
let heatmapMessageIds = $state<Set<string>>(new Set());
function toggleHeatmap(messageId: string) {
const next = new Set(heatmapMessageIds);
if (next.has(messageId)) {
next.delete(messageId);
} else {
next.add(messageId);
}
heatmapMessageIds = next;
}
function isHeatmapVisible(messageId: string): boolean {
return heatmapMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString("en-US", {
hour12: false,
@@ -256,6 +237,9 @@
class="group flex {message.role === 'user'
? 'justify-end'
: 'justify-start'}"
data-testid={message.role === "user"
? "user-message"
: "assistant-message"}
>
<div
class={message.role === "user"
@@ -567,23 +551,13 @@
>
</div>
{:else if message.content || (loading && !message.attachments?.some((a) => a.type === "generated-image"))}
{#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}
<TokenHeatmap
tokens={message.tokens}
isGenerating={loading &&
isLastAssistantMessage(message.id)}
onRegenerateFrom={(tokenIndex) =>
regenerateFromToken(message.id, tokenIndex)}
/>
{:else}
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>
{/if}
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>
{/if}
{/if}
</div>
@@ -658,35 +632,6 @@
</button>
{/if}
<!-- Uncertainty heatmap toggle (assistant messages with tokens) -->
{#if message.role === "assistant" && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleHeatmap(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(
message.id,
)
? 'text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
title={isHeatmapVisible(message.id)
? "Hide uncertainty heatmap"
: "Show uncertainty heatmap"}
>
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
/>
</svg>
</button>
{/if}
<!-- Regenerate button (last assistant message only) -->
{#if message.role === "assistant" && isLastAssistantMessage(message.id) && !loading}
<button

View File

@@ -977,6 +977,7 @@
<!-- Launch Button -->
<button
data-testid="launch-button"
onclick={onLaunch}
disabled={isLaunching || !canFit}
class="w-full py-2 text-sm font-mono tracking-wider uppercase border transition-all duration-200

View File

@@ -1,236 +0,0 @@
<script lang="ts">
import type { TokenData } from "$lib/stores/app.svelte";
interface Props {
tokens: TokenData[];
class?: string;
isGenerating?: boolean;
onRegenerateFrom?: (tokenIndex: number) => void;
}
let {
tokens,
class: className = "",
isGenerating = false,
onRegenerateFrom,
}: Props = $props();
// Tooltip state - track both token data and index
let hoveredTokenIndex = $state<number | null>(null);
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
let isTooltipHovered = $state(false);
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
// Derive the hovered token from the index (stable across re-renders)
const hoveredToken = $derived(
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
? {
token: tokens[hoveredTokenIndex],
index: hoveredTokenIndex,
...hoveredPosition,
}
: null,
);
/**
* Get confidence styling based on probability.
* Following Apple design principles: high confidence tokens blend in,
* only uncertainty draws attention.
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
return "bg-red-500/20 text-red-200/90"; // Draws attention
}
/**
* Get border/underline styling for uncertain tokens
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return "border-transparent"; // No border for expected
if (probability > 0.5) return "border-gray-500/20";
if (probability > 0.2) return "border-amber-500/30";
return "border-red-500/40";
}
function clearHideTimeout() {
if (hideTimeoutId) {
clearTimeout(hideTimeoutId);
hideTimeoutId = null;
}
}
function handleMouseEnter(
event: MouseEvent,
token: TokenData,
index: number,
) {
clearHideTimeout();
const rects = (event.target as HTMLElement).getClientRects();
let rect = rects[0];
for (let j = 0; j < rects.length; j++) {
if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {
rect = rects[j];
break;
}
}
hoveredTokenIndex = index;
hoveredPosition = {
x: rect.left + rect.width / 2,
y: rect.top - 10,
};
}
function handleMouseLeave() {
clearHideTimeout();
// Use longer delay during generation to account for re-renders
const delay = isGenerating ? 300 : 200;
hideTimeoutId = setTimeout(() => {
if (!isTooltipHovered) {
hoveredTokenIndex = null;
hoveredPosition = null;
}
}, delay);
}
function handleTooltipEnter() {
clearHideTimeout();
isTooltipHovered = true;
}
function handleTooltipLeave() {
isTooltipHovered = false;
hoveredTokenIndex = null;
hoveredPosition = null;
}
function handleRegenerate() {
if (hoveredToken && onRegenerateFrom) {
const indexToRegenerate = hoveredToken.index;
// Clear hover state immediately
hoveredTokenIndex = null;
hoveredPosition = null;
isTooltipHovered = false;
// Call regenerate
onRegenerateFrom(indexToRegenerate);
}
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + "%";
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
function getProbabilityColor(probability: number): string {
if (probability > 0.8) return "text-gray-300";
if (probability > 0.5) return "text-gray-400";
if (probability > 0.2) return "text-amber-400";
return "text-red-400";
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
tokenData.probability,
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
onmouseleave={handleMouseLeave}>{tokenData.token}</span
>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50 pb-2"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
onmouseenter={handleTooltipEnter}
onmouseleave={handleTooltipLeave}
>
<div
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
>
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-500 text-xs">Token:</span>
<span class="text-white font-mono ml-1"
>"{hoveredToken.token.token}"</span
>
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
>{formatProbability(hoveredToken.token.probability)}</span
>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono"
>{formatLogprob(hoveredToken.token.logprob)}</span
>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700/50 mt-2 pt-2">
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24"
>"{alt.token}"</span
>
<span class="text-gray-400 ml-2"
>{formatProbability(altProb)}</span
>
</div>
{/each}
</div>
{/if}
<!-- Regenerate button -->
{#if onRegenerateFrom}
<button
onclick={handleRegenerate}
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
/>
</svg>
Regenerate from here
</button>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -242,19 +242,6 @@ export interface MessageAttachment {
mimeType?: string;
}
export interface TopLogprob {
token: string;
logprob: number;
bytes: number[] | null;
}
export interface TokenData {
token: string;
logprob: number;
probability: number;
topLogprobs: TopLogprob[];
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -266,7 +253,6 @@ export interface Message {
tps?: number; // Tokens per second (for assistant messages)
requestType?: "chat" | "image-generation" | "image-editing";
sourceImageDataUrl?: string; // For image editing regeneration
tokens?: TokenData[];
}
export interface Conversation {
@@ -554,18 +540,7 @@ class AppStore {
*/
private saveConversationsToStorage() {
try {
// Strip tokens from messages before saving to avoid bloating localStorage
const stripped = this.conversations.map((conv) => ({
...conv,
messages: conv.messages.map((msg) => {
if (msg.tokens) {
const { tokens: _, ...rest } = msg;
return rest;
}
return msg;
}),
}));
localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
} catch (error) {
console.error("Failed to save conversations:", error);
}
@@ -1470,213 +1445,6 @@ class AppStore {
}
}
/**
* Regenerate response from a specific token index.
* Truncates the assistant message at the given token and re-generates from there.
*/
async regenerateFromToken(
messageId: string,
tokenIndex: number,
): Promise<void> {
if (this.isLoading) return;
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
const msgIndex = this.messages.findIndex((m) => m.id === messageId);
if (msgIndex === -1) return;
const msg = this.messages[msgIndex];
if (
msg.role !== "assistant" ||
!msg.tokens ||
tokenIndex >= msg.tokens.length
)
return;
// Keep tokens up to (not including) the specified index
const tokensToKeep = msg.tokens.slice(0, tokenIndex);
const prefixText = tokensToKeep.map((t) => t.token).join("");
// Remove all messages after this assistant message
this.messages = this.messages.slice(0, msgIndex + 1);
// Update the message to show the prefix
this.messages[msgIndex].content = prefixText;
this.messages[msgIndex].tokens = tokensToKeep;
this.updateActiveConversation();
// Set up for continuation - modify the existing message in place
this.isLoading = true;
this.currentResponse = prefixText;
this.ttftMs = null;
this.tps = null;
this.totalTokens = tokensToKeep.length;
try {
// Build messages for API - include the partial assistant message
const systemPrompt = {
role: "system" as const,
content:
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
};
const apiMessages = [
systemPrompt,
...this.messages.map((m) => {
let msgContent = m.content;
if (m.attachments) {
for (const attachment of m.attachments) {
if (attachment.type === "text" && attachment.content) {
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
}
}
}
return { role: m.role, content: msgContent };
}),
];
const modelToUse = this.getModelForRequest();
if (!modelToUse) {
throw new Error("No model available");
}
const requestStartTime = performance.now();
let firstTokenTime: number | null = null;
let tokenCount = tokensToKeep.length;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const reader = response.body?.getReader();
if (!reader) throw new Error("No response body");
let fullContent = prefixText;
const collectedTokens: TokenData[] = [...tokensToKeep];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
}
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
if (delta) {
if (firstTokenTime === null) {
firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime;
}
tokenCount += 1;
this.totalTokens = tokenCount;
if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {
const elapsed = performance.now() - firstTokenTime;
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
}
fullContent += delta;
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
if (this.activeConversationId === targetConversationId) {
this.currentResponse = displayContent;
}
// Update existing message in place
this.updateConversationMessage(
targetConversationId,
messageId,
(m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
},
);
// Final update
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.tokens = [...collectedTokens];
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
if (this.tps !== null) m.tps = this.tps;
});
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error regenerating from token:", error);
if (this.conversationExists(targetConversationId)) {
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = `${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
});
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
} finally {
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
}
}
/**
* Helper method to regenerate a chat completion response
*/
@@ -1745,8 +1513,6 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1761,49 +1527,16 @@ class AppStore {
}
let streamedContent = "";
const collectedTokens: TokenData[] = [];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
choices?: Array<{ delta?: { content?: string } }>;
}
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
const delta = parsed.choices?.[0]?.delta?.content;
if (delta) {
streamedContent += delta;
const { displayContent, thinkingContent } =
@@ -1821,7 +1554,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -1840,7 +1572,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -2183,8 +1914,6 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -2201,48 +1930,14 @@ class AppStore {
let streamedContent = "";
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
choices?: Array<{ delta?: { content?: string } }>;
}
const collectedTokens: TokenData[] = [];
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
const tokenContent = parsed.choices?.[0]?.delta?.content;
if (tokenContent) {
// Track first token for TTFT
if (firstTokenTime === null) {
@@ -2278,7 +1973,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -2303,7 +1997,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
// Store performance metrics on the message
if (this.ttftMs !== null) {
msg.ttftMs = this.ttftMs;
@@ -3000,8 +2693,6 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
appStore.regenerateFromToken(messageId, tokenIndex);
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -1663,12 +1663,14 @@
<div
class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden"
>
<TopologyGraph
class="w-full h-full"
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
<div data-testid="topology-graph" class="w-full h-full">
<TopologyGraph
class="w-full h-full"
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
</div>
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
@@ -1782,12 +1784,14 @@
class="flex-1 relative bg-exo-dark-gray/40 mx-4 mb-4 rounded-lg overflow-hidden"
>
<!-- The main topology graph - full container -->
<TopologyGraph
class="w-full h-full"
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
<div data-testid="topology-graph" class="w-full h-full">
<TopologyGraph
class="w-full h-full"
highlightedNodes={highlightedNodes()}
filteredNodes={nodeFilter}
onNodeClick={togglePreviewNodeFilter}
/>
</div>
<!-- Thunderbolt Bridge Cycle Warning -->
{#if tbBridgeCycles.length > 0}
@@ -2363,6 +2367,7 @@
<!-- Model Dropdown (Custom) -->
<div class="flex-shrink-0 mb-3 relative">
<button
data-testid="model-dropdown"
type="button"
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
@@ -2499,6 +2504,7 @@
model.id,
)}
<button
data-testid="model-option"
type="button"
onclick={() => {
if (modelCanFit) {
@@ -2777,6 +2783,7 @@
{#each allPreviews as apiPreview, i}
<div
role="group"
data-testid="model-card"
onmouseenter={() => {
if (apiPreview.memory_delta_by_node) {
hoveredPreviewNodes = new Set(

View File

@@ -0,0 +1,68 @@
import { test, expect } from "@playwright/test";
import {
waitForTopologyLoaded,
waitForModelCards,
waitForChatReady,
waitForAssistantMessage,
sendChatMessage,
selectModelFromLaunchDropdown,
} from "../helpers/wait-for-ready";
test.describe("Chat Message", () => {
test("should send a message and receive a response", async ({ page }) => {
// Increase timeout for this test since it involves model loading and inference
test.setTimeout(600000); // 10 minutes
await page.goto("/");
await waitForTopologyLoaded(page);
// First select the model from the dropdown (model cards appear after selection)
await selectModelFromLaunchDropdown(page, /qwen.*0\.6b/i);
// Now wait for model cards to appear
await waitForModelCards(page);
// Find and click on the model card (should already be filtered to Qwen)
const modelCard = page.locator('[data-testid="model-card"]').first();
await expect(modelCard).toBeVisible({ timeout: 10000 });
// Click the launch button
const launchButton = modelCard.locator('[data-testid="launch-button"]');
await launchButton.click();
// Wait for the model to be ready (may take time to download)
await expect(
page
.locator('[data-testid="instance-status"]')
.filter({ hasText: /READY/i })
.first(),
).toBeVisible({ timeout: 300000 }); // 5 minutes for download
// Wait for chat to be ready
await waitForChatReady(page);
// Select the model in the chat selector if needed
const modelSelector = page.locator('[data-testid="chat-model-selector"]');
if (await modelSelector.isVisible()) {
await modelSelector.click();
await page.locator("text=/qwen.*0\\.6b/i").first().click();
}
// Send a simple message
await sendChatMessage(page, "What is 2+2?");
// Wait for assistant response
await waitForAssistantMessage(page, 120000); // 2 minutes for inference
// Verify the assistant message is visible
const assistantMessage = page
.locator('[data-testid="assistant-message"]')
.last();
await expect(assistantMessage).toBeVisible();
// The response should contain something (not empty)
const messageContent = await assistantMessage.textContent();
expect(messageContent).toBeTruthy();
expect(messageContent!.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,36 @@
import { test, expect } from "@playwright/test";
import {
waitForTopologyLoaded,
waitForModelCards,
selectModelFromLaunchDropdown,
} from "../helpers/wait-for-ready";
test.describe("Launch Instance", () => {
test("should launch Qwen3-0.6B-4bit model", async ({ page }) => {
await page.goto("/");
await waitForTopologyLoaded(page);
// First select the model from the dropdown (model cards appear after selection)
await selectModelFromLaunchDropdown(page, /qwen.*0\.6b/i);
// Now wait for model cards to appear
await waitForModelCards(page);
// Find and click on the model card (should already be filtered to Qwen)
const modelCard = page.locator('[data-testid="model-card"]').first();
await expect(modelCard).toBeVisible({ timeout: 10000 });
// Click the launch button
const launchButton = modelCard.locator('[data-testid="launch-button"]');
await launchButton.click();
// Wait for the model to start (status should change to READY or show download progress)
// The model may need to download first, so we wait with a longer timeout
await expect(
page
.locator('[data-testid="instance-status"]')
.filter({ hasText: /READY|downloading/i })
.first(),
).toBeVisible({ timeout: 300000 }); // 5 minutes for download
});
});

View File

@@ -0,0 +1,117 @@
import { expect, type Page } from "@playwright/test";
const BASE_URL = "http://localhost:52415";
export async function waitForApiReady(
page: Page,
timeoutMs = 30000,
): Promise<void> {
const startTime = Date.now();
while (Date.now() - startTime < timeoutMs) {
try {
const response = await page.request.get(`${BASE_URL}/node_id`);
if (response.ok()) {
return;
}
} catch {
// API not ready yet, continue polling
}
await page.waitForTimeout(500);
}
throw new Error(`API did not become ready within ${timeoutMs}ms`);
}
export async function waitForTopologyLoaded(page: Page): Promise<void> {
await expect(page.locator('[data-testid="topology-graph"]')).toBeVisible({
timeout: 30000,
});
}
export async function waitForModelCards(page: Page): Promise<void> {
await expect(page.locator('[data-testid="model-card"]').first()).toBeVisible({
timeout: 30000,
});
}
export async function selectModelFromLaunchDropdown(
page: Page,
modelPattern: RegExp | string,
): Promise<void> {
// Click the model dropdown in the Launch Instance panel
const dropdown = page.locator('button:has-text("SELECT MODEL")');
await expect(dropdown).toBeVisible({ timeout: 30000 });
await dropdown.click();
// Wait for dropdown menu to appear and select the model
const modelOption = page.locator("button").filter({ hasText: modelPattern });
await expect(modelOption.first()).toBeVisible({ timeout: 10000 });
await modelOption.first().click();
}
export async function waitForChatReady(page: Page): Promise<void> {
await expect(page.locator('[data-testid="chat-input"]')).toBeVisible({
timeout: 10000,
});
await expect(page.locator('[data-testid="send-button"]')).toBeVisible({
timeout: 10000,
});
}
export async function waitForAssistantMessage(
page: Page,
timeoutMs = 60000,
): Promise<void> {
await expect(
page.locator('[data-testid="assistant-message"]').last(),
).toBeVisible({ timeout: timeoutMs });
}
export async function waitForStreamingComplete(
page: Page,
timeoutMs = 120000,
): Promise<void> {
const startTime = Date.now();
while (Date.now() - startTime < timeoutMs) {
const sendButton = page.locator('[data-testid="send-button"]');
const buttonText = await sendButton.textContent();
if (
buttonText &&
!buttonText.includes("PROCESSING") &&
!buttonText.includes("...")
) {
return;
}
await page.waitForTimeout(500);
}
throw new Error(`Streaming did not complete within ${timeoutMs}ms`);
}
export async function selectModel(
page: Page,
modelName: string,
): Promise<void> {
const modelSelector = page.locator('[data-testid="chat-model-selector"]');
await modelSelector.click();
await page.locator(`text=${modelName}`).click();
}
export async function sendChatMessage(
page: Page,
message: string,
): Promise<void> {
const chatInput = page.locator('[data-testid="chat-input"]');
await chatInput.fill(message);
const sendButton = page.locator('[data-testid="send-button"]');
await sendButton.click();
}
export async function launchModel(
page: Page,
modelCardIndex = 0,
): Promise<void> {
const modelCards = page.locator('[data-testid="model-card"]');
const launchButton = modelCards
.nth(modelCardIndex)
.locator('[data-testid="launch-button"]');
await launchButton.click();
}

View File

@@ -0,0 +1,26 @@
import { test, expect } from "@playwright/test";
import { waitForTopologyLoaded } from "../helpers/wait-for-ready";
test.describe("Chat Interface", () => {
test("should display chat input and send button", async ({ page }) => {
await page.goto("/");
await waitForTopologyLoaded(page);
const chatInput = page.locator('[data-testid="chat-input"]');
await expect(chatInput).toBeVisible();
const sendButton = page.locator('[data-testid="send-button"]');
await expect(sendButton).toBeVisible();
});
test("should allow typing in chat input", async ({ page }) => {
await page.goto("/");
await waitForTopologyLoaded(page);
const chatInput = page.locator('[data-testid="chat-input"]');
await expect(chatInput).toBeVisible();
await chatInput.fill("Test message");
await expect(chatInput).toHaveValue("Test message");
});
});

View File

@@ -0,0 +1,16 @@
import { test, expect } from "@playwright/test";
import { waitForTopologyLoaded } from "../helpers/wait-for-ready";
test.describe("Homepage", () => {
test("should load and display key elements", async ({ page }) => {
await page.goto("/");
await waitForTopologyLoaded(page);
// Verify key UI elements are present
await expect(
page.locator('[data-testid="topology-graph"]').first(),
).toBeVisible();
await expect(page.locator('[data-testid="chat-input"]')).toBeVisible();
await expect(page.locator('[data-testid="send-button"]')).toBeVisible();
});
});

View File

@@ -14,8 +14,6 @@ from exo.shared.types.api import (
ErrorInfo,
ErrorResponse,
FinishReason,
Logprobs,
LogprobsContentItem,
StreamingChoiceResponse,
ToolCall,
)
@@ -83,8 +81,6 @@ def chat_request_to_text_generation(
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
logprobs=request.logprobs or False,
top_logprobs=request.top_logprobs,
)
@@ -92,19 +88,6 @@ def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -113,7 +96,6 @@ def chunk_to_response(
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
@@ -180,7 +162,6 @@ async def collect_chat_response(
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
logprobs_content: list[LogprobsContentItem] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
@@ -195,14 +176,6 @@ async def collect_chat_response(
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
@@ -235,9 +208,6 @@ async def collect_chat_response(
content=combined_text,
tool_calls=tool_calls if tool_calls else None,
),
logprobs=Logprobs(content=logprobs_content)
if logprobs_content
else None,
finish_reason=finish_reason,
)
],

View File

@@ -610,11 +610,6 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_chat_response(
@@ -1164,11 +1159,6 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_claude_response(
@@ -1196,11 +1186,6 @@ class API:
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await collect_responses_response(

View File

@@ -2,12 +2,7 @@ from collections.abc import Generator
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import (
GenerationStats,
ImageGenerationStats,
TopLogprobItem,
Usage,
)
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -25,8 +20,6 @@ class TokenChunk(BaseChunk):
usage: Usage | None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
class ErrorChunk(BaseChunk):

View File

@@ -40,5 +40,3 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
stop: str | list[str] | None = None
seed: int | None = None
chat_template_messages: list[dict[str, Any]] | None = None
logprobs: bool = False
top_logprobs: int | None = None

View File

@@ -6,7 +6,6 @@ from exo.shared.types.api import (
GenerationStats,
ImageGenerationStats,
ToolCallItem,
TopLogprobItem,
Usage,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -23,8 +22,7 @@ class TokenizedResponse(BaseRunnerResponse):
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
usage: Usage | None

View File

@@ -11,7 +11,5 @@ QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = None
DEFAULT_TOP_LOGPROBS: int = 5
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -12,7 +12,6 @@ from exo.shared.types.api import (
FinishReason,
GenerationStats,
PromptTokensDetails,
TopLogprobItem,
Usage,
)
from exo.shared.types.common import ModelId
@@ -24,12 +23,7 @@ from exo.shared.types.worker.runner_response import (
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
from exo.worker.engines.mlx.constants import (
DEFAULT_TOP_LOGPROBS,
KV_BITS,
KV_GROUP_SIZE,
MAX_TOKENS,
)
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
mx_barrier,
@@ -161,60 +155,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def extract_top_logprobs(
logprobs: mx.array,
tokenizer: TokenizerWrapper,
top_logprobs: int,
selected_token: int,
) -> tuple[float, list[TopLogprobItem]]:
"""Extract the selected token's logprob and top alternative tokens.
Args:
logprobs: Full vocabulary logprobs array from MLX
tokenizer: Tokenizer for decoding token IDs to strings
top_logprobs: Number of top alternatives to return
selected_token: The token ID that was actually sampled
Returns:
Tuple of (selected_token_logprob, list of TopLogprobItem for top alternatives)
"""
# Get the logprob of the selected token
selected_logprob = float(logprobs[selected_token].item())
# Get top indices (most probable tokens)
# mx.argpartition gives indices that would partition the array
# We negate logprobs since argpartition finds smallest, and we want largest
top_logprobs = min(top_logprobs, logprobs.shape[0]) # Don't exceed vocab size
top_indices = mx.argpartition(-logprobs, top_logprobs)[:top_logprobs]
# Get the actual logprob values for these indices
top_values = logprobs[top_indices]
# Sort by logprob (descending) for consistent ordering
sort_order = mx.argsort(-top_values)
top_indices = top_indices[sort_order]
top_values = top_values[sort_order]
# Convert to list of TopLogprobItem
top_logprob_items: list[TopLogprobItem] = []
for i in range(top_logprobs):
token_id = int(top_indices[i].item())
token_logprob = float(top_values[i].item())
# Decode token ID to string
token_str = tokenizer.decode([token_id])
# Get byte representation
token_bytes = list(token_str.encode("utf-8"))
top_logprob_items.append(
TopLogprobItem(
token=token_str,
logprob=token_logprob,
bytes=token_bytes,
)
)
return selected_logprob, top_logprob_items
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
@@ -356,22 +296,9 @@ def mlx_generate(
),
)
# Extract logprobs from the full vocabulary logprobs array
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if task.logprobs:
logprob, top_logprobs = extract_top_logprobs(
logprobs=out.logprobs,
tokenizer=tokenizer,
top_logprobs=task.top_logprobs or DEFAULT_TOP_LOGPROBS,
selected_token=out.token,
)
yield GenerationResponse(
text=text,
token=out.token,
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
usage=usage,

View File

@@ -442,12 +442,6 @@ def apply_chat_template(
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
# For assistant prefilling, append content after templating to avoid a closing turn token.
partial_assistant_content: str | None = None
if formatted_messages and formatted_messages[-1].get("role") == "assistant":
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
formatted_messages = formatted_messages[:-1]
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
@@ -455,9 +449,6 @@ def apply_chat_template(
tools=task_params.tools,
)
if partial_assistant_content:
prompt += partial_assistant_content
logger.info(prompt)
return prompt

View File

@@ -320,8 +320,6 @@ def main(
usage=response.usage,
finish_reason=response.finish_reason,
stats=response.stats,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
),
)
)