mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 03:33:30 -05:00
Compare commits
16 Commits
alexcheema
...
sami/dashb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d611f55332 | ||
|
|
66174b6509 | ||
|
|
7a2abfa0ed | ||
|
|
5aea62c8ef | ||
|
|
32ce382445 | ||
|
|
a4c42993e0 | ||
|
|
38d03ce1fa | ||
|
|
ad0b1a2ce9 | ||
|
|
6f7c9000cf | ||
|
|
c9ff05f012 | ||
|
|
164f8fb38c | ||
|
|
698eb9ad17 | ||
|
|
2ef29eeb5f | ||
|
|
e847bbd675 | ||
|
|
8f1ca88e5d | ||
|
|
075c5c545e |
136
.github/workflows/pipeline.yml
vendored
136
.github/workflows/pipeline.yml
vendored
@@ -143,3 +143,139 @@ jobs:
|
||||
export HOME="$RUNNER_TEMP"
|
||||
export EXO_TESTS=1
|
||||
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
|
||||
dashboard-tests:
|
||||
name: Dashboard E2E Tests
|
||||
runs-on: macos-26
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build Metal packages
|
||||
run: |
|
||||
if nix build .#metal-toolchain 2>/dev/null; then
|
||||
echo "metal-toolchain built successfully (likely cache hit)"
|
||||
else
|
||||
echo "metal-toolchain build failed, extracting from Xcode..."
|
||||
|
||||
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
|
||||
NAR_NAME="metal-toolchain-17C48.nar"
|
||||
|
||||
WORK_DIR="${RUNNER_TEMP}/metal-work"
|
||||
mkdir -p "$WORK_DIR"
|
||||
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
|
||||
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
|
||||
if [ -z "$DMG_PATH" ]; then
|
||||
echo "Error: Could not find Metal toolchain DMG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found DMG at: $DMG_PATH"
|
||||
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
|
||||
|
||||
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
|
||||
hdiutil detach "${WORK_DIR}/metal-dmg"
|
||||
|
||||
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
|
||||
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
|
||||
echo "Added NAR to store: $STORE_PATH"
|
||||
|
||||
rm -rf "$WORK_DIR"
|
||||
|
||||
nix build .#metal-toolchain
|
||||
fi
|
||||
|
||||
nix build .#mlx
|
||||
|
||||
- name: Install macmon for hardware monitoring
|
||||
run: brew install macmon
|
||||
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Sync Python dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: dashboard/package-lock.json
|
||||
|
||||
- name: Install dashboard dependencies
|
||||
working-directory: dashboard
|
||||
run: npm ci
|
||||
|
||||
- name: Install Playwright browsers
|
||||
working-directory: dashboard
|
||||
run: npx playwright install chromium --with-deps
|
||||
|
||||
- name: Build dashboard
|
||||
working-directory: dashboard
|
||||
run: npm run build
|
||||
|
||||
- name: Verify macmon is accessible
|
||||
run: |
|
||||
echo "PATH: $PATH"
|
||||
which macmon || echo "macmon not in PATH"
|
||||
macmon --version
|
||||
# Test macmon actually works - capture stderr too
|
||||
echo "Testing macmon pipe output (with stderr)..."
|
||||
timeout 5 macmon pipe --interval 1000 2>&1 || echo "macmon pipe exit code: $?"
|
||||
# Try running macmon raw (not pipe mode)
|
||||
echo "Testing macmon raw output..."
|
||||
macmon raw 2>&1 | head -5 || echo "macmon raw failed"
|
||||
|
||||
- name: Verify Python can find macmon
|
||||
run: |
|
||||
echo "Testing shutil.which from uv run python..."
|
||||
uv run python -c "import shutil; print('Python shutil.which macmon:', shutil.which('macmon'))"
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: dashboard
|
||||
run: |
|
||||
export PATH="/usr/sbin:/usr/bin:/opt/homebrew/bin:$PATH"
|
||||
echo "Effective PATH: $PATH"
|
||||
which macmon && echo "macmon found at $(which macmon)"
|
||||
npm test
|
||||
env:
|
||||
CI: true
|
||||
|
||||
- name: Upload test results
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: dashboard/playwright-report/
|
||||
retention-days: 30
|
||||
|
||||
- name: Upload video recordings
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: test-videos
|
||||
path: dashboard/test-results/
|
||||
retention-days: 30
|
||||
|
||||
- name: Upload snapshot diffs
|
||||
uses: actions/upload-artifact@v4
|
||||
if: failure()
|
||||
with:
|
||||
name: snapshot-diffs
|
||||
path: dashboard/tests/**/*-snapshots/*-diff.png
|
||||
retention-days: 30
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -29,5 +29,12 @@ dashboard/build/
|
||||
dashboard/node_modules/
|
||||
dashboard/.svelte-kit/
|
||||
|
||||
# playwright
|
||||
dashboard/test-results/
|
||||
dashboard/playwright-report/
|
||||
dashboard/playwright/.cache/
|
||||
dashboard/tests/**/*-snapshots/*-actual.png
|
||||
dashboard/tests/**/*-snapshots/*-diff.png
|
||||
|
||||
# host config snapshots
|
||||
hosts_*.json
|
||||
|
||||
72
dashboard/package-lock.json
generated
72
dashboard/package-lock.json
generated
@@ -14,12 +14,13 @@
|
||||
"mode-watcher": "^1.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@playwright/test": "^1.41.0",
|
||||
"@sveltejs/adapter-static": "^3.0.10",
|
||||
"@sveltejs/kit": "^2.48.4",
|
||||
"@sveltejs/vite-plugin-svelte": "^5.0.0",
|
||||
"@tailwindcss/vite": "^4.0.0",
|
||||
"@types/d3": "^7.4.3",
|
||||
"@types/node": "^22",
|
||||
"@types/node": "^22.19.8",
|
||||
"d3": "^7.9.0",
|
||||
"prettier": "^3.4.2",
|
||||
"prettier-plugin-svelte": "^3.3.3",
|
||||
@@ -518,6 +519,22 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.14"
|
||||
}
|
||||
},
|
||||
"node_modules/@playwright/test": {
|
||||
"version": "1.58.1",
|
||||
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.58.1.tgz",
|
||||
"integrity": "sha512-6LdVIUERWxQMmUSSQi0I53GgCBYgM2RpGngCPY7hSeju+VrKjq3lvs7HpJoPbDiY5QM5EYRtRX5fvrinnMAz3w==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"playwright": "1.58.1"
|
||||
},
|
||||
"bin": {
|
||||
"playwright": "cli.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@polka/url": {
|
||||
"version": "1.0.0-next.29",
|
||||
"resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.29.tgz",
|
||||
@@ -1515,9 +1532,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
"version": "22.19.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.1.tgz",
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"version": "22.19.8",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.8.tgz",
|
||||
"integrity": "sha512-ebO/Yl+EAvVe8DnMfi+iaAyIqYdK0q/q0y0rw82INWEKJOBe6b/P3YWE8NW7oOlF/nXFNrHwhARrN/hdgDkraA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -2655,6 +2672,53 @@
|
||||
"url": "https://github.com/sponsors/jonschlinkert"
|
||||
}
|
||||
},
|
||||
"node_modules/playwright": {
|
||||
"version": "1.58.1",
|
||||
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.58.1.tgz",
|
||||
"integrity": "sha512-+2uTZHxSCcxjvGc5C891LrS1/NlxglGxzrC4seZiVjcYVQfUa87wBL6rTDqzGjuoWNjnBzRqKmF6zRYGMvQUaQ==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"playwright-core": "1.58.1"
|
||||
},
|
||||
"bin": {
|
||||
"playwright": "cli.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"fsevents": "2.3.2"
|
||||
}
|
||||
},
|
||||
"node_modules/playwright-core": {
|
||||
"version": "1.58.1",
|
||||
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.58.1.tgz",
|
||||
"integrity": "sha512-bcWzOaTxcW+VOOGBCQgnaKToLJ65d6AqfLVKEWvexyS3AS6rbXl+xdpYRMGSRBClPvyj44njOWoxjNdL/H9UNg==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"playwright-core": "cli.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/playwright/node_modules/fsevents": {
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
|
||||
"integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/postcss": {
|
||||
"version": "8.5.6",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz",
|
||||
|
||||
@@ -8,18 +8,23 @@
|
||||
"build": "vite build",
|
||||
"preview": "vite preview",
|
||||
"prepare": "svelte-kit sync || echo ''",
|
||||
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json"
|
||||
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
|
||||
"test": "playwright test",
|
||||
"test:e2e": "playwright test tests/e2e",
|
||||
"test:visual": "playwright test tests/visual",
|
||||
"test:update-snapshots": "playwright test tests/visual --update-snapshots"
|
||||
},
|
||||
"devDependencies": {
|
||||
"prettier": "^3.4.2",
|
||||
"prettier-plugin-svelte": "^3.3.3",
|
||||
"@playwright/test": "^1.41.0",
|
||||
"@sveltejs/adapter-static": "^3.0.10",
|
||||
"@sveltejs/kit": "^2.48.4",
|
||||
"@sveltejs/vite-plugin-svelte": "^5.0.0",
|
||||
"@tailwindcss/vite": "^4.0.0",
|
||||
"@types/d3": "^7.4.3",
|
||||
"@types/node": "^22",
|
||||
"@types/node": "^22.19.8",
|
||||
"d3": "^7.9.0",
|
||||
"prettier": "^3.4.2",
|
||||
"prettier-plugin-svelte": "^3.3.3",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte-check": "^4.0.0",
|
||||
"tailwindcss": "^4.0.0",
|
||||
|
||||
43
dashboard/playwright.config.ts
Normal file
43
dashboard/playwright.config.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
/// <reference types="node" />
|
||||
import { defineConfig, devices } from "@playwright/test";
|
||||
|
||||
export default defineConfig({
|
||||
testDir: "./tests",
|
||||
fullyParallel: true,
|
||||
forbidOnly: !!process.env.CI,
|
||||
retries: process.env.CI ? 2 : 0,
|
||||
workers: process.env.CI ? 1 : undefined,
|
||||
reporter: [["html", { open: "never" }], ["list"]],
|
||||
use: {
|
||||
baseURL: "http://localhost:52415",
|
||||
trace: "on-first-retry",
|
||||
video: "on",
|
||||
screenshot: "only-on-failure",
|
||||
},
|
||||
projects: [
|
||||
{
|
||||
name: "chromium",
|
||||
use: { ...devices["Desktop Chrome"] },
|
||||
},
|
||||
],
|
||||
webServer: {
|
||||
command: "cd .. && uv run exo",
|
||||
url: "http://localhost:52415/node_id",
|
||||
reuseExistingServer: !process.env.CI,
|
||||
timeout: 300000, // 5 minutes - CI needs time to install dependencies
|
||||
env: {
|
||||
...process.env,
|
||||
// Ensure macmon and system tools are accessible
|
||||
PATH: `/usr/sbin:/usr/bin:/opt/homebrew/bin:${process.env.PATH}`,
|
||||
// Override memory detection for CI (macmon may not work on CI runners)
|
||||
// 24GB is typical for GitHub Actions macos-26 runners
|
||||
...(process.env.CI ? { OVERRIDE_MEMORY_MB: "24000" } : {}),
|
||||
},
|
||||
},
|
||||
expect: {
|
||||
toHaveScreenshot: {
|
||||
maxDiffPixelRatio: 0.05,
|
||||
threshold: 0.2,
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -407,6 +407,7 @@
|
||||
<!-- Custom dropdown -->
|
||||
<div class="relative flex-1 max-w-xs">
|
||||
<button
|
||||
data-testid="chat-model-selector"
|
||||
bind:this={dropdownButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
|
||||
@@ -587,6 +588,7 @@
|
||||
>
|
||||
|
||||
<textarea
|
||||
data-testid="chat-input"
|
||||
bind:this={textareaRef}
|
||||
bind:value={message}
|
||||
onkeydown={handleKeydown}
|
||||
@@ -606,6 +608,7 @@
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
data-testid="send-button"
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
|
||||
@@ -6,13 +6,11 @@
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken,
|
||||
setEditingImage,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import type { Message } from "$lib/stores/app.svelte";
|
||||
import type { MessageAttachment } from "$lib/stores/app.svelte";
|
||||
import MarkdownContent from "./MarkdownContent.svelte";
|
||||
import TokenHeatmap from "./TokenHeatmap.svelte";
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -101,23 +99,6 @@
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
// Uncertainty heatmap toggle
|
||||
let heatmapMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function toggleHeatmap(messageId: string) {
|
||||
const next = new Set(heatmapMessageIds);
|
||||
if (next.has(messageId)) {
|
||||
next.delete(messageId);
|
||||
} else {
|
||||
next.add(messageId);
|
||||
}
|
||||
heatmapMessageIds = next;
|
||||
}
|
||||
|
||||
function isHeatmapVisible(messageId: string): boolean {
|
||||
return heatmapMessageIds.has(messageId);
|
||||
}
|
||||
|
||||
function formatTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toLocaleTimeString("en-US", {
|
||||
hour12: false,
|
||||
@@ -256,6 +237,9 @@
|
||||
class="group flex {message.role === 'user'
|
||||
? 'justify-end'
|
||||
: 'justify-start'}"
|
||||
data-testid={message.role === "user"
|
||||
? "user-message"
|
||||
: "assistant-message"}
|
||||
>
|
||||
<div
|
||||
class={message.role === "user"
|
||||
@@ -567,23 +551,13 @@
|
||||
>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some((a) => a.type === "generated-image"))}
|
||||
{#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}
|
||||
<TokenHeatmap
|
||||
tokens={message.tokens}
|
||||
isGenerating={loading &&
|
||||
isLastAssistantMessage(message.id)}
|
||||
onRegenerateFrom={(tokenIndex) =>
|
||||
regenerateFromToken(message.id, tokenIndex)}
|
||||
/>
|
||||
{:else}
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
{/if}
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
@@ -658,35 +632,6 @@
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Uncertainty heatmap toggle (assistant messages with tokens) -->
|
||||
{#if message.role === "assistant" && message.tokens && message.tokens.length > 0}
|
||||
<button
|
||||
onclick={() => toggleHeatmap(message.id)}
|
||||
class="p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(
|
||||
message.id,
|
||||
)
|
||||
? 'text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
title={isHeatmapVisible(message.id)
|
||||
? "Hide uncertainty heatmap"
|
||||
: "Show uncertainty heatmap"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button (last assistant message only) -->
|
||||
{#if message.role === "assistant" && isLastAssistantMessage(message.id) && !loading}
|
||||
<button
|
||||
|
||||
@@ -977,6 +977,7 @@
|
||||
|
||||
<!-- Launch Button -->
|
||||
<button
|
||||
data-testid="launch-button"
|
||||
onclick={onLaunch}
|
||||
disabled={isLaunching || !canFit}
|
||||
class="w-full py-2 text-sm font-mono tracking-wider uppercase border transition-all duration-200
|
||||
|
||||
@@ -1,236 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { TokenData } from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
tokens: TokenData[];
|
||||
class?: string;
|
||||
isGenerating?: boolean;
|
||||
onRegenerateFrom?: (tokenIndex: number) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
tokens,
|
||||
class: className = "",
|
||||
isGenerating = false,
|
||||
onRegenerateFrom,
|
||||
}: Props = $props();
|
||||
|
||||
// Tooltip state - track both token data and index
|
||||
let hoveredTokenIndex = $state<number | null>(null);
|
||||
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
|
||||
let isTooltipHovered = $state(false);
|
||||
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
// Derive the hovered token from the index (stable across re-renders)
|
||||
const hoveredToken = $derived(
|
||||
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
|
||||
? {
|
||||
token: tokens[hoveredTokenIndex],
|
||||
index: hoveredTokenIndex,
|
||||
...hoveredPosition,
|
||||
}
|
||||
: null,
|
||||
);
|
||||
|
||||
/**
|
||||
* Get confidence styling based on probability.
|
||||
* Following Apple design principles: high confidence tokens blend in,
|
||||
* only uncertainty draws attention.
|
||||
*/
|
||||
function getConfidenceClass(probability: number): string {
|
||||
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
|
||||
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
|
||||
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
|
||||
return "bg-red-500/20 text-red-200/90"; // Draws attention
|
||||
}
|
||||
|
||||
/**
|
||||
* Get border/underline styling for uncertain tokens
|
||||
*/
|
||||
function getBorderClass(probability: number): string {
|
||||
if (probability > 0.8) return "border-transparent"; // No border for expected
|
||||
if (probability > 0.5) return "border-gray-500/20";
|
||||
if (probability > 0.2) return "border-amber-500/30";
|
||||
return "border-red-500/40";
|
||||
}
|
||||
|
||||
function clearHideTimeout() {
|
||||
if (hideTimeoutId) {
|
||||
clearTimeout(hideTimeoutId);
|
||||
hideTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseEnter(
|
||||
event: MouseEvent,
|
||||
token: TokenData,
|
||||
index: number,
|
||||
) {
|
||||
clearHideTimeout();
|
||||
const rects = (event.target as HTMLElement).getClientRects();
|
||||
let rect = rects[0];
|
||||
for (let j = 0; j < rects.length; j++) {
|
||||
if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {
|
||||
rect = rects[j];
|
||||
break;
|
||||
}
|
||||
}
|
||||
hoveredTokenIndex = index;
|
||||
hoveredPosition = {
|
||||
x: rect.left + rect.width / 2,
|
||||
y: rect.top - 10,
|
||||
};
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
clearHideTimeout();
|
||||
// Use longer delay during generation to account for re-renders
|
||||
const delay = isGenerating ? 300 : 200;
|
||||
hideTimeoutId = setTimeout(() => {
|
||||
if (!isTooltipHovered) {
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
}, delay);
|
||||
}
|
||||
|
||||
function handleTooltipEnter() {
|
||||
clearHideTimeout();
|
||||
isTooltipHovered = true;
|
||||
}
|
||||
|
||||
function handleTooltipLeave() {
|
||||
isTooltipHovered = false;
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
|
||||
function handleRegenerate() {
|
||||
if (hoveredToken && onRegenerateFrom) {
|
||||
const indexToRegenerate = hoveredToken.index;
|
||||
// Clear hover state immediately
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
isTooltipHovered = false;
|
||||
// Call regenerate
|
||||
onRegenerateFrom(indexToRegenerate);
|
||||
}
|
||||
}
|
||||
|
||||
function formatProbability(prob: number): string {
|
||||
return (prob * 100).toFixed(1) + "%";
|
||||
}
|
||||
|
||||
function formatLogprob(logprob: number): string {
|
||||
return logprob.toFixed(3);
|
||||
}
|
||||
|
||||
function getProbabilityColor(probability: number): string {
|
||||
if (probability > 0.8) return "text-gray-300";
|
||||
if (probability > 0.5) return "text-gray-400";
|
||||
if (probability > 0.2) return "text-amber-400";
|
||||
return "text-red-400";
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="token-heatmap leading-relaxed {className}">
|
||||
{#each tokens as tokenData, i (i)}
|
||||
<span
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
|
||||
tokenData.probability,
|
||||
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
|
||||
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
|
||||
onmouseleave={handleMouseLeave}>{tokenData.token}</span
|
||||
>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Tooltip -->
|
||||
{#if hoveredToken}
|
||||
<div
|
||||
class="fixed z-50 pb-2"
|
||||
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
|
||||
onmouseenter={handleTooltipEnter}
|
||||
onmouseleave={handleTooltipLeave}
|
||||
>
|
||||
<div
|
||||
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
|
||||
>
|
||||
<!-- Token info -->
|
||||
<div class="mb-2">
|
||||
<span class="text-gray-500 text-xs">Token:</span>
|
||||
<span class="text-white font-mono ml-1"
|
||||
>"{hoveredToken.token.token}"</span
|
||||
>
|
||||
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
|
||||
>{formatProbability(hoveredToken.token.probability)}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<div class="text-gray-400 text-xs mb-1">
|
||||
logprob: <span class="text-gray-300 font-mono"
|
||||
>{formatLogprob(hoveredToken.token.logprob)}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<!-- Top alternatives -->
|
||||
{#if hoveredToken.token.topLogprobs.length > 0}
|
||||
<div class="border-t border-gray-700/50 mt-2 pt-2">
|
||||
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
|
||||
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
|
||||
{@const altProb = Math.exp(alt.logprob)}
|
||||
<div class="flex justify-between items-center text-xs py-0.5">
|
||||
<span class="text-gray-300 font-mono truncate max-w-24"
|
||||
>"{alt.token}"</span
|
||||
>
|
||||
<span class="text-gray-400 ml-2"
|
||||
>{formatProbability(altProb)}</span
|
||||
>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button -->
|
||||
{#if onRegenerateFrom}
|
||||
<button
|
||||
onclick={handleRegenerate}
|
||||
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
|
||||
/>
|
||||
</svg>
|
||||
Regenerate from here
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Arrow -->
|
||||
<div class="absolute left-1/2 -translate-x-1/2 top-full">
|
||||
<div class="border-8 border-transparent border-t-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.token-heatmap {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.token-span {
|
||||
margin: 0;
|
||||
border-width: 1px;
|
||||
}
|
||||
</style>
|
||||
@@ -242,19 +242,6 @@ export interface MessageAttachment {
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
export interface TopLogprob {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}
|
||||
|
||||
export interface TokenData {
|
||||
token: string;
|
||||
logprob: number;
|
||||
probability: number;
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -266,7 +253,6 @@ export interface Message {
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
requestType?: "chat" | "image-generation" | "image-editing";
|
||||
sourceImageDataUrl?: string; // For image editing regeneration
|
||||
tokens?: TokenData[];
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -554,18 +540,7 @@ class AppStore {
|
||||
*/
|
||||
private saveConversationsToStorage() {
|
||||
try {
|
||||
// Strip tokens from messages before saving to avoid bloating localStorage
|
||||
const stripped = this.conversations.map((conv) => ({
|
||||
...conv,
|
||||
messages: conv.messages.map((msg) => {
|
||||
if (msg.tokens) {
|
||||
const { tokens: _, ...rest } = msg;
|
||||
return rest;
|
||||
}
|
||||
return msg;
|
||||
}),
|
||||
}));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
|
||||
} catch (error) {
|
||||
console.error("Failed to save conversations:", error);
|
||||
}
|
||||
@@ -1470,213 +1445,6 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate response from a specific token index.
|
||||
* Truncates the assistant message at the given token and re-generates from there.
|
||||
*/
|
||||
async regenerateFromToken(
|
||||
messageId: string,
|
||||
tokenIndex: number,
|
||||
): Promise<void> {
|
||||
if (this.isLoading) return;
|
||||
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
const msgIndex = this.messages.findIndex((m) => m.id === messageId);
|
||||
if (msgIndex === -1) return;
|
||||
|
||||
const msg = this.messages[msgIndex];
|
||||
if (
|
||||
msg.role !== "assistant" ||
|
||||
!msg.tokens ||
|
||||
tokenIndex >= msg.tokens.length
|
||||
)
|
||||
return;
|
||||
|
||||
// Keep tokens up to (not including) the specified index
|
||||
const tokensToKeep = msg.tokens.slice(0, tokenIndex);
|
||||
const prefixText = tokensToKeep.map((t) => t.token).join("");
|
||||
|
||||
// Remove all messages after this assistant message
|
||||
this.messages = this.messages.slice(0, msgIndex + 1);
|
||||
|
||||
// Update the message to show the prefix
|
||||
this.messages[msgIndex].content = prefixText;
|
||||
this.messages[msgIndex].tokens = tokensToKeep;
|
||||
this.updateActiveConversation();
|
||||
|
||||
// Set up for continuation - modify the existing message in place
|
||||
this.isLoading = true;
|
||||
this.currentResponse = prefixText;
|
||||
this.ttftMs = null;
|
||||
this.tps = null;
|
||||
this.totalTokens = tokensToKeep.length;
|
||||
|
||||
try {
|
||||
// Build messages for API - include the partial assistant message
|
||||
const systemPrompt = {
|
||||
role: "system" as const,
|
||||
content:
|
||||
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
|
||||
};
|
||||
|
||||
const apiMessages = [
|
||||
systemPrompt,
|
||||
...this.messages.map((m) => {
|
||||
let msgContent = m.content;
|
||||
if (m.attachments) {
|
||||
for (const attachment of m.attachments) {
|
||||
if (attachment.type === "text" && attachment.content) {
|
||||
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
|
||||
}
|
||||
}
|
||||
}
|
||||
return { role: m.role, content: msgContent };
|
||||
}),
|
||||
];
|
||||
|
||||
const modelToUse = this.getModelForRequest();
|
||||
if (!modelToUse) {
|
||||
throw new Error("No model available");
|
||||
}
|
||||
|
||||
const requestStartTime = performance.now();
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = tokensToKeep.length;
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) throw new Error("No response body");
|
||||
|
||||
let fullContent = prefixText;
|
||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
}
|
||||
|
||||
tokenCount += 1;
|
||||
this.totalTokens = tokenCount;
|
||||
|
||||
if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {
|
||||
const elapsed = performance.now() - firstTokenTime;
|
||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
this.currentResponse = displayContent;
|
||||
}
|
||||
|
||||
// Update existing message in place
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
messageId,
|
||||
(m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Final update
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||
if (this.tps !== null) m.tps = this.tps;
|
||||
});
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error regenerating from token:", error);
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = `${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
|
||||
});
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} finally {
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to regenerate a chat completion response
|
||||
*/
|
||||
@@ -1745,8 +1513,6 @@ class AppStore {
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1761,49 +1527,16 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
choices?: Array<{ delta?: { content?: string } }>;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const delta = parsed.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
@@ -1821,7 +1554,6 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -1840,7 +1572,6 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -2183,8 +1914,6 @@ class AppStore {
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -2201,48 +1930,14 @@ class AppStore {
|
||||
let streamedContent = "";
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
choices?: Array<{ delta?: { content?: string } }>;
|
||||
}
|
||||
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const tokenContent = parsed.choices?.[0]?.delta?.content;
|
||||
if (tokenContent) {
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
@@ -2278,7 +1973,6 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -2303,7 +1997,6 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
// Store performance metrics on the message
|
||||
if (this.ttftMs !== null) {
|
||||
msg.ttftMs = this.ttftMs;
|
||||
@@ -3000,8 +2693,6 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
|
||||
appStore.regenerateFromToken(messageId, tokenIndex);
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -1663,12 +1663,14 @@
|
||||
<div
|
||||
class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden"
|
||||
>
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
<div data-testid="topology-graph" class="w-full h-full">
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
@@ -1782,12 +1784,14 @@
|
||||
class="flex-1 relative bg-exo-dark-gray/40 mx-4 mb-4 rounded-lg overflow-hidden"
|
||||
>
|
||||
<!-- The main topology graph - full container -->
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
<div data-testid="topology-graph" class="w-full h-full">
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
@@ -2363,6 +2367,7 @@
|
||||
<!-- Model Dropdown (Custom) -->
|
||||
<div class="flex-shrink-0 mb-3 relative">
|
||||
<button
|
||||
data-testid="model-dropdown"
|
||||
type="button"
|
||||
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
|
||||
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
|
||||
@@ -2499,6 +2504,7 @@
|
||||
model.id,
|
||||
)}
|
||||
<button
|
||||
data-testid="model-option"
|
||||
type="button"
|
||||
onclick={() => {
|
||||
if (modelCanFit) {
|
||||
@@ -2777,6 +2783,7 @@
|
||||
{#each allPreviews as apiPreview, i}
|
||||
<div
|
||||
role="group"
|
||||
data-testid="model-card"
|
||||
onmouseenter={() => {
|
||||
if (apiPreview.memory_delta_by_node) {
|
||||
hoveredPreviewNodes = new Set(
|
||||
|
||||
68
dashboard/tests/e2e/chat-message.spec.ts
Normal file
68
dashboard/tests/e2e/chat-message.spec.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import {
|
||||
waitForTopologyLoaded,
|
||||
waitForModelCards,
|
||||
waitForChatReady,
|
||||
waitForAssistantMessage,
|
||||
sendChatMessage,
|
||||
selectModelFromLaunchDropdown,
|
||||
} from "../helpers/wait-for-ready";
|
||||
|
||||
test.describe("Chat Message", () => {
|
||||
test("should send a message and receive a response", async ({ page }) => {
|
||||
// Increase timeout for this test since it involves model loading and inference
|
||||
test.setTimeout(600000); // 10 minutes
|
||||
|
||||
await page.goto("/");
|
||||
await waitForTopologyLoaded(page);
|
||||
|
||||
// First select the model from the dropdown (model cards appear after selection)
|
||||
await selectModelFromLaunchDropdown(page, /qwen.*0\.6b/i);
|
||||
|
||||
// Now wait for model cards to appear
|
||||
await waitForModelCards(page);
|
||||
|
||||
// Find and click on the model card (should already be filtered to Qwen)
|
||||
const modelCard = page.locator('[data-testid="model-card"]').first();
|
||||
await expect(modelCard).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Click the launch button
|
||||
const launchButton = modelCard.locator('[data-testid="launch-button"]');
|
||||
await launchButton.click();
|
||||
|
||||
// Wait for the model to be ready (may take time to download)
|
||||
await expect(
|
||||
page
|
||||
.locator('[data-testid="instance-status"]')
|
||||
.filter({ hasText: /READY/i })
|
||||
.first(),
|
||||
).toBeVisible({ timeout: 300000 }); // 5 minutes for download
|
||||
|
||||
// Wait for chat to be ready
|
||||
await waitForChatReady(page);
|
||||
|
||||
// Select the model in the chat selector if needed
|
||||
const modelSelector = page.locator('[data-testid="chat-model-selector"]');
|
||||
if (await modelSelector.isVisible()) {
|
||||
await modelSelector.click();
|
||||
await page.locator("text=/qwen.*0\\.6b/i").first().click();
|
||||
}
|
||||
|
||||
// Send a simple message
|
||||
await sendChatMessage(page, "What is 2+2?");
|
||||
|
||||
// Wait for assistant response
|
||||
await waitForAssistantMessage(page, 120000); // 2 minutes for inference
|
||||
|
||||
// Verify the assistant message is visible
|
||||
const assistantMessage = page
|
||||
.locator('[data-testid="assistant-message"]')
|
||||
.last();
|
||||
await expect(assistantMessage).toBeVisible();
|
||||
|
||||
// The response should contain something (not empty)
|
||||
const messageContent = await assistantMessage.textContent();
|
||||
expect(messageContent).toBeTruthy();
|
||||
expect(messageContent!.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
36
dashboard/tests/e2e/launch-instance.spec.ts
Normal file
36
dashboard/tests/e2e/launch-instance.spec.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import {
|
||||
waitForTopologyLoaded,
|
||||
waitForModelCards,
|
||||
selectModelFromLaunchDropdown,
|
||||
} from "../helpers/wait-for-ready";
|
||||
|
||||
test.describe("Launch Instance", () => {
|
||||
test("should launch Qwen3-0.6B-4bit model", async ({ page }) => {
|
||||
await page.goto("/");
|
||||
await waitForTopologyLoaded(page);
|
||||
|
||||
// First select the model from the dropdown (model cards appear after selection)
|
||||
await selectModelFromLaunchDropdown(page, /qwen.*0\.6b/i);
|
||||
|
||||
// Now wait for model cards to appear
|
||||
await waitForModelCards(page);
|
||||
|
||||
// Find and click on the model card (should already be filtered to Qwen)
|
||||
const modelCard = page.locator('[data-testid="model-card"]').first();
|
||||
await expect(modelCard).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Click the launch button
|
||||
const launchButton = modelCard.locator('[data-testid="launch-button"]');
|
||||
await launchButton.click();
|
||||
|
||||
// Wait for the model to start (status should change to READY or show download progress)
|
||||
// The model may need to download first, so we wait with a longer timeout
|
||||
await expect(
|
||||
page
|
||||
.locator('[data-testid="instance-status"]')
|
||||
.filter({ hasText: /READY|downloading/i })
|
||||
.first(),
|
||||
).toBeVisible({ timeout: 300000 }); // 5 minutes for download
|
||||
});
|
||||
});
|
||||
117
dashboard/tests/helpers/wait-for-ready.ts
Normal file
117
dashboard/tests/helpers/wait-for-ready.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
import { expect, type Page } from "@playwright/test";
|
||||
|
||||
const BASE_URL = "http://localhost:52415";
|
||||
|
||||
export async function waitForApiReady(
|
||||
page: Page,
|
||||
timeoutMs = 30000,
|
||||
): Promise<void> {
|
||||
const startTime = Date.now();
|
||||
while (Date.now() - startTime < timeoutMs) {
|
||||
try {
|
||||
const response = await page.request.get(`${BASE_URL}/node_id`);
|
||||
if (response.ok()) {
|
||||
return;
|
||||
}
|
||||
} catch {
|
||||
// API not ready yet, continue polling
|
||||
}
|
||||
await page.waitForTimeout(500);
|
||||
}
|
||||
throw new Error(`API did not become ready within ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
export async function waitForTopologyLoaded(page: Page): Promise<void> {
|
||||
await expect(page.locator('[data-testid="topology-graph"]')).toBeVisible({
|
||||
timeout: 30000,
|
||||
});
|
||||
}
|
||||
|
||||
export async function waitForModelCards(page: Page): Promise<void> {
|
||||
await expect(page.locator('[data-testid="model-card"]').first()).toBeVisible({
|
||||
timeout: 30000,
|
||||
});
|
||||
}
|
||||
|
||||
export async function selectModelFromLaunchDropdown(
|
||||
page: Page,
|
||||
modelPattern: RegExp | string,
|
||||
): Promise<void> {
|
||||
// Click the model dropdown in the Launch Instance panel
|
||||
const dropdown = page.locator('button:has-text("SELECT MODEL")');
|
||||
await expect(dropdown).toBeVisible({ timeout: 30000 });
|
||||
await dropdown.click();
|
||||
|
||||
// Wait for dropdown menu to appear and select the model
|
||||
const modelOption = page.locator("button").filter({ hasText: modelPattern });
|
||||
await expect(modelOption.first()).toBeVisible({ timeout: 10000 });
|
||||
await modelOption.first().click();
|
||||
}
|
||||
|
||||
export async function waitForChatReady(page: Page): Promise<void> {
|
||||
await expect(page.locator('[data-testid="chat-input"]')).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
await expect(page.locator('[data-testid="send-button"]')).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
}
|
||||
|
||||
export async function waitForAssistantMessage(
|
||||
page: Page,
|
||||
timeoutMs = 60000,
|
||||
): Promise<void> {
|
||||
await expect(
|
||||
page.locator('[data-testid="assistant-message"]').last(),
|
||||
).toBeVisible({ timeout: timeoutMs });
|
||||
}
|
||||
|
||||
export async function waitForStreamingComplete(
|
||||
page: Page,
|
||||
timeoutMs = 120000,
|
||||
): Promise<void> {
|
||||
const startTime = Date.now();
|
||||
while (Date.now() - startTime < timeoutMs) {
|
||||
const sendButton = page.locator('[data-testid="send-button"]');
|
||||
const buttonText = await sendButton.textContent();
|
||||
if (
|
||||
buttonText &&
|
||||
!buttonText.includes("PROCESSING") &&
|
||||
!buttonText.includes("...")
|
||||
) {
|
||||
return;
|
||||
}
|
||||
await page.waitForTimeout(500);
|
||||
}
|
||||
throw new Error(`Streaming did not complete within ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
export async function selectModel(
|
||||
page: Page,
|
||||
modelName: string,
|
||||
): Promise<void> {
|
||||
const modelSelector = page.locator('[data-testid="chat-model-selector"]');
|
||||
await modelSelector.click();
|
||||
await page.locator(`text=${modelName}`).click();
|
||||
}
|
||||
|
||||
export async function sendChatMessage(
|
||||
page: Page,
|
||||
message: string,
|
||||
): Promise<void> {
|
||||
const chatInput = page.locator('[data-testid="chat-input"]');
|
||||
await chatInput.fill(message);
|
||||
const sendButton = page.locator('[data-testid="send-button"]');
|
||||
await sendButton.click();
|
||||
}
|
||||
|
||||
export async function launchModel(
|
||||
page: Page,
|
||||
modelCardIndex = 0,
|
||||
): Promise<void> {
|
||||
const modelCards = page.locator('[data-testid="model-card"]');
|
||||
const launchButton = modelCards
|
||||
.nth(modelCardIndex)
|
||||
.locator('[data-testid="launch-button"]');
|
||||
await launchButton.click();
|
||||
}
|
||||
26
dashboard/tests/visual/chat-interface.spec.ts
Normal file
26
dashboard/tests/visual/chat-interface.spec.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { waitForTopologyLoaded } from "../helpers/wait-for-ready";
|
||||
|
||||
test.describe("Chat Interface", () => {
|
||||
test("should display chat input and send button", async ({ page }) => {
|
||||
await page.goto("/");
|
||||
await waitForTopologyLoaded(page);
|
||||
|
||||
const chatInput = page.locator('[data-testid="chat-input"]');
|
||||
await expect(chatInput).toBeVisible();
|
||||
|
||||
const sendButton = page.locator('[data-testid="send-button"]');
|
||||
await expect(sendButton).toBeVisible();
|
||||
});
|
||||
|
||||
test("should allow typing in chat input", async ({ page }) => {
|
||||
await page.goto("/");
|
||||
await waitForTopologyLoaded(page);
|
||||
|
||||
const chatInput = page.locator('[data-testid="chat-input"]');
|
||||
await expect(chatInput).toBeVisible();
|
||||
|
||||
await chatInput.fill("Test message");
|
||||
await expect(chatInput).toHaveValue("Test message");
|
||||
});
|
||||
});
|
||||
16
dashboard/tests/visual/homepage.spec.ts
Normal file
16
dashboard/tests/visual/homepage.spec.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { waitForTopologyLoaded } from "../helpers/wait-for-ready";
|
||||
|
||||
test.describe("Homepage", () => {
|
||||
test("should load and display key elements", async ({ page }) => {
|
||||
await page.goto("/");
|
||||
await waitForTopologyLoaded(page);
|
||||
|
||||
// Verify key UI elements are present
|
||||
await expect(
|
||||
page.locator('[data-testid="topology-graph"]').first(),
|
||||
).toBeVisible();
|
||||
await expect(page.locator('[data-testid="chat-input"]')).toBeVisible();
|
||||
await expect(page.locator('[data-testid="send-button"]')).toBeVisible();
|
||||
});
|
||||
});
|
||||
@@ -14,8 +14,6 @@ from exo.shared.types.api import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
@@ -83,8 +81,6 @@ def chat_request_to_text_generation(
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
logprobs=request.logprobs or False,
|
||||
top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,19 +88,6 @@ def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
|
||||
# Build logprobs if available
|
||||
logprobs: Logprobs | None = None
|
||||
if chunk.logprob is not None:
|
||||
logprobs = Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -113,7 +96,6 @@ def chunk_to_response(
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -180,7 +162,6 @@ async def collect_chat_response(
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_content: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
@@ -195,14 +176,6 @@ async def collect_chat_response(
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -235,9 +208,6 @@ async def collect_chat_response(
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
logprobs=Logprobs(content=logprobs_content)
|
||||
if logprobs_content
|
||||
else None,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
|
||||
@@ -610,11 +610,6 @@ class API:
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
@@ -1164,11 +1159,6 @@ class API:
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_claude_response(
|
||||
@@ -1196,11 +1186,6 @@ class API:
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_responses_response(
|
||||
|
||||
@@ -2,12 +2,7 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -25,8 +20,6 @@ class TokenChunk(BaseChunk):
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
|
||||
@@ -40,5 +40,3 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
|
||||
stop: str | list[str] | None = None
|
||||
seed: int | None = None
|
||||
chat_template_messages: list[dict[str, Any]] | None = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: int | None = None
|
||||
|
||||
@@ -6,7 +6,6 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
@@ -23,8 +22,7 @@ class TokenizedResponse(BaseRunnerResponse):
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
|
||||
@@ -11,7 +11,5 @@ QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = None
|
||||
|
||||
DEFAULT_TOP_LOGPROBS: int = 5
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -12,7 +12,6 @@ from exo.shared.types.api import (
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
PromptTokensDetails,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
@@ -24,12 +23,7 @@ from exo.shared.types.worker.runner_response import (
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
DEFAULT_TOP_LOGPROBS,
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
mx_barrier,
|
||||
@@ -161,60 +155,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_logprobs: int,
|
||||
selected_token: int,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top alternative tokens.
|
||||
|
||||
Args:
|
||||
logprobs: Full vocabulary logprobs array from MLX
|
||||
tokenizer: Tokenizer for decoding token IDs to strings
|
||||
top_logprobs: Number of top alternatives to return
|
||||
selected_token: The token ID that was actually sampled
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_token_logprob, list of TopLogprobItem for top alternatives)
|
||||
"""
|
||||
# Get the logprob of the selected token
|
||||
selected_logprob = float(logprobs[selected_token].item())
|
||||
|
||||
# Get top indices (most probable tokens)
|
||||
# mx.argpartition gives indices that would partition the array
|
||||
# We negate logprobs since argpartition finds smallest, and we want largest
|
||||
top_logprobs = min(top_logprobs, logprobs.shape[0]) # Don't exceed vocab size
|
||||
top_indices = mx.argpartition(-logprobs, top_logprobs)[:top_logprobs]
|
||||
|
||||
# Get the actual logprob values for these indices
|
||||
top_values = logprobs[top_indices]
|
||||
|
||||
# Sort by logprob (descending) for consistent ordering
|
||||
sort_order = mx.argsort(-top_values)
|
||||
top_indices = top_indices[sort_order]
|
||||
top_values = top_values[sort_order]
|
||||
|
||||
# Convert to list of TopLogprobItem
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for i in range(top_logprobs):
|
||||
token_id = int(top_indices[i].item())
|
||||
token_logprob = float(top_values[i].item())
|
||||
# Decode token ID to string
|
||||
token_str = tokenizer.decode([token_id])
|
||||
# Get byte representation
|
||||
token_bytes = list(token_str.encode("utf-8"))
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=token_logprob,
|
||||
bytes=token_bytes,
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -356,22 +296,9 @@ def mlx_generate(
|
||||
),
|
||||
)
|
||||
|
||||
# Extract logprobs from the full vocabulary logprobs array
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if task.logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs=out.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_logprobs=task.top_logprobs or DEFAULT_TOP_LOGPROBS,
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
|
||||
@@ -442,12 +442,6 @@ def apply_chat_template(
|
||||
continue
|
||||
formatted_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# For assistant prefilling, append content after templating to avoid a closing turn token.
|
||||
partial_assistant_content: str | None = None
|
||||
if formatted_messages and formatted_messages[-1].get("role") == "assistant":
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
@@ -455,9 +449,6 @@ def apply_chat_template(
|
||||
tools=task_params.tools,
|
||||
)
|
||||
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -320,8 +320,6 @@ def main(
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user