mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-17 18:41:49 -05:00
Compare commits
7 Commits
aiohttp
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
302e67c8c0 | ||
|
|
63ec56c696 | ||
|
|
578a417a7e | ||
|
|
57ec2c9011 | ||
|
|
c6963cafa0 | ||
|
|
663a0faaeb | ||
|
|
0a58aa73ec |
@@ -56,11 +56,6 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Show warning if local network is not working and EXO is running.
|
||||
// The checker uses a longer timeout on first launch to allow time for
|
||||
// the permission prompt, so this correctly handles both:
|
||||
// 1. User denied permission on first launch
|
||||
// 2. Permission broke after restart (macOS TCC bug)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import os.log
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service uses NWConnection to mDNS multicast
|
||||
/// to verify actual connectivity.
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks connectivity using NWConnection to mDNS multicast.
|
||||
/// The connection attempt triggers the permission prompt if not yet shown.
|
||||
private func checkConnectivity(timeout: UInt64) async -> Status {
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { state in
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
// Otherwise keep waiting - might be showing permission prompt
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: timeout)
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
|
||||
@@ -241,9 +241,6 @@ class PromptSizer:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Fix for transformers 5.x
|
||||
if hasattr(ids, "input_ids"):
|
||||
ids = ids.input_ids
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
|
||||
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,7 +863,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -903,7 +902,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1520,7 +1518,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1530,7 +1527,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1943,7 +1939,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2651,7 +2646,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2839,7 +2833,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2984,7 +2977,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3006,7 +2998,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
isLoading,
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
import TokenHeatmap from './TokenHeatmap.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -95,6 +97,23 @@
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
// Uncertainty view state - tracks which messages show token heatmap
|
||||
let uncertaintyViewMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function toggleUncertaintyView(messageId: string) {
|
||||
const newSet = new Set(uncertaintyViewMessageIds);
|
||||
if (newSet.has(messageId)) {
|
||||
newSet.delete(messageId);
|
||||
} else {
|
||||
newSet.add(messageId);
|
||||
}
|
||||
uncertaintyViewMessageIds = newSet;
|
||||
}
|
||||
|
||||
function isUncertaintyViewEnabled(messageId: string): boolean {
|
||||
return uncertaintyViewMessageIds.has(messageId);
|
||||
}
|
||||
|
||||
function formatTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toLocaleTimeString('en-US', {
|
||||
hour12: false,
|
||||
@@ -366,7 +385,17 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if message.role === 'assistant' && isUncertaintyViewEnabled(message.id) && message.tokens && message.tokens.length > 0}
|
||||
<!-- Uncertainty heatmap view -->
|
||||
<TokenHeatmap
|
||||
tokens={message.tokens}
|
||||
isGenerating={loading}
|
||||
onRegenerateFrom={(tokenIndex) => regenerateFromToken(message.id, tokenIndex)}
|
||||
/>
|
||||
{:else}
|
||||
<!-- Normal markdown view -->
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{/if}
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
@@ -419,6 +448,19 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Uncertainty view toggle (assistant messages with tokens only) -->
|
||||
{#if message.role === 'assistant' && message.tokens && message.tokens.length > 0}
|
||||
<button
|
||||
onclick={() => toggleUncertaintyView(message.id)}
|
||||
class="p-1.5 transition-colors rounded cursor-pointer {isUncertaintyViewEnabled(message.id) ? 'text-exo-yellow' : 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
title={isUncertaintyViewEnabled(message.id) ? 'Hide uncertainty' : 'Show uncertainty'}
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
|
||||
192
dashboard/src/lib/components/TokenHeatmap.svelte
Normal file
192
dashboard/src/lib/components/TokenHeatmap.svelte
Normal file
@@ -0,0 +1,192 @@
|
||||
<script lang="ts">
|
||||
import type { TokenData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
tokens: TokenData[];
|
||||
class?: string;
|
||||
isGenerating?: boolean;
|
||||
onRegenerateFrom?: (tokenIndex: number) => void;
|
||||
}
|
||||
|
||||
let { tokens, class: className = '', isGenerating = false, onRegenerateFrom }: Props = $props();
|
||||
|
||||
// Tooltip state - track both token data and index
|
||||
let hoveredTokenIndex = $state<number | null>(null);
|
||||
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
|
||||
let isTooltipHovered = $state(false);
|
||||
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
// Derive the hovered token from the index (stable across re-renders)
|
||||
const hoveredToken = $derived(
|
||||
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
|
||||
? { token: tokens[hoveredTokenIndex], index: hoveredTokenIndex, ...hoveredPosition }
|
||||
: null
|
||||
);
|
||||
|
||||
/**
|
||||
* Get confidence styling based on probability.
|
||||
* Following Apple design principles: high confidence tokens blend in,
|
||||
* only uncertainty draws attention.
|
||||
*/
|
||||
function getConfidenceClass(probability: number): string {
|
||||
if (probability > 0.8) return 'text-inherit'; // Expected tokens - blend in
|
||||
if (probability > 0.5) return 'bg-gray-500/10 text-inherit'; // Slight hint
|
||||
if (probability > 0.2) return 'bg-amber-500/15 text-amber-200/90'; // Subtle warmth
|
||||
return 'bg-red-500/20 text-red-200/90'; // Draws attention
|
||||
}
|
||||
|
||||
/**
|
||||
* Get border/underline styling for uncertain tokens
|
||||
*/
|
||||
function getBorderClass(probability: number): string {
|
||||
if (probability > 0.8) return 'border-transparent'; // No border for expected
|
||||
if (probability > 0.5) return 'border-gray-500/20';
|
||||
if (probability > 0.2) return 'border-amber-500/30';
|
||||
return 'border-red-500/40';
|
||||
}
|
||||
|
||||
function clearHideTimeout() {
|
||||
if (hideTimeoutId) {
|
||||
clearTimeout(hideTimeoutId);
|
||||
hideTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseEnter(event: MouseEvent, token: TokenData, index: number) {
|
||||
clearHideTimeout();
|
||||
const rect = (event.target as HTMLElement).getBoundingClientRect();
|
||||
hoveredTokenIndex = index;
|
||||
hoveredPosition = {
|
||||
x: rect.left + rect.width / 2,
|
||||
y: rect.top - 10
|
||||
};
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
clearHideTimeout();
|
||||
// Use longer delay during generation to account for re-renders
|
||||
const delay = isGenerating ? 300 : 100;
|
||||
hideTimeoutId = setTimeout(() => {
|
||||
if (!isTooltipHovered) {
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
}, delay);
|
||||
}
|
||||
|
||||
function handleTooltipEnter() {
|
||||
clearHideTimeout();
|
||||
isTooltipHovered = true;
|
||||
}
|
||||
|
||||
function handleTooltipLeave() {
|
||||
isTooltipHovered = false;
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
|
||||
function handleRegenerate() {
|
||||
if (hoveredToken && onRegenerateFrom) {
|
||||
const indexToRegenerate = hoveredToken.index;
|
||||
// Clear hover state immediately
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
isTooltipHovered = false;
|
||||
// Call regenerate
|
||||
onRegenerateFrom(indexToRegenerate);
|
||||
}
|
||||
}
|
||||
|
||||
function formatProbability(prob: number): string {
|
||||
return (prob * 100).toFixed(1) + '%';
|
||||
}
|
||||
|
||||
function formatLogprob(logprob: number): string {
|
||||
return logprob.toFixed(3);
|
||||
}
|
||||
|
||||
function getProbabilityColor(probability: number): string {
|
||||
if (probability > 0.8) return 'text-gray-300';
|
||||
if (probability > 0.5) return 'text-gray-400';
|
||||
if (probability > 0.2) return 'text-amber-400';
|
||||
return 'text-red-400';
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="token-heatmap leading-relaxed {className}">
|
||||
{#each tokens as tokenData, i (i)}
|
||||
<span
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(tokenData.probability)} {getBorderClass(tokenData.probability)} hover:opacity-80"
|
||||
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
|
||||
onmouseleave={handleMouseLeave}
|
||||
>{tokenData.token}</span>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Tooltip -->
|
||||
{#if hoveredToken}
|
||||
<div
|
||||
class="fixed z-50"
|
||||
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
|
||||
onmouseenter={handleTooltipEnter}
|
||||
onmouseleave={handleTooltipLeave}
|
||||
>
|
||||
<div class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48">
|
||||
<!-- Token info -->
|
||||
<div class="mb-2">
|
||||
<span class="text-gray-500 text-xs">Token:</span>
|
||||
<span class="text-white font-mono ml-1">"{hoveredToken.token.token}"</span>
|
||||
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2">{formatProbability(hoveredToken.token.probability)}</span>
|
||||
</div>
|
||||
|
||||
<div class="text-gray-400 text-xs mb-1">
|
||||
logprob: <span class="text-gray-300 font-mono">{formatLogprob(hoveredToken.token.logprob)}</span>
|
||||
</div>
|
||||
|
||||
<!-- Top alternatives -->
|
||||
{#if hoveredToken.token.topLogprobs.length > 0}
|
||||
<div class="border-t border-gray-700/50 mt-2 pt-2">
|
||||
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
|
||||
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
|
||||
{@const altProb = Math.exp(alt.logprob)}
|
||||
<div class="flex justify-between items-center text-xs py-0.5">
|
||||
<span class="text-gray-300 font-mono truncate max-w-24">"{alt.token}"</span>
|
||||
<span class="text-gray-400 ml-2">{formatProbability(altProb)}</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button -->
|
||||
{#if onRegenerateFrom}
|
||||
<button
|
||||
onclick={handleRegenerate}
|
||||
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
|
||||
>
|
||||
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
|
||||
</svg>
|
||||
Regenerate from here
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Arrow -->
|
||||
<div class="absolute left-1/2 -translate-x-1/2 top-full">
|
||||
<div class="border-8 border-transparent border-t-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.token-heatmap {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.token-span {
|
||||
margin: 0;
|
||||
border-width: 1px;
|
||||
}
|
||||
</style>
|
||||
@@ -182,6 +182,20 @@ export interface MessageAttachment {
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
// Token-level data for uncertainty visualization
|
||||
export interface TopLogprob {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}
|
||||
|
||||
export interface TokenData {
|
||||
token: string;
|
||||
logprob: number;
|
||||
probability: number; // exp(logprob)
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -191,6 +205,7 @@ export interface Message {
|
||||
attachments?: MessageAttachment[];
|
||||
ttftMs?: number; // Time to first token in ms (for assistant messages)
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
tokens?: TokenData[]; // Token-level data for uncertainty visualization
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -368,6 +383,21 @@ class AppStore {
|
||||
private fetchInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private currentRequestController: AbortController | null = null;
|
||||
|
||||
/**
|
||||
* Abort any in-flight generation request
|
||||
*/
|
||||
abortCurrentRequest(): boolean {
|
||||
if (this.currentRequestController) {
|
||||
this.currentRequestController.abort();
|
||||
this.currentRequestController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -1046,6 +1076,10 @@ class AppStore {
|
||||
// Remove any messages after the user message
|
||||
this.messages = this.messages.slice(0, lastUserIndex + 1);
|
||||
|
||||
// Create abort controller for this request
|
||||
const controller = new AbortController();
|
||||
this.currentRequestController = controller;
|
||||
|
||||
// Resend the message to get a new response
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
@@ -1107,7 +1141,10 @@ class AppStore {
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1140,6 +1177,7 @@ class AppStore {
|
||||
const decoder = new TextDecoder();
|
||||
let fullContent = "";
|
||||
let partialLine = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@@ -1158,6 +1196,29 @@ class AppStore {
|
||||
const json = JSON.parse(trimmed.slice(6));
|
||||
const delta = json.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
// Extract logprobs for uncertainty visualization
|
||||
const logprobsData = json.choices?.[0]?.logprobs;
|
||||
if (logprobsData?.content?.[0]) {
|
||||
const logprobItem = logprobsData.content[0];
|
||||
const tokenData: TokenData = {
|
||||
token: logprobItem.token || delta,
|
||||
logprob: logprobItem.logprob ?? 0,
|
||||
probability: Math.exp(logprobItem.logprob ?? 0),
|
||||
topLogprobs: (logprobItem.top_logprobs || []).map(
|
||||
(item: {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}) => ({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
bytes: item.bytes,
|
||||
}),
|
||||
),
|
||||
};
|
||||
collectedTokens.push(tokenData);
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
@@ -1170,6 +1231,7 @@ class AppStore {
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
this.messages[idx].tokens = [...collectedTokens];
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
@@ -1187,9 +1249,16 @@ class AppStore {
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
if (collectedTokens.length > 0) {
|
||||
this.messages[idx].tokens = collectedTokens;
|
||||
}
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} catch (error) {
|
||||
// Don't show error for aborted requests (user cancelled)
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
const idx = this.messages.findIndex((m) => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content =
|
||||
@@ -1197,6 +1266,10 @@ class AppStore {
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
// Clean up controller if this is still the active request
|
||||
if (this.currentRequestController === controller) {
|
||||
this.currentRequestController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.updateActiveConversation();
|
||||
@@ -1218,6 +1291,210 @@ class AppStore {
|
||||
this.tps = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate from a specific token in an assistant message.
|
||||
* Keeps content up to and including the specified token, then continues generation.
|
||||
* If a generation is already in progress, it will be aborted first.
|
||||
*/
|
||||
async regenerateFromToken(
|
||||
messageId: string,
|
||||
tokenIndex: number,
|
||||
): Promise<void> {
|
||||
// Abort any in-flight request first
|
||||
this.abortCurrentRequest();
|
||||
|
||||
const messageIdx = this.messages.findIndex((m) => m.id === messageId);
|
||||
if (messageIdx === -1) return;
|
||||
|
||||
const message = this.messages[messageIdx];
|
||||
if (message.role !== "assistant" || !message.tokens) return;
|
||||
|
||||
// Get tokens up to and including the specified index
|
||||
const keptTokens = message.tokens.slice(0, tokenIndex + 1);
|
||||
const prefixText = keptTokens.map((t) => t.token).join("");
|
||||
|
||||
// Update the message with just the prefix
|
||||
this.messages[messageIdx].content = prefixText;
|
||||
this.messages[messageIdx].tokens = keptTokens;
|
||||
this.messages[messageIdx].thinking = undefined;
|
||||
this.persistActiveConversation();
|
||||
|
||||
// Start loading
|
||||
this.isLoading = true;
|
||||
this.currentResponse = prefixText;
|
||||
|
||||
// Create abort controller for this request
|
||||
const controller = new AbortController();
|
||||
this.currentRequestController = controller;
|
||||
|
||||
try {
|
||||
const systemPrompt = {
|
||||
role: "system" as const,
|
||||
content:
|
||||
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
|
||||
};
|
||||
|
||||
// Build messages: all messages before this one, plus the prefix as assistant
|
||||
const apiMessages: { role: string; content: string }[] = [systemPrompt];
|
||||
for (let i = 0; i < messageIdx; i++) {
|
||||
const m = this.messages[i];
|
||||
apiMessages.push({ role: m.role, content: m.content || "" });
|
||||
}
|
||||
// Add the prefix as a partial assistant response to continue from
|
||||
apiMessages.push({ role: "assistant", content: prefixText });
|
||||
|
||||
// Determine which model to use
|
||||
let modelToUse = this.selectedChatModel;
|
||||
if (!modelToUse) {
|
||||
const firstInstanceKey = Object.keys(this.instances)[0];
|
||||
if (firstInstanceKey) {
|
||||
const instance = this.instances[firstInstanceKey] as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
if (instance) {
|
||||
const keys = Object.keys(instance);
|
||||
if (keys.length === 1) {
|
||||
const inst = instance[keys[0]] as
|
||||
| { shardAssignments?: { modelId?: string } }
|
||||
| undefined;
|
||||
modelToUse = inst?.shardAssignments?.modelId || "";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!modelToUse) {
|
||||
this.messages[messageIdx].content =
|
||||
prefixText + "\n\nError: No model available.";
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
continue_from_prefix: true,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
this.messages[messageIdx].content =
|
||||
prefixText + `\n\nError: ${response.status} - ${errorText}`;
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
this.messages[messageIdx].content =
|
||||
prefixText + "\n\nError: No response stream available";
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let fullContent = prefixText;
|
||||
let partialLine = "";
|
||||
const collectedTokens: TokenData[] = [...keptTokens];
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
const lines = (partialLine + chunk).split("\n");
|
||||
partialLine = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || trimmed === "data: [DONE]") continue;
|
||||
|
||||
if (trimmed.startsWith("data: ")) {
|
||||
try {
|
||||
const json = JSON.parse(trimmed.slice(6));
|
||||
const delta = json.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
// Extract logprobs for uncertainty visualization
|
||||
const logprobsData = json.choices?.[0]?.logprobs;
|
||||
if (logprobsData?.content?.[0]) {
|
||||
const logprobItem = logprobsData.content[0];
|
||||
const tokenData: TokenData = {
|
||||
token: logprobItem.token || delta,
|
||||
logprob: logprobItem.logprob ?? 0,
|
||||
probability: Math.exp(logprobItem.logprob ?? 0),
|
||||
topLogprobs: (logprobItem.top_logprobs || []).map(
|
||||
(item: {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}) => ({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
bytes: item.bytes,
|
||||
}),
|
||||
),
|
||||
};
|
||||
collectedTokens.push(tokenData);
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
this.currentResponse = displayContent;
|
||||
|
||||
this.messages[messageIdx].content = displayContent;
|
||||
this.messages[messageIdx].thinking =
|
||||
thinkingContent || undefined;
|
||||
this.messages[messageIdx].tokens = [...collectedTokens];
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
} catch {
|
||||
// Skip malformed JSON
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final cleanup
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
this.messages[messageIdx].content = displayContent;
|
||||
this.messages[messageIdx].thinking = thinkingContent || undefined;
|
||||
if (collectedTokens.length > 0) {
|
||||
this.messages[messageIdx].tokens = collectedTokens;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} catch (error) {
|
||||
// Don't show error for aborted requests (user cancelled)
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
this.messages[messageIdx].content =
|
||||
prefixText +
|
||||
`\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
// Clean up controller if this is still the active request
|
||||
if (this.currentRequestController === controller) {
|
||||
this.currentRequestController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.updateActiveConversation();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Strip thinking tags from content for display.
|
||||
* Handles both complete <think>...</think> blocks and in-progress <think>... blocks during streaming.
|
||||
@@ -1274,6 +1551,10 @@ class AppStore {
|
||||
this.startChat();
|
||||
}
|
||||
|
||||
// Create abort controller for this request
|
||||
const controller = new AbortController();
|
||||
this.currentRequestController = controller;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
this.ttftMs = null;
|
||||
@@ -1408,7 +1689,10 @@ class AppStore {
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1424,6 +1708,7 @@ class AppStore {
|
||||
const decoder = new TextDecoder();
|
||||
let fullContent = "";
|
||||
let buffer = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@@ -1463,6 +1748,29 @@ class AppStore {
|
||||
this.tps = (tokenCount / elapsed) * 1000;
|
||||
}
|
||||
|
||||
// Extract logprobs for uncertainty visualization
|
||||
const logprobsData = parsed.choices?.[0]?.logprobs;
|
||||
if (logprobsData?.content?.[0]) {
|
||||
const logprobItem = logprobsData.content[0];
|
||||
const tokenData: TokenData = {
|
||||
token: logprobItem.token || tokenContent,
|
||||
logprob: logprobItem.logprob ?? 0,
|
||||
probability: Math.exp(logprobItem.logprob ?? 0),
|
||||
topLogprobs: (logprobItem.top_logprobs || []).map(
|
||||
(item: {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}) => ({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
bytes: item.bytes,
|
||||
}),
|
||||
),
|
||||
};
|
||||
collectedTokens.push(tokenData);
|
||||
}
|
||||
|
||||
fullContent += tokenContent;
|
||||
|
||||
// Strip thinking tags for display and extract thinking content
|
||||
@@ -1477,6 +1785,8 @@ class AppStore {
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
// Update tokens during streaming for real-time visualization
|
||||
this.messages[idx].tokens = [...collectedTokens];
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
@@ -1524,9 +1834,17 @@ class AppStore {
|
||||
if (this.tps !== null) {
|
||||
this.messages[idx].tps = this.tps;
|
||||
}
|
||||
// Store token data for uncertainty visualization
|
||||
if (collectedTokens.length > 0) {
|
||||
this.messages[idx].tokens = collectedTokens;
|
||||
}
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} catch (error) {
|
||||
// Don't show error for aborted requests (user cancelled)
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
console.error("Error sending message:", error);
|
||||
// Update the assistant message with error
|
||||
const idx = this.messages.findIndex((m) => m.id === assistantMessage.id);
|
||||
@@ -1536,6 +1854,10 @@ class AppStore {
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
// Clean up controller if this is still the active request
|
||||
if (this.currentRequestController === controller) {
|
||||
this.currentRequestController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.updateActiveConversation();
|
||||
@@ -1615,6 +1937,9 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
|
||||
appStore.regenerateFromToken(messageId, tokenIndex);
|
||||
export const abortCurrentRequest = () => appStore.abortCurrentRequest();
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -6,6 +6,8 @@ readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.12.14",
|
||||
"types-aiofiles>=24.1.0.20250708",
|
||||
"pydantic>=2.11.7",
|
||||
"fastapi>=0.116.1",
|
||||
"filelock>=3.18.0",
|
||||
@@ -21,7 +23,6 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
1
src/exo/master/adapters/__init__.py
Normal file
1
src/exo/master/adapters/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""
|
||||
184
src/exo/master/adapters/claude.py
Normal file
184
src/exo/master/adapters/claude.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
FinishReason,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
ClaudeContentBlockStartEvent,
|
||||
ClaudeContentBlockStopEvent,
|
||||
ClaudeMessageDelta,
|
||||
ClaudeMessageDeltaEvent,
|
||||
ClaudeMessageDeltaUsage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessagesResponse,
|
||||
ClaudeMessageStart,
|
||||
ClaudeMessageStartEvent,
|
||||
ClaudeMessageStopEvent,
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeUsage,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
|
||||
|
||||
def finish_reason_to_claude_stop_reason(
|
||||
finish_reason: FinishReason | None,
|
||||
) -> ClaudeStopReason | None:
|
||||
"""Map OpenAI finish_reason to Claude stop_reason."""
|
||||
if finish_reason is None:
|
||||
return None
|
||||
mapping: dict[FinishReason, ClaudeStopReason] = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
"content_filter": "end_turn",
|
||||
"function_call": "tool_use",
|
||||
}
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
|
||||
def claude_request_to_chat_params(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> ChatCompletionTaskParams:
|
||||
"""Convert Claude Messages API request to internal ChatCompletionTaskParams."""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
# Add system message if present
|
||||
if request.system:
|
||||
if isinstance(request.system, str):
|
||||
messages.append(
|
||||
ChatCompletionMessage(role="system", content=request.system)
|
||||
)
|
||||
else:
|
||||
# List of text blocks
|
||||
system_text = "".join(block.text for block in request.system)
|
||||
messages.append(ChatCompletionMessage(role="system", content=system_text))
|
||||
|
||||
# Convert messages
|
||||
for msg in request.messages:
|
||||
content: str
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
else:
|
||||
# Concatenate text blocks (images not supported for MVP)
|
||||
text_parts: list[str] = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
messages.append(ChatCompletionMessage(role=msg.role, content=content))
|
||||
|
||||
return ChatCompletionTaskParams(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
)
|
||||
|
||||
|
||||
def chat_response_to_claude_response(
|
||||
response: ChatCompletionResponse,
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Convert internal ChatCompletionResponse to Claude Messages API response."""
|
||||
content_text = ""
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
|
||||
if response.choices:
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
|
||||
content_text = (
|
||||
choice.message.content
|
||||
if isinstance(choice.message.content, str)
|
||||
else str(choice.message.content)
|
||||
)
|
||||
stop_reason = finish_reason_to_claude_stop_reason(choice.finish_reason)
|
||||
|
||||
# Use actual usage data from response if available
|
||||
input_tokens = response.usage.prompt_tokens if response.usage else 0
|
||||
output_tokens = response.usage.completion_tokens if response.usage else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
id=f"msg_{response.id}",
|
||||
model=response.model,
|
||||
content=[ClaudeTextBlock(text=content_text)],
|
||||
stop_reason=stop_reason,
|
||||
usage=ClaudeUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Claude Messages API streaming events from TokenChunks."""
|
||||
# Initial message_start event
|
||||
initial_message = ClaudeMessageStart(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=[],
|
||||
stop_reason=None,
|
||||
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
|
||||
)
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
|
||||
# content_block_stop
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
delta=ClaudeMessageDelta(stop_reason=stop_reason),
|
||||
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
|
||||
)
|
||||
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
|
||||
|
||||
# message_stop
|
||||
message_stop = ClaudeMessageStopEvent()
|
||||
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"
|
||||
199
src/exo/master/adapters/responses.py
Normal file
199
src/exo/master/adapters/responses.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""OpenAI Responses API adapter for converting requests/responses."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseMessageItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
|
||||
|
||||
def responses_request_to_chat_params(
|
||||
request: ResponsesRequest,
|
||||
) -> ChatCompletionTaskParams:
|
||||
"""Convert OpenAI Responses API request to internal ChatCompletionTaskParams."""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
# Add instructions as system message if present
|
||||
if request.instructions:
|
||||
messages.append(
|
||||
ChatCompletionMessage(role="system", content=request.instructions)
|
||||
)
|
||||
|
||||
# Convert input to messages
|
||||
if isinstance(request.input, str):
|
||||
messages.append(ChatCompletionMessage(role="user", content=request.input))
|
||||
else:
|
||||
for msg in request.input:
|
||||
messages.append(
|
||||
ChatCompletionMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
)
|
||||
)
|
||||
|
||||
return ChatCompletionTaskParams(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
max_tokens=request.max_output_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stream=request.stream,
|
||||
)
|
||||
|
||||
|
||||
def chat_response_to_responses_response(
|
||||
response: ChatCompletionResponse,
|
||||
) -> ResponsesResponse:
|
||||
"""Convert internal ChatCompletionResponse to OpenAI Responses API response."""
|
||||
output_text = ""
|
||||
|
||||
if response.choices:
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
|
||||
output_text = (
|
||||
choice.message.content
|
||||
if isinstance(choice.message.content, str)
|
||||
else str(choice.message.content)
|
||||
)
|
||||
|
||||
item_id = f"item_{response.id}"
|
||||
output_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=output_text)],
|
||||
)
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=response.usage.prompt_tokens,
|
||||
output_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return ResponsesResponse(
|
||||
id=f"resp_{response.id}",
|
||||
model=response.model,
|
||||
output=[output_item],
|
||||
output_text=output_text,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
|
||||
# response.created
|
||||
initial_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="in_progress",
|
||||
output=[],
|
||||
output_text="",
|
||||
)
|
||||
created_event = ResponseCreatedEvent(response=initial_response)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(response=initial_response)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
output_index=0, content_index=0, part=initial_part
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
|
||||
accumulated_text = ""
|
||||
last_stats = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
output_index=0, content_index=0, text=accumulated_text
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
output_index=0, content_index=0, part=final_part
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.done
|
||||
final_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=[final_item],
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
completed_event = ResponseCompletedEvent(response=final_response)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
@@ -14,6 +14,16 @@ from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
chat_response_to_claude_response,
|
||||
claude_request_to_chat_params,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.master.adapters.responses import (
|
||||
chat_response_to_responses_response,
|
||||
generate_responses_stream,
|
||||
responses_request_to_chat_params,
|
||||
)
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
@@ -31,6 +41,8 @@ from exo.shared.types.api import (
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -39,6 +51,10 @@ from exo.shared.types.api import (
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessagesResponse,
|
||||
)
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
@@ -52,6 +68,10 @@ from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
@@ -65,6 +85,20 @@ from exo.utils.event_buffer import OrderedBuffer
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
# Build logprobs if available
|
||||
logprobs: Logprobs | None = None
|
||||
if chunk.logprob is not None:
|
||||
logprobs = Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
bytes=list(chunk.text.encode("utf-8")),
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -73,6 +107,7 @@ def chunk_to_response(
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -168,6 +203,8 @@ class API:
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
|
||||
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -548,6 +585,75 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def claude_messages(
|
||||
self, payload: ClaudeMessagesRequest
|
||||
) -> ClaudeMessagesResponse | StreamingResponse:
|
||||
"""Handle Claude Messages API requests."""
|
||||
chat_params = claude_request_to_chat_params(payload)
|
||||
model_meta = await resolve_model_meta(chat_params.model)
|
||||
chat_params.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == chat_params.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(chat_params.model)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No instance found for model {chat_params.model}",
|
||||
)
|
||||
|
||||
command = ChatCompletion(request_params=chat_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_claude_stream(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._chat_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
response = await self._collect_chat_completion(command.command_id)
|
||||
return chat_response_to_claude_response(response)
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
) -> ResponsesResponse | StreamingResponse:
|
||||
"""Handle OpenAI Responses API requests."""
|
||||
chat_params = responses_request_to_chat_params(payload)
|
||||
|
||||
model_meta = await resolve_model_meta(chat_params.model)
|
||||
chat_params.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == chat_params.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(chat_params.model)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No instance found for model {chat_params.model}",
|
||||
)
|
||||
|
||||
command = ChatCompletion(request_params=chat_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_responses_stream(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._chat_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
response = await self._collect_chat_completion(command.command_id)
|
||||
return chat_response_to_responses_response(response)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
@@ -612,9 +718,17 @@ class API:
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
try:
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
except (anyio.BrokenResourceError, KeyError):
|
||||
# Client disconnected, queue was closed/removed - this is expected
|
||||
# when clients abort requests (e.g., regenerate from token)
|
||||
logger.debug(
|
||||
f"Client disconnected for command {event.command_id}, "
|
||||
"dropping chunk"
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
392
src/exo/master/tests/test_claude_api.py
Normal file
392
src/exo/master/tests/test_claude_api.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Tests for Claude Messages API conversion functions and types."""
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
chat_response_to_claude_response,
|
||||
claude_request_to_chat_params,
|
||||
finish_reason_to_claude_stop_reason,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
ClaudeContentBlockStartEvent,
|
||||
ClaudeContentBlockStopEvent,
|
||||
ClaudeMessage,
|
||||
ClaudeMessageDelta,
|
||||
ClaudeMessageDeltaEvent,
|
||||
ClaudeMessageDeltaUsage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessageStart,
|
||||
ClaudeMessageStartEvent,
|
||||
ClaudeMessageStopEvent,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeUsage,
|
||||
)
|
||||
|
||||
|
||||
class TestFinishReasonToClaudeStopReason:
|
||||
"""Tests for finish_reason to Claude stop_reason mapping."""
|
||||
|
||||
def test_stop_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
|
||||
|
||||
def test_length_maps_to_max_tokens(self):
|
||||
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
|
||||
|
||||
def test_tool_calls_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
|
||||
|
||||
def test_function_call_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
|
||||
|
||||
def test_content_filter_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert finish_reason_to_claude_stop_reason(None) is None
|
||||
|
||||
|
||||
class TestClaudeRequestToChatParams:
|
||||
"""Tests for converting Claude Messages API requests to ChatCompletionTaskParams."""
|
||||
|
||||
def test_basic_request_conversion(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert params.model == "claude-3-opus"
|
||||
assert params.max_tokens == 100
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello"
|
||||
|
||||
def test_request_with_system_string(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
system="You are a helpful assistant.",
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are a helpful assistant."
|
||||
assert params.messages[1].role == "user"
|
||||
assert params.messages[1].content == "Hello"
|
||||
|
||||
def test_request_with_system_text_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
system=[
|
||||
ClaudeTextBlock(text="You are helpful. "),
|
||||
ClaudeTextBlock(text="Be concise."),
|
||||
],
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are helpful. Be concise."
|
||||
|
||||
def test_request_with_content_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(
|
||||
role="user",
|
||||
content=[
|
||||
ClaudeTextBlock(text="First part. "),
|
||||
ClaudeTextBlock(text="Second part."),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].content == "First part. Second part."
|
||||
|
||||
def test_request_with_multi_turn_conversation(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
ClaudeMessage(role="assistant", content="Hi there!"),
|
||||
ClaudeMessage(role="user", content="How are you?"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 3
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[1].role == "assistant"
|
||||
assert params.messages[2].role == "user"
|
||||
|
||||
def test_request_with_optional_parameters(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[ClaudeMessage(role="user", content="Hello")],
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
stop_sequences=["STOP", "END"],
|
||||
stream=True,
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert params.temperature == 0.7
|
||||
assert params.top_p == 0.9
|
||||
assert params.top_k == 40
|
||||
assert params.stop == ["STOP", "END"]
|
||||
assert params.stream is True
|
||||
|
||||
|
||||
class TestChatResponseToClaudeResponse:
|
||||
"""Tests for converting ChatCompletionResponse to Claude Messages API response."""
|
||||
|
||||
def test_basic_response_conversion(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello! How can I help you?",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(prompt_tokens=10, completion_tokens=7, total_tokens=17),
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.id == "msg_chatcmpl-123"
|
||||
assert claude_response.model == "llama-3.2-1b"
|
||||
assert claude_response.role == "assistant"
|
||||
assert claude_response.type == "message"
|
||||
assert len(claude_response.content) == 1
|
||||
assert claude_response.content[0].type == "text"
|
||||
assert claude_response.content[0].text == "Hello! How can I help you?"
|
||||
assert claude_response.stop_reason == "end_turn"
|
||||
assert claude_response.usage.input_tokens == 10
|
||||
assert claude_response.usage.output_tokens == 7
|
||||
|
||||
def test_response_with_length_finish_reason(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content="Truncated..."
|
||||
),
|
||||
finish_reason="length",
|
||||
)
|
||||
],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.stop_reason == "max_tokens"
|
||||
|
||||
def test_response_with_empty_content(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(prompt_tokens=10, completion_tokens=0, total_tokens=10),
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == ""
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
def test_response_with_no_choices(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == ""
|
||||
assert claude_response.stop_reason is None
|
||||
assert claude_response.usage.input_tokens == 0
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
def test_response_without_usage(self):
|
||||
"""Test response conversion when usage data is not available."""
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == "Hello!"
|
||||
assert claude_response.usage.input_tokens == 0
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
|
||||
class TestClaudeMessagesRequestValidation:
|
||||
"""Tests for Claude Messages API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"max_tokens": 100,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_max_tokens(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_messages(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestClaudeStreamingEvents:
|
||||
"""Tests for Claude Messages API streaming event serialization."""
|
||||
|
||||
def test_message_start_event_format(self):
|
||||
message = ClaudeMessageStart(
|
||||
id="msg_123",
|
||||
model="claude-3-opus",
|
||||
content=[],
|
||||
stop_reason=None,
|
||||
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
|
||||
)
|
||||
event = ClaudeMessageStartEvent(message=message)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_start"
|
||||
assert parsed["message"]["id"] == "msg_123"
|
||||
assert parsed["message"]["type"] == "message"
|
||||
assert parsed["message"]["role"] == "assistant"
|
||||
assert parsed["message"]["model"] == "claude-3-opus"
|
||||
|
||||
def test_content_block_start_event_format(self):
|
||||
event = ClaudeContentBlockStartEvent(
|
||||
index=0,
|
||||
content_block=ClaudeTextBlock(text=""),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_start"
|
||||
assert parsed["index"] == 0
|
||||
assert parsed["content_block"]["type"] == "text"
|
||||
assert parsed["content_block"]["text"] == ""
|
||||
|
||||
def test_content_block_delta_event_format(self):
|
||||
event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text="Hello"),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_delta"
|
||||
assert parsed["index"] == 0
|
||||
assert parsed["delta"]["type"] == "text_delta"
|
||||
assert parsed["delta"]["text"] == "Hello"
|
||||
|
||||
def test_content_block_stop_event_format(self):
|
||||
event = ClaudeContentBlockStopEvent(index=0)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_stop"
|
||||
assert parsed["index"] == 0
|
||||
|
||||
def test_message_delta_event_format(self):
|
||||
event = ClaudeMessageDeltaEvent(
|
||||
delta=ClaudeMessageDelta(stop_reason="end_turn"),
|
||||
usage=ClaudeMessageDeltaUsage(output_tokens=25),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_delta"
|
||||
assert parsed["delta"]["stop_reason"] == "end_turn"
|
||||
assert parsed["usage"]["output_tokens"] == 25
|
||||
|
||||
def test_message_stop_event_format(self):
|
||||
event = ClaudeMessageStopEvent()
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_stop"
|
||||
|
||||
def test_sse_format(self):
|
||||
"""Test that SSE format is correctly generated."""
|
||||
event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text="Hello"),
|
||||
)
|
||||
# Simulate the SSE format used in the streaming generator
|
||||
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
|
||||
|
||||
assert sse_line.startswith("event: content_block_delta\n")
|
||||
assert "data: " in sse_line
|
||||
assert sse_line.endswith("\n\n")
|
||||
414
src/exo/master/tests/test_openai_responses_api.py
Normal file
414
src/exo/master/tests/test_openai_responses_api.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""Tests for OpenAI Responses API conversion functions and types."""
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.master.adapters.responses import (
|
||||
chat_response_to_responses_response,
|
||||
responses_request_to_chat_params,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInputMessage,
|
||||
ResponseMessageItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
)
|
||||
|
||||
|
||||
class TestResponsesRequestToChatParams:
|
||||
"""Tests for converting OpenAI Responses API requests to ChatCompletionTaskParams."""
|
||||
|
||||
def test_string_input_conversion(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello, how are you?",
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert params.model == "gpt-4o"
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello, how are you?"
|
||||
|
||||
def test_message_array_input_conversion(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
ResponseInputMessage(role="assistant", content="Hi there!"),
|
||||
ResponseInputMessage(role="user", content="How are you?"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 3
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello"
|
||||
assert params.messages[1].role == "assistant"
|
||||
assert params.messages[1].content == "Hi there!"
|
||||
assert params.messages[2].role == "user"
|
||||
assert params.messages[2].content == "How are you?"
|
||||
|
||||
def test_request_with_instructions(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
instructions="You are a helpful assistant. Be concise.",
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are a helpful assistant. Be concise."
|
||||
assert params.messages[1].role == "user"
|
||||
assert params.messages[1].content == "Hello"
|
||||
|
||||
def test_request_with_optional_parameters(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
max_output_tokens=500,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
stream=True,
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert params.max_tokens == 500
|
||||
assert params.temperature == 0.8
|
||||
assert params.top_p == 0.95
|
||||
assert params.stream is True
|
||||
|
||||
def test_request_with_system_role_in_messages(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="system", content="Be helpful"),
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[1].role == "user"
|
||||
|
||||
def test_request_with_developer_role(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="developer", content="Internal note"),
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "developer"
|
||||
|
||||
|
||||
class TestChatResponseToResponsesResponse:
|
||||
"""Tests for converting ChatCompletionResponse to OpenAI Responses API response."""
|
||||
|
||||
def test_basic_response_conversion(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello! How can I help you?",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.id == "resp_chatcmpl-123"
|
||||
assert responses_response.object == "response"
|
||||
assert responses_response.model == "llama-3.2-1b"
|
||||
assert responses_response.status == "completed"
|
||||
assert responses_response.output_text == "Hello! How can I help you?"
|
||||
assert len(responses_response.output) == 1
|
||||
assert responses_response.output[0].type == "message"
|
||||
assert responses_response.output[0].role == "assistant"
|
||||
assert len(responses_response.output[0].content) == 1
|
||||
assert responses_response.output[0].content[0].type == "output_text"
|
||||
assert (
|
||||
responses_response.output[0].content[0].text == "Hello! How can I help you?"
|
||||
)
|
||||
|
||||
def test_response_with_usage(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.usage is not None
|
||||
assert responses_response.usage.input_tokens == 10
|
||||
assert responses_response.usage.output_tokens == 5
|
||||
assert responses_response.usage.total_tokens == 15
|
||||
|
||||
def test_response_with_empty_content(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output_text == ""
|
||||
assert responses_response.output[0].content[0].text == ""
|
||||
|
||||
def test_response_with_no_choices(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output_text == ""
|
||||
|
||||
def test_response_without_usage(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.usage is None
|
||||
|
||||
def test_response_item_id_format(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-abc123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output[0].id == "item_chatcmpl-abc123"
|
||||
|
||||
|
||||
class TestResponsesRequestValidation:
|
||||
"""Tests for OpenAI Responses API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"input": "Hello",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_input(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_accepts_string_input(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
)
|
||||
assert request.input == "Hello"
|
||||
|
||||
def test_request_accepts_message_array_input(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[ResponseInputMessage(role="user", content="Hello")],
|
||||
)
|
||||
assert len(request.input) == 1
|
||||
|
||||
|
||||
class TestResponsesStreamingEvents:
|
||||
"""Tests for OpenAI Responses API streaming event serialization."""
|
||||
|
||||
def test_response_created_event_format(self):
|
||||
response = ResponsesResponse(
|
||||
id="resp_123",
|
||||
model="gpt-4o",
|
||||
status="in_progress",
|
||||
output=[],
|
||||
output_text="",
|
||||
)
|
||||
event = ResponseCreatedEvent(response=response)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.created"
|
||||
assert parsed["response"]["id"] == "resp_123"
|
||||
assert parsed["response"]["object"] == "response"
|
||||
assert parsed["response"]["status"] == "in_progress"
|
||||
|
||||
def test_output_item_added_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_item.added"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["item"]["type"] == "message"
|
||||
assert parsed["item"]["id"] == "item_123"
|
||||
assert parsed["item"]["role"] == "assistant"
|
||||
|
||||
def test_content_part_added_event_format(self):
|
||||
part = ResponseOutputText(text="")
|
||||
event = ResponseContentPartAddedEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=part,
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.content_part.added"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["content_index"] == 0
|
||||
assert parsed["part"]["type"] == "output_text"
|
||||
|
||||
def test_text_delta_event_format(self):
|
||||
event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta="Hello",
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_text.delta"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["content_index"] == 0
|
||||
assert parsed["delta"] == "Hello"
|
||||
|
||||
def test_text_done_event_format(self):
|
||||
event = ResponseTextDoneEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
text="Hello, world!",
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_text.done"
|
||||
assert parsed["text"] == "Hello, world!"
|
||||
|
||||
def test_output_item_done_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="Hello, world!")],
|
||||
status="completed",
|
||||
)
|
||||
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_item.done"
|
||||
assert parsed["item"]["status"] == "completed"
|
||||
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
|
||||
|
||||
def test_response_completed_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="Hello!")],
|
||||
status="completed",
|
||||
)
|
||||
response = ResponsesResponse(
|
||||
id="resp_123",
|
||||
model="gpt-4o",
|
||||
status="completed",
|
||||
output=[item],
|
||||
output_text="Hello!",
|
||||
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
|
||||
)
|
||||
event = ResponseCompletedEvent(response=response)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.completed"
|
||||
assert parsed["response"]["status"] == "completed"
|
||||
assert parsed["response"]["output_text"] == "Hello!"
|
||||
assert parsed["response"]["usage"]["total_tokens"] == 15
|
||||
|
||||
def test_sse_format(self):
|
||||
"""Test that SSE format is correctly generated."""
|
||||
event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta="Hello",
|
||||
)
|
||||
# Simulate the SSE format used in the streaming generator
|
||||
sse_line = (
|
||||
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
|
||||
)
|
||||
|
||||
assert sse_line.startswith("event: response.output_text.delta\n")
|
||||
assert "data: " in sse_line
|
||||
assert sse_line.endswith("\n\n")
|
||||
@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
|
||||
|
||||
def logger_setup(log_file: Path | None, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
|
||||
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
logger.remove()
|
||||
|
||||
# replace all stdlib loggers with _InterceptHandlers that log to loguru
|
||||
|
||||
@@ -146,10 +146,12 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
stream: bool = False
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
continue_from_prefix: bool = False # When True, continue the last assistant message
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.shared.types.api import GenerationStats, TopLogprobItem
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -20,6 +20,8 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
logprob: float | None = None # Log probability of the selected token
|
||||
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
168
src/exo/shared/types/claude_api.py
Normal file
168
src/exo/shared/types/claude_api.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Claude Messages API types for request/response conversion."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Type aliases
|
||||
ClaudeRole = Literal["user", "assistant"]
|
||||
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
|
||||
|
||||
|
||||
# Content block types
|
||||
class ClaudeTextBlock(BaseModel, frozen=True):
|
||||
"""Text content block in Claude Messages API."""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeImageSource(BaseModel, frozen=True):
|
||||
"""Image source for Claude image blocks."""
|
||||
|
||||
type: Literal["base64", "url"]
|
||||
media_type: str | None = None
|
||||
data: str | None = None
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class ClaudeImageBlock(BaseModel, frozen=True):
|
||||
"""Image content block in Claude Messages API."""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
source: ClaudeImageSource
|
||||
|
||||
|
||||
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
|
||||
|
||||
|
||||
# Request types
|
||||
class ClaudeMessage(BaseModel, frozen=True):
|
||||
"""Message in Claude Messages API request."""
|
||||
|
||||
role: ClaudeRole
|
||||
content: str | list[ClaudeContentBlock]
|
||||
|
||||
|
||||
class ClaudeMessagesRequest(BaseModel):
|
||||
"""Request body for Claude Messages API."""
|
||||
|
||||
model: str
|
||||
max_tokens: int
|
||||
messages: list[ClaudeMessage]
|
||||
system: str | list[ClaudeTextBlock] | None = None
|
||||
stop_sequences: list[str] | None = None
|
||||
stream: bool = False
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
class ClaudeUsage(BaseModel, frozen=True):
|
||||
"""Token usage in Claude Messages API response."""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
||||
|
||||
class ClaudeMessagesResponse(BaseModel, frozen=True):
|
||||
"""Response body for Claude Messages API."""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ClaudeTextBlock]
|
||||
model: str
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
usage: ClaudeUsage
|
||||
|
||||
|
||||
# Streaming event types
|
||||
class ClaudeMessageStart(BaseModel, frozen=True):
|
||||
"""Partial message in message_start event."""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ClaudeTextBlock] = Field(default_factory=list)
|
||||
model: str
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
usage: ClaudeUsage
|
||||
|
||||
|
||||
class ClaudeMessageStartEvent(BaseModel, frozen=True):
|
||||
"""Event sent at start of message stream."""
|
||||
|
||||
type: Literal["message_start"] = "message_start"
|
||||
message: ClaudeMessageStart
|
||||
|
||||
|
||||
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
|
||||
"""Event sent at start of a content block."""
|
||||
|
||||
type: Literal["content_block_start"] = "content_block_start"
|
||||
index: int
|
||||
content_block: ClaudeTextBlock
|
||||
|
||||
|
||||
class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
"""Delta for text content block."""
|
||||
|
||||
type: Literal["text_delta"] = "text_delta"
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for content block delta."""
|
||||
|
||||
type: Literal["content_block_delta"] = "content_block_delta"
|
||||
index: int
|
||||
delta: ClaudeTextDelta
|
||||
|
||||
|
||||
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
||||
"""Event sent at end of a content block."""
|
||||
|
||||
type: Literal["content_block_stop"] = "content_block_stop"
|
||||
index: int
|
||||
|
||||
|
||||
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
|
||||
"""Usage in message_delta event."""
|
||||
|
||||
output_tokens: int
|
||||
|
||||
|
||||
class ClaudeMessageDelta(BaseModel, frozen=True):
|
||||
"""Delta in message_delta event."""
|
||||
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
|
||||
|
||||
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent with final message delta."""
|
||||
|
||||
type: Literal["message_delta"] = "message_delta"
|
||||
delta: ClaudeMessageDelta
|
||||
usage: ClaudeMessageDeltaUsage
|
||||
|
||||
|
||||
class ClaudeMessageStopEvent(BaseModel, frozen=True):
|
||||
"""Event sent at end of message stream."""
|
||||
|
||||
type: Literal["message_stop"] = "message_stop"
|
||||
|
||||
|
||||
ClaudeStreamEvent = (
|
||||
ClaudeMessageStartEvent
|
||||
| ClaudeContentBlockStartEvent
|
||||
| ClaudeContentBlockDeltaEvent
|
||||
| ClaudeContentBlockStopEvent
|
||||
| ClaudeMessageDeltaEvent
|
||||
| ClaudeMessageStopEvent
|
||||
)
|
||||
162
src/exo/shared/types/openai_responses.py
Normal file
162
src/exo/shared/types/openai_responses.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""OpenAI Responses API types for request/response conversion."""
|
||||
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Type aliases
|
||||
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
|
||||
ResponseRole = Literal["user", "assistant", "system", "developer"]
|
||||
|
||||
|
||||
# Request types
|
||||
class ResponseInputMessage(BaseModel, frozen=True):
|
||||
"""Input message for Responses API."""
|
||||
|
||||
role: ResponseRole
|
||||
content: str
|
||||
|
||||
|
||||
class ResponsesRequest(BaseModel):
|
||||
"""Request body for OpenAI Responses API."""
|
||||
|
||||
model: str
|
||||
input: str | list[ResponseInputMessage]
|
||||
instructions: str | None = None
|
||||
max_output_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
stream: bool = False
|
||||
# previous_response_id not supported in MVP
|
||||
metadata: dict[str, str] | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
class ResponseOutputText(BaseModel, frozen=True):
|
||||
"""Text content in response output."""
|
||||
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
annotations: list[dict[str, str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ResponseMessageItem(BaseModel, frozen=True):
|
||||
"""Message item in response output array."""
|
||||
|
||||
type: Literal["message"] = "message"
|
||||
id: str
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ResponseOutputText]
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel, frozen=True):
|
||||
"""Token usage in Responses API response."""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ResponsesResponse(BaseModel, frozen=True):
|
||||
"""Response body for OpenAI Responses API."""
|
||||
|
||||
id: str
|
||||
object: Literal["response"] = "response"
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
status: ResponseStatus = "completed"
|
||||
model: str
|
||||
output: list[ResponseItem]
|
||||
output_text: str
|
||||
usage: ResponseUsage | None = None
|
||||
|
||||
|
||||
# Streaming event types
|
||||
class ResponseCreatedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is created."""
|
||||
|
||||
type: Literal["response.created"] = "response.created"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
class ResponseInProgressEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response starts processing."""
|
||||
|
||||
type: Literal["response.in_progress"] = "response.in_progress"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when an output item is added."""
|
||||
|
||||
type: Literal["response.output_item.added"] = "response.output_item.added"
|
||||
output_index: int
|
||||
item: ResponseItem
|
||||
|
||||
|
||||
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a content part is added."""
|
||||
|
||||
type: Literal["response.content_part.added"] = "response.content_part.added"
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ResponseOutputText
|
||||
|
||||
|
||||
class ResponseTextDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for text delta during streaming."""
|
||||
|
||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseTextDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when text content is done."""
|
||||
|
||||
type: Literal["response.output_text.done"] = "response.output_text.done"
|
||||
output_index: int
|
||||
content_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a content part is done."""
|
||||
|
||||
type: Literal["response.content_part.done"] = "response.content_part.done"
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ResponseOutputText
|
||||
|
||||
|
||||
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when an output item is done."""
|
||||
|
||||
type: Literal["response.output_item.done"] = "response.output_item.done"
|
||||
output_index: int
|
||||
item: ResponseItem
|
||||
|
||||
|
||||
class ResponseCompletedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is completed."""
|
||||
|
||||
type: Literal["response.completed"] = "response.completed"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
ResponsesStreamEvent = (
|
||||
ResponseCreatedEvent
|
||||
| ResponseInProgressEvent
|
||||
| ResponseOutputItemAddedEvent
|
||||
| ResponseContentPartAddedEvent
|
||||
| ResponseTextDeltaEvent
|
||||
| ResponseTextDoneEvent
|
||||
| ResponseContentPartDoneEvent
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseCompletedEvent
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, TopLogprobItem
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ class TokenizedResponse(BaseRunnerResponse):
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
logprob: float | None = None # Log probability of the selected token
|
||||
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -7,13 +7,13 @@ import time
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, cast
|
||||
from typing import Callable, Literal
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import certifi
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -207,22 +207,23 @@ async def _fetch_file_list(
|
||||
headers = await get_download_headers()
|
||||
async with (
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.get(url, headers=headers) as response,
|
||||
):
|
||||
response = await session.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch file list: {response.status_code}")
|
||||
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
|
||||
files: list[FileListEntry] = []
|
||||
for item in data:
|
||||
if item.type == "file":
|
||||
files.append(FileListEntry.model_validate(item))
|
||||
elif item.type == "directory" and recursive:
|
||||
subfiles = await _fetch_file_list(
|
||||
repo_id, revision, item.path, recursive
|
||||
)
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
if response.status == 200:
|
||||
data_json = await response.text()
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
|
||||
files: list[FileListEntry] = []
|
||||
for item in data:
|
||||
if item.type == "file":
|
||||
files.append(FileListEntry.model_validate(item))
|
||||
elif item.type == "directory" and recursive:
|
||||
subfiles = await _fetch_file_list(
|
||||
repo_id, revision, item.path, recursive
|
||||
)
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
|
||||
async def get_download_headers() -> dict[str, str]:
|
||||
@@ -230,25 +231,31 @@ async def get_download_headers() -> dict[str, str]:
|
||||
|
||||
|
||||
def create_http_session(
|
||||
auto_decompress: bool = False,
|
||||
timeout_profile: Literal["short", "long"] = "long",
|
||||
) -> httpx.AsyncClient:
|
||||
) -> aiohttp.ClientSession:
|
||||
if timeout_profile == "short":
|
||||
total_timeout = 30
|
||||
connect_timeout = 10
|
||||
read_timeout = 30
|
||||
sock_read_timeout = 30
|
||||
sock_connect_timeout = 10
|
||||
else:
|
||||
total_timeout = 1800
|
||||
connect_timeout = 60
|
||||
read_timeout = 1800
|
||||
sock_read_timeout = 1800
|
||||
sock_connect_timeout = 60
|
||||
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
return httpx.AsyncClient(
|
||||
verify=ssl_context,
|
||||
timeout=httpx.Timeout(
|
||||
return aiohttp.ClientSession(
|
||||
auto_decompress=auto_decompress,
|
||||
connector=connector,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=total_timeout,
|
||||
connect=connect_timeout,
|
||||
read=read_timeout,
|
||||
write=total_timeout,
|
||||
sock_read=sock_read_timeout,
|
||||
sock_connect=sock_connect_timeout,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -275,25 +282,23 @@ async def file_meta(
|
||||
headers = await get_download_headers()
|
||||
async with (
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.head(url, headers=headers) as r,
|
||||
):
|
||||
r = await session.head(url, headers=headers)
|
||||
if r.status_code == 307:
|
||||
if r.status == 307:
|
||||
# On redirect, only trust Hugging Face's x-linked-* headers.
|
||||
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
|
||||
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
|
||||
x_linked_size = r.headers.get("x-linked-size")
|
||||
x_linked_etag = r.headers.get("x-linked-etag")
|
||||
if x_linked_size and x_linked_etag:
|
||||
content_length = int(x_linked_size)
|
||||
etag = trim_etag(x_linked_etag)
|
||||
return content_length, etag
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = cast(str | None, r.headers.get("location"))
|
||||
redirected_location = r.headers.get("location")
|
||||
return await file_meta(repo_id, revision, path, redirected_location)
|
||||
content_length = cast(
|
||||
str | None,
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length"),
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
content_length = 0 if content_length is None else int(content_length)
|
||||
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
|
||||
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
etag = trim_etag(etag)
|
||||
@@ -352,17 +357,17 @@ async def _download_file(
|
||||
n_read = resume_byte_pos or 0
|
||||
async with (
|
||||
create_http_session(timeout_profile="long") as session,
|
||||
session.get(url, headers=headers) as r,
|
||||
):
|
||||
r = await session.get(url, headers=headers)
|
||||
if r.status_code == 404:
|
||||
if r.status == 404:
|
||||
raise FileNotFoundError(f"File not found: {url}")
|
||||
assert r.status_code in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status_code}"
|
||||
assert r.status in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status}"
|
||||
)
|
||||
async with aiofiles.open(
|
||||
partial_path, "ab" if resume_byte_pos else "wb"
|
||||
) as f:
|
||||
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
|
||||
while chunk := await r.content.read(8 * 1024 * 1024):
|
||||
n_read = n_read + (await f.write(chunk))
|
||||
on_progress(n_read, length, False)
|
||||
|
||||
|
||||
@@ -40,4 +40,6 @@ class TokenizerWrapper:
|
||||
messages_dicts: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> str: ...
|
||||
|
||||
@@ -12,6 +12,7 @@ from exo.shared.types.api import (
|
||||
ChatCompletionMessage,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
@@ -115,6 +116,60 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_k: int,
|
||||
selected_token: int,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top-k alternative tokens.
|
||||
|
||||
Args:
|
||||
logprobs: Full vocabulary logprobs array from MLX
|
||||
tokenizer: Tokenizer for decoding token IDs to strings
|
||||
top_k: Number of top alternatives to return
|
||||
selected_token: The token ID that was actually sampled
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_token_logprob, list of TopLogprobItem for top-k tokens)
|
||||
"""
|
||||
# Get the logprob of the selected token
|
||||
selected_logprob = float(logprobs[selected_token].item())
|
||||
|
||||
# Get top-k indices (most probable tokens)
|
||||
# mx.argpartition gives indices that would partition the array
|
||||
# We negate logprobs since argpartition finds smallest, and we want largest
|
||||
top_k = min(top_k, logprobs.shape[0]) # Don't exceed vocab size
|
||||
top_indices = mx.argpartition(-logprobs, top_k)[:top_k]
|
||||
|
||||
# Get the actual logprob values for these indices
|
||||
top_values = logprobs[top_indices]
|
||||
|
||||
# Sort by logprob (descending) for consistent ordering
|
||||
sort_order = mx.argsort(-top_values)
|
||||
top_indices = top_indices[sort_order]
|
||||
top_values = top_values[sort_order]
|
||||
|
||||
# Convert to list of TopLogprobItem
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for i in range(top_k):
|
||||
token_id = int(top_indices[i].item())
|
||||
token_logprob = float(top_values[i].item())
|
||||
# Decode token ID to string
|
||||
token_str = tokenizer.decode([token_id])
|
||||
# Get byte representation
|
||||
token_bytes = list(token_str.encode("utf-8"))
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=token_logprob,
|
||||
bytes=token_bytes,
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -146,9 +201,24 @@ def mlx_generate(
|
||||
sampler = make_sampler(
|
||||
temp=task.temperature if task.temperature is not None else 0.7,
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
top_k=task.top_k if task.top_k is not None else 0,
|
||||
)
|
||||
|
||||
# Normalize stop sequences to a list
|
||||
stop_sequences: list[str] = (
|
||||
([task.stop] if isinstance(task.stop, str) else task.stop)
|
||||
if task.stop is not None
|
||||
else []
|
||||
)
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
|
||||
# Determine if we need to extract logprobs
|
||||
should_extract_logprobs = task.logprobs is True
|
||||
num_top_logprobs = task.top_logprobs if task.top_logprobs is not None else 5
|
||||
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -163,9 +233,41 @@ def mlx_generate(
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
# Check for stop sequences
|
||||
text = out.text
|
||||
finish_reason: FinishReason | None = cast(
|
||||
FinishReason | None, out.finish_reason
|
||||
)
|
||||
stop_matched = False
|
||||
|
||||
if stop_sequences:
|
||||
for stop_seq in stop_sequences:
|
||||
if stop_seq in accumulated_text:
|
||||
# Trim text to just before the stop sequence
|
||||
stop_index = accumulated_text.find(stop_seq)
|
||||
text_before_stop = accumulated_text[:stop_index]
|
||||
chunk_start = len(accumulated_text) - len(out.text)
|
||||
text = text_before_stop[chunk_start:]
|
||||
finish_reason = "stop"
|
||||
stop_matched = True
|
||||
break
|
||||
|
||||
# Extract logprobs if requested
|
||||
token_logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if should_extract_logprobs:
|
||||
token_logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs=out.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_k=num_top_logprobs,
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
is_done = finish_reason is not None
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
if is_done:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
@@ -173,22 +275,25 @@ def mlx_generate(
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
if not stop_matched and out.finish_reason not in get_args(FinishReason):
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
text=text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
logprob=token_logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
if is_done:
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -359,12 +359,26 @@ def apply_chat_template(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
# Use continue_final_message when continuing from prefix (e.g., regenerate from token)
|
||||
# This keeps the final assistant message open without EOS tokens
|
||||
# Note: explicitly set add_generation_prompt=False when using continue_final_message
|
||||
# because some tokenizers (e.g., Kimi) default add_generation_prompt=True
|
||||
prompt: str
|
||||
if chat_task_data.continue_from_prefix:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
continue_final_message=True,
|
||||
add_generation_prompt=False,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
else:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
@@ -186,6 +186,8 @@ def main(
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
|
||||
@@ -1,63 +1,60 @@
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio import create_task_group
|
||||
import http.client
|
||||
import time
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
REACHABILITY_ATTEMPTS = 3
|
||||
BAD_STATUSLINE_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
if ":" in target_ip:
|
||||
# TODO: use real IpAddress types
|
||||
target_ip = f"[{target_ip}]"
|
||||
url = f"http://{target_ip}:52415/node_id"
|
||||
|
||||
remote_node_id = None
|
||||
|
||||
last_error = None
|
||||
|
||||
for _ in range(REACHABILITY_ATTEMPTS):
|
||||
# TODO: use an async http client
|
||||
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
|
||||
try:
|
||||
r = await client.get(url)
|
||||
if r.status_code != 200:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = r.text.strip().strip('"')
|
||||
if not body:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
remote_node_id = NodeId(body)
|
||||
break
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
except (
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
) as e:
|
||||
last_error = e
|
||||
await anyio.sleep(1)
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.BadStatusLine:
|
||||
if attempt >= BAD_STATUSLINE_ATTEMPTS:
|
||||
logger.warning(
|
||||
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
|
||||
)
|
||||
return None
|
||||
time.sleep(1)
|
||||
return _fetch_remote_node_id(attempt=attempt + 1)
|
||||
except http.client.HTTPException as e:
|
||||
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
else:
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"malformed response from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
@@ -77,33 +74,18 @@ async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
|
||||
# these are intentionally httpx's defaults so we can tune them later
|
||||
timeout = httpx.Timeout(timeout=5.0)
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=5,
|
||||
)
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
continue
|
||||
if node.node_id == self_node_id:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
client,
|
||||
)
|
||||
|
||||
return reachable
|
||||
|
||||
Reference in New Issue
Block a user