Compare commits

..

7 Commits

Author SHA1 Message Date
Alex Cheema
302e67c8c0 fix: explicitly disable add_generation_prompt when using continue_final_message
Some tokenizers (e.g., Kimi-K2-Thinking) default add_generation_prompt=True,
which conflicts with continue_final_message=True. Explicitly set
add_generation_prompt=False to avoid the ValueError.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 22:48:23 +00:00
Alex Cheema
63ec56c696 style: format app.svelte.ts
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 20:50:12 +00:00
Alex Cheema
578a417a7e chore: update package-lock.json
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 20:30:36 +00:00
Alex Cheema
57ec2c9011 feat: add regenerate from token with proper continuation
- Add continue_from_prefix flag to ChatCompletionTaskParams
- Use continue_final_message in apply_chat_template to keep
  assistant turn open (no EOS tokens) when continuing from prefix
- Add AbortController to cancel in-flight requests when regenerating
- Handle BrokenResourceError server-side when client disconnects
- Improve TokenHeatmap hover stability during generation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 19:43:49 +00:00
Alex Cheema
c6963cafa0 fix: collect logprobs in sendMessage and refine heatmap styling
- Fixed bug where sendMessage wasn't extracting logprobs from streaming
  response (only regenerate was doing this)
- Refined TokenHeatmap styling following Apple design principles:
  - High confidence tokens (>80%) blend in completely
  - Only uncertain tokens draw attention with subtle highlights
  - Tooltip uses contextual colors based on confidence level

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 17:51:18 +00:00
Alex Cheema
663a0faaeb feat: add uncertainty visualization with token-level logprobs
Wire log probabilities from MLX through the API to enable uncertainty
visualization in the dashboard:

Backend:
- Extract top-k logprobs from MLX stream_generate output
- Add logprob and top_logprobs fields to GenerationResponse and TokenChunk
- Populate Logprobs in streaming API response when requested

Dashboard:
- Add TokenHeatmap component with color-coded token confidence
- Parse logprobs from SSE responses and store on messages
- Add toggle button to switch between normal and uncertainty view
- Hover tooltip shows exact probability and top-5 alternatives

Color scheme:
- Green (>80%): High confidence
- Yellow (50-80%): Medium confidence
- Orange (20-50%): Low confidence
- Red (<20%): Very low confidence

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 16:45:23 +00:00
Alex Cheema
0a58aa73ec feat: add Claude Messages API and OpenAI Responses API support
Adds two new API endpoints that wrap the existing chat completions:

- /v1/messages - Claude Messages API compatible endpoint
- /v1/responses - OpenAI Responses API compatible endpoint

Both support streaming (SSE) and non-streaming modes with proper
token usage reporting from actual inference stats.

Also adds top_k sampling parameter and stop sequence support to the
MLX inference engine.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 15:14:21 +00:00
27 changed files with 3283 additions and 639 deletions

View File

@@ -56,11 +56,6 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
let result = await self.performCheck()
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:

View File

@@ -241,9 +241,6 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn

View File

@@ -863,7 +863,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -903,7 +902,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,7 +1518,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1530,7 +1527,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1943,7 +1939,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2651,7 +2646,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2839,7 +2833,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2984,7 +2977,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3006,7 +2998,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -1,14 +1,16 @@
<script lang="ts">
import {
messages,
currentResponse,
import {
messages,
currentResponse,
isLoading,
deleteMessage,
editAndRegenerate,
regenerateLastResponse
regenerateLastResponse,
regenerateFromToken
} from '$lib/stores/app.svelte';
import type { MessageAttachment } from '$lib/stores/app.svelte';
import MarkdownContent from './MarkdownContent.svelte';
import TokenHeatmap from './TokenHeatmap.svelte';
interface Props {
class?: string;
@@ -95,6 +97,23 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty view state - tracks which messages show token heatmap
let uncertaintyViewMessageIds = $state<Set<string>>(new Set());
function toggleUncertaintyView(messageId: string) {
const newSet = new Set(uncertaintyViewMessageIds);
if (newSet.has(messageId)) {
newSet.delete(messageId);
} else {
newSet.add(messageId);
}
uncertaintyViewMessageIds = newSet;
}
function isUncertaintyViewEnabled(messageId: string): boolean {
return uncertaintyViewMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString('en-US', {
hour12: false,
@@ -366,7 +385,17 @@ function isThinkingExpanded(messageId: string): boolean {
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if message.role === 'assistant' && isUncertaintyViewEnabled(message.id) && message.tokens && message.tokens.length > 0}
<!-- Uncertainty heatmap view -->
<TokenHeatmap
tokens={message.tokens}
isGenerating={loading}
onRegenerateFrom={(tokenIndex) => regenerateFromToken(message.id, tokenIndex)}
/>
{:else}
<!-- Normal markdown view -->
<MarkdownContent content={message.content || (loading ? response : '')} />
{/if}
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
@@ -419,6 +448,19 @@ function isThinkingExpanded(messageId: string): boolean {
</svg>
</button>
{/if}
<!-- Uncertainty view toggle (assistant messages with tokens only) -->
{#if message.role === 'assistant' && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleUncertaintyView(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isUncertaintyViewEnabled(message.id) ? 'text-exo-yellow' : 'text-exo-light-gray hover:text-exo-yellow'}"
title={isUncertaintyViewEnabled(message.id) ? 'Hide uncertainty' : 'Show uncertainty'}
>
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
</svg>
</button>
{/if}
<!-- Delete button -->
<button

View File

@@ -0,0 +1,192 @@
<script lang="ts">
import type { TokenData } from '$lib/stores/app.svelte';
interface Props {
tokens: TokenData[];
class?: string;
isGenerating?: boolean;
onRegenerateFrom?: (tokenIndex: number) => void;
}
let { tokens, class: className = '', isGenerating = false, onRegenerateFrom }: Props = $props();
// Tooltip state - track both token data and index
let hoveredTokenIndex = $state<number | null>(null);
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
let isTooltipHovered = $state(false);
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
// Derive the hovered token from the index (stable across re-renders)
const hoveredToken = $derived(
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
? { token: tokens[hoveredTokenIndex], index: hoveredTokenIndex, ...hoveredPosition }
: null
);
/**
* Get confidence styling based on probability.
* Following Apple design principles: high confidence tokens blend in,
* only uncertainty draws attention.
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return 'text-inherit'; // Expected tokens - blend in
if (probability > 0.5) return 'bg-gray-500/10 text-inherit'; // Slight hint
if (probability > 0.2) return 'bg-amber-500/15 text-amber-200/90'; // Subtle warmth
return 'bg-red-500/20 text-red-200/90'; // Draws attention
}
/**
* Get border/underline styling for uncertain tokens
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return 'border-transparent'; // No border for expected
if (probability > 0.5) return 'border-gray-500/20';
if (probability > 0.2) return 'border-amber-500/30';
return 'border-red-500/40';
}
function clearHideTimeout() {
if (hideTimeoutId) {
clearTimeout(hideTimeoutId);
hideTimeoutId = null;
}
}
function handleMouseEnter(event: MouseEvent, token: TokenData, index: number) {
clearHideTimeout();
const rect = (event.target as HTMLElement).getBoundingClientRect();
hoveredTokenIndex = index;
hoveredPosition = {
x: rect.left + rect.width / 2,
y: rect.top - 10
};
}
function handleMouseLeave() {
clearHideTimeout();
// Use longer delay during generation to account for re-renders
const delay = isGenerating ? 300 : 100;
hideTimeoutId = setTimeout(() => {
if (!isTooltipHovered) {
hoveredTokenIndex = null;
hoveredPosition = null;
}
}, delay);
}
function handleTooltipEnter() {
clearHideTimeout();
isTooltipHovered = true;
}
function handleTooltipLeave() {
isTooltipHovered = false;
hoveredTokenIndex = null;
hoveredPosition = null;
}
function handleRegenerate() {
if (hoveredToken && onRegenerateFrom) {
const indexToRegenerate = hoveredToken.index;
// Clear hover state immediately
hoveredTokenIndex = null;
hoveredPosition = null;
isTooltipHovered = false;
// Call regenerate
onRegenerateFrom(indexToRegenerate);
}
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + '%';
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
function getProbabilityColor(probability: number): string {
if (probability > 0.8) return 'text-gray-300';
if (probability > 0.5) return 'text-gray-400';
if (probability > 0.2) return 'text-amber-400';
return 'text-red-400';
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(tokenData.probability)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
onmouseleave={handleMouseLeave}
>{tokenData.token}</span>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
onmouseenter={handleTooltipEnter}
onmouseleave={handleTooltipLeave}
>
<div class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48">
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-500 text-xs">Token:</span>
<span class="text-white font-mono ml-1">"{hoveredToken.token.token}"</span>
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2">{formatProbability(hoveredToken.token.probability)}</span>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono">{formatLogprob(hoveredToken.token.logprob)}</span>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700/50 mt-2 pt-2">
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24">"{alt.token}"</span>
<span class="text-gray-400 ml-2">{formatProbability(altProb)}</span>
</div>
{/each}
</div>
{/if}
<!-- Regenerate button -->
{#if onRegenerateFrom}
<button
onclick={handleRegenerate}
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
>
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
</svg>
Regenerate from here
</button>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -182,6 +182,20 @@ export interface MessageAttachment {
mimeType?: string;
}
// Token-level data for uncertainty visualization
export interface TopLogprob {
token: string;
logprob: number;
bytes?: number[];
}
export interface TokenData {
token: string;
logprob: number;
probability: number; // exp(logprob)
topLogprobs: TopLogprob[];
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -191,6 +205,7 @@ export interface Message {
attachments?: MessageAttachment[];
ttftMs?: number; // Time to first token in ms (for assistant messages)
tps?: number; // Tokens per second (for assistant messages)
tokens?: TokenData[]; // Token-level data for uncertainty visualization
}
export interface Conversation {
@@ -368,6 +383,21 @@ class AppStore {
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
private currentRequestController: AbortController | null = null;
/**
* Abort any in-flight generation request
*/
abortCurrentRequest(): boolean {
if (this.currentRequestController) {
this.currentRequestController.abort();
this.currentRequestController = null;
this.isLoading = false;
this.currentResponse = "";
return true;
}
return false;
}
constructor() {
if (browser) {
@@ -1046,6 +1076,10 @@ class AppStore {
// Remove any messages after the user message
this.messages = this.messages.slice(0, lastUserIndex + 1);
// Create abort controller for this request
const controller = new AbortController();
this.currentRequestController = controller;
// Resend the message to get a new response
this.isLoading = true;
this.currentResponse = "";
@@ -1107,7 +1141,10 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
signal: controller.signal,
});
if (!response.ok) {
@@ -1140,6 +1177,7 @@ class AppStore {
const decoder = new TextDecoder();
let fullContent = "";
let partialLine = "";
const collectedTokens: TokenData[] = [];
while (true) {
const { done, value } = await reader.read();
@@ -1158,6 +1196,29 @@ class AppStore {
const json = JSON.parse(trimmed.slice(6));
const delta = json.choices?.[0]?.delta?.content;
if (delta) {
// Extract logprobs for uncertainty visualization
const logprobsData = json.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
const tokenData: TokenData = {
token: logprobItem.token || delta,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: {
token: string;
logprob: number;
bytes?: number[];
}) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
};
collectedTokens.push(tokenData);
}
fullContent += delta;
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
@@ -1170,6 +1231,7 @@ class AppStore {
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
this.messages[idx].tokens = [...collectedTokens];
}
this.persistActiveConversation();
}
@@ -1187,9 +1249,16 @@ class AppStore {
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
if (collectedTokens.length > 0) {
this.messages[idx].tokens = collectedTokens;
}
}
this.persistActiveConversation();
} catch (error) {
// Don't show error for aborted requests (user cancelled)
if (error instanceof Error && error.name === "AbortError") {
return;
}
const idx = this.messages.findIndex((m) => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content =
@@ -1197,6 +1266,10 @@ class AppStore {
}
this.persistActiveConversation();
} finally {
// Clean up controller if this is still the active request
if (this.currentRequestController === controller) {
this.currentRequestController = null;
}
this.isLoading = false;
this.currentResponse = "";
this.updateActiveConversation();
@@ -1218,6 +1291,210 @@ class AppStore {
this.tps = null;
}
/**
* Regenerate from a specific token in an assistant message.
* Keeps content up to and including the specified token, then continues generation.
* If a generation is already in progress, it will be aborted first.
*/
async regenerateFromToken(
messageId: string,
tokenIndex: number,
): Promise<void> {
// Abort any in-flight request first
this.abortCurrentRequest();
const messageIdx = this.messages.findIndex((m) => m.id === messageId);
if (messageIdx === -1) return;
const message = this.messages[messageIdx];
if (message.role !== "assistant" || !message.tokens) return;
// Get tokens up to and including the specified index
const keptTokens = message.tokens.slice(0, tokenIndex + 1);
const prefixText = keptTokens.map((t) => t.token).join("");
// Update the message with just the prefix
this.messages[messageIdx].content = prefixText;
this.messages[messageIdx].tokens = keptTokens;
this.messages[messageIdx].thinking = undefined;
this.persistActiveConversation();
// Start loading
this.isLoading = true;
this.currentResponse = prefixText;
// Create abort controller for this request
const controller = new AbortController();
this.currentRequestController = controller;
try {
const systemPrompt = {
role: "system" as const,
content:
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
};
// Build messages: all messages before this one, plus the prefix as assistant
const apiMessages: { role: string; content: string }[] = [systemPrompt];
for (let i = 0; i < messageIdx; i++) {
const m = this.messages[i];
apiMessages.push({ role: m.role, content: m.content || "" });
}
// Add the prefix as a partial assistant response to continue from
apiMessages.push({ role: "assistant", content: prefixText });
// Determine which model to use
let modelToUse = this.selectedChatModel;
if (!modelToUse) {
const firstInstanceKey = Object.keys(this.instances)[0];
if (firstInstanceKey) {
const instance = this.instances[firstInstanceKey] as
| Record<string, unknown>
| undefined;
if (instance) {
const keys = Object.keys(instance);
if (keys.length === 1) {
const inst = instance[keys[0]] as
| { shardAssignments?: { modelId?: string } }
| undefined;
modelToUse = inst?.shardAssignments?.modelId || "";
}
}
}
}
if (!modelToUse) {
this.messages[messageIdx].content =
prefixText + "\n\nError: No model available.";
this.isLoading = false;
this.updateActiveConversation();
return;
}
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
continue_from_prefix: true,
}),
signal: controller.signal,
});
if (!response.ok) {
const errorText = await response.text();
this.messages[messageIdx].content =
prefixText + `\n\nError: ${response.status} - ${errorText}`;
this.isLoading = false;
this.updateActiveConversation();
return;
}
const reader = response.body?.getReader();
if (!reader) {
this.messages[messageIdx].content =
prefixText + "\n\nError: No response stream available";
this.isLoading = false;
this.updateActiveConversation();
return;
}
const decoder = new TextDecoder();
let fullContent = prefixText;
let partialLine = "";
const collectedTokens: TokenData[] = [...keptTokens];
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value, { stream: true });
const lines = (partialLine + chunk).split("\n");
partialLine = lines.pop() || "";
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed || trimmed === "data: [DONE]") continue;
if (trimmed.startsWith("data: ")) {
try {
const json = JSON.parse(trimmed.slice(6));
const delta = json.choices?.[0]?.delta?.content;
if (delta) {
// Extract logprobs for uncertainty visualization
const logprobsData = json.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
const tokenData: TokenData = {
token: logprobItem.token || delta,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: {
token: string;
logprob: number;
bytes?: number[];
}) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
};
collectedTokens.push(tokenData);
}
fullContent += delta;
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.currentResponse = displayContent;
this.messages[messageIdx].content = displayContent;
this.messages[messageIdx].thinking =
thinkingContent || undefined;
this.messages[messageIdx].tokens = [...collectedTokens];
this.persistActiveConversation();
}
} catch {
// Skip malformed JSON
}
}
}
}
// Final cleanup
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.messages[messageIdx].content = displayContent;
this.messages[messageIdx].thinking = thinkingContent || undefined;
if (collectedTokens.length > 0) {
this.messages[messageIdx].tokens = collectedTokens;
}
this.persistActiveConversation();
} catch (error) {
// Don't show error for aborted requests (user cancelled)
if (error instanceof Error && error.name === "AbortError") {
return;
}
this.messages[messageIdx].content =
prefixText +
`\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
this.persistActiveConversation();
} finally {
// Clean up controller if this is still the active request
if (this.currentRequestController === controller) {
this.currentRequestController = null;
}
this.isLoading = false;
this.currentResponse = "";
this.updateActiveConversation();
}
}
/**
* Strip thinking tags from content for display.
* Handles both complete <think>...</think> blocks and in-progress <think>... blocks during streaming.
@@ -1274,6 +1551,10 @@ class AppStore {
this.startChat();
}
// Create abort controller for this request
const controller = new AbortController();
this.currentRequestController = controller;
this.isLoading = true;
this.currentResponse = "";
this.ttftMs = null;
@@ -1408,7 +1689,10 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
signal: controller.signal,
});
if (!response.ok) {
@@ -1424,6 +1708,7 @@ class AppStore {
const decoder = new TextDecoder();
let fullContent = "";
let buffer = "";
const collectedTokens: TokenData[] = [];
while (true) {
const { done, value } = await reader.read();
@@ -1463,6 +1748,29 @@ class AppStore {
this.tps = (tokenCount / elapsed) * 1000;
}
// Extract logprobs for uncertainty visualization
const logprobsData = parsed.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
const tokenData: TokenData = {
token: logprobItem.token || tokenContent,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: {
token: string;
logprob: number;
bytes?: number[];
}) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
};
collectedTokens.push(tokenData);
}
fullContent += tokenContent;
// Strip thinking tags for display and extract thinking content
@@ -1477,6 +1785,8 @@ class AppStore {
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
// Update tokens during streaming for real-time visualization
this.messages[idx].tokens = [...collectedTokens];
}
this.persistActiveConversation();
}
@@ -1524,9 +1834,17 @@ class AppStore {
if (this.tps !== null) {
this.messages[idx].tps = this.tps;
}
// Store token data for uncertainty visualization
if (collectedTokens.length > 0) {
this.messages[idx].tokens = collectedTokens;
}
}
this.persistActiveConversation();
} catch (error) {
// Don't show error for aborted requests (user cancelled)
if (error instanceof Error && error.name === "AbortError") {
return;
}
console.error("Error sending message:", error);
// Update the assistant message with error
const idx = this.messages.findIndex((m) => m.id === assistantMessage.id);
@@ -1536,6 +1854,10 @@ class AppStore {
}
this.persistActiveConversation();
} finally {
// Clean up controller if this is still the active request
if (this.currentRequestController === controller) {
this.currentRequestController = null;
}
this.isLoading = false;
this.currentResponse = "";
this.updateActiveConversation();
@@ -1615,6 +1937,9 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
appStore.regenerateFromToken(messageId, tokenIndex);
export const abortCurrentRequest = () => appStore.abortCurrentRequest();
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -6,6 +6,8 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"pydantic>=2.11.7",
"fastapi>=0.116.1",
"filelock>=3.18.0",
@@ -21,7 +23,6 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
]
[project.scripts]

View File

@@ -0,0 +1 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -0,0 +1,184 @@
"""Claude Messages API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
FinishReason,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.tasks import ChatCompletionTaskParams
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def claude_request_to_chat_params(
request: ClaudeMessagesRequest,
) -> ChatCompletionTaskParams:
"""Convert Claude Messages API request to internal ChatCompletionTaskParams."""
messages: list[ChatCompletionMessage] = []
# Add system message if present
if request.system:
if isinstance(request.system, str):
messages.append(
ChatCompletionMessage(role="system", content=request.system)
)
else:
# List of text blocks
system_text = "".join(block.text for block in request.system)
messages.append(ChatCompletionMessage(role="system", content=system_text))
# Convert messages
for msg in request.messages:
content: str
if isinstance(msg.content, str):
content = msg.content
else:
# Concatenate text blocks (images not supported for MVP)
text_parts: list[str] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
content = "".join(text_parts)
messages.append(ChatCompletionMessage(role=msg.role, content=content))
return ChatCompletionTaskParams(
model=request.model,
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
)
def chat_response_to_claude_response(
response: ChatCompletionResponse,
) -> ClaudeMessagesResponse:
"""Convert internal ChatCompletionResponse to Claude Messages API response."""
content_text = ""
stop_reason: ClaudeStopReason | None = None
if response.choices:
choice = response.choices[0]
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
content_text = (
choice.message.content
if isinstance(choice.message.content, str)
else str(choice.message.content)
)
stop_reason = finish_reason_to_claude_stop_reason(choice.finish_reason)
# Use actual usage data from response if available
input_tokens = response.usage.prompt_tokens if response.usage else 0
output_tokens = response.usage.completion_tokens if response.usage else 0
return ClaudeMessagesResponse(
id=f"msg_{response.id}",
model=response.model,
content=[ClaudeTextBlock(text=content_text)],
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
async for chunk in chunk_stream:
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -0,0 +1,199 @@
"""OpenAI Responses API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
from exo.shared.types.tasks import ChatCompletionTaskParams
def responses_request_to_chat_params(
request: ResponsesRequest,
) -> ChatCompletionTaskParams:
"""Convert OpenAI Responses API request to internal ChatCompletionTaskParams."""
messages: list[ChatCompletionMessage] = []
# Add instructions as system message if present
if request.instructions:
messages.append(
ChatCompletionMessage(role="system", content=request.instructions)
)
# Convert input to messages
if isinstance(request.input, str):
messages.append(ChatCompletionMessage(role="user", content=request.input))
else:
for msg in request.input:
messages.append(
ChatCompletionMessage(
role=msg.role,
content=msg.content,
)
)
return ChatCompletionTaskParams(
model=request.model,
messages=messages,
max_tokens=request.max_output_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=request.stream,
)
def chat_response_to_responses_response(
response: ChatCompletionResponse,
) -> ResponsesResponse:
"""Convert internal ChatCompletionResponse to OpenAI Responses API response."""
output_text = ""
if response.choices:
choice = response.choices[0]
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
output_text = (
choice.message.content
if isinstance(choice.message.content, str)
else str(choice.message.content)
)
item_id = f"item_{response.id}"
output_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=output_text)],
)
usage = None
if response.usage:
usage = ResponseUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ResponsesResponse(
id=f"resp_{response.id}",
model=response.model,
output=[output_item],
output_text=output_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(response=initial_response)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(response=initial_response)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
output_index=0, content_index=0, part=initial_part
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
last_stats = None
async for chunk in chunk_stream:
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
output_index=0, content_index=0, text=accumulated_text
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
output_index=0, content_index=0, part=final_part
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[final_item],
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(response=final_response)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -14,6 +14,16 @@ from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from exo.master.adapters.claude import (
chat_response_to_claude_response,
claude_request_to_chat_params,
generate_claude_stream,
)
from exo.master.adapters.responses import (
chat_response_to_responses_response,
generate_responses_stream,
responses_request_to_chat_params,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
@@ -31,6 +41,8 @@ from exo.shared.types.api import (
DeleteInstanceResponse,
FinishReason,
GenerationStats,
Logprobs,
LogprobsContentItem,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -39,6 +51,10 @@ from exo.shared.types.api import (
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeMessagesRequest,
ClaudeMessagesResponse,
)
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -52,6 +68,10 @@ from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -65,6 +85,20 @@ from exo.utils.event_buffer import OrderedBuffer
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
bytes=list(chunk.text.encode("utf-8")),
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -73,6 +107,7 @@ def chunk_to_response(
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
@@ -168,6 +203,8 @@ class API:
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -548,6 +585,75 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def claude_messages(
self, payload: ClaudeMessagesRequest
) -> ClaudeMessagesResponse | StreamingResponse:
"""Handle Claude Messages API requests."""
chat_params = claude_request_to_chat_params(payload)
model_meta = await resolve_model_meta(chat_params.model)
chat_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == chat_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(chat_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {chat_params.model}",
)
command = ChatCompletion(request_params=chat_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_claude_stream(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
response = await self._collect_chat_completion(command.command_id)
return chat_response_to_claude_response(response)
async def openai_responses(
self, payload: ResponsesRequest
) -> ResponsesResponse | StreamingResponse:
"""Handle OpenAI Responses API requests."""
chat_params = responses_request_to_chat_params(payload)
model_meta = await resolve_model_meta(chat_params.model)
chat_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == chat_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(chat_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {chat_params.model}",
)
command = ChatCompletion(request_params=chat_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_responses_stream(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
response = await self._collect_chat_completion(command.command_id)
return chat_response_to_responses_response(response)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -612,9 +718,17 @@ class API:
and event.command_id in self._chat_completion_queues
):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
try:
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
except (anyio.BrokenResourceError, KeyError):
# Client disconnected, queue was closed/removed - this is expected
# when clients abort requests (e.g., regenerate from token)
logger.debug(
f"Client disconnected for command {event.command_id}, "
"dropping chunk"
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -0,0 +1,392 @@
"""Tests for Claude Messages API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.claude import (
chat_response_to_claude_response,
claude_request_to_chat_params,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
Usage,
)
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessage,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToChatParams:
"""Tests for converting Claude Messages API requests to ChatCompletionTaskParams."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert params.model == "claude-3-opus"
assert params.max_tokens == 100
assert len(params.messages) == 1
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello"
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are a helpful assistant."
assert params.messages[1].role == "user"
assert params.messages[1].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are helpful. Be concise."
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 1
assert params.messages[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 3
assert params.messages[0].role == "user"
assert params.messages[1].role == "assistant"
assert params.messages[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_chat_params(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestChatResponseToClaudeResponse:
"""Tests for converting ChatCompletionResponse to Claude Messages API response."""
def test_basic_response_conversion(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content="Hello! How can I help you?",
),
finish_reason="stop",
)
],
usage=Usage(prompt_tokens=10, completion_tokens=7, total_tokens=17),
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.id == "msg_chatcmpl-123"
assert claude_response.model == "llama-3.2-1b"
assert claude_response.role == "assistant"
assert claude_response.type == "message"
assert len(claude_response.content) == 1
assert claude_response.content[0].type == "text"
assert claude_response.content[0].text == "Hello! How can I help you?"
assert claude_response.stop_reason == "end_turn"
assert claude_response.usage.input_tokens == 10
assert claude_response.usage.output_tokens == 7
def test_response_with_length_finish_reason(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content="Truncated..."
),
finish_reason="length",
)
],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.stop_reason == "max_tokens"
def test_response_with_empty_content(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content=""),
finish_reason="stop",
)
],
usage=Usage(prompt_tokens=10, completion_tokens=0, total_tokens=10),
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == ""
assert claude_response.usage.output_tokens == 0
def test_response_with_no_choices(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == ""
assert claude_response.stop_reason is None
assert claude_response.usage.input_tokens == 0
assert claude_response.usage.output_tokens == 0
def test_response_without_usage(self):
"""Test response conversion when usage data is not available."""
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == "Hello!"
assert claude_response.usage.input_tokens == 0
assert claude_response.usage.output_tokens == 0
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)
class TestClaudeStreamingEvents:
"""Tests for Claude Messages API streaming event serialization."""
def test_message_start_event_format(self):
message = ClaudeMessageStart(
id="msg_123",
model="claude-3-opus",
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
)
event = ClaudeMessageStartEvent(message=message)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_start"
assert parsed["message"]["id"] == "msg_123"
assert parsed["message"]["type"] == "message"
assert parsed["message"]["role"] == "assistant"
assert parsed["message"]["model"] == "claude-3-opus"
def test_content_block_start_event_format(self):
event = ClaudeContentBlockStartEvent(
index=0,
content_block=ClaudeTextBlock(text=""),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_start"
assert parsed["index"] == 0
assert parsed["content_block"]["type"] == "text"
assert parsed["content_block"]["text"] == ""
def test_content_block_delta_event_format(self):
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_delta"
assert parsed["index"] == 0
assert parsed["delta"]["type"] == "text_delta"
assert parsed["delta"]["text"] == "Hello"
def test_content_block_stop_event_format(self):
event = ClaudeContentBlockStopEvent(index=0)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_stop"
assert parsed["index"] == 0
def test_message_delta_event_format(self):
event = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason="end_turn"),
usage=ClaudeMessageDeltaUsage(output_tokens=25),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_delta"
assert parsed["delta"]["stop_reason"] == "end_turn"
assert parsed["usage"]["output_tokens"] == 25
def test_message_stop_event_format(self):
event = ClaudeMessageStopEvent()
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_stop"
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
# Simulate the SSE format used in the streaming generator
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
assert sse_line.startswith("event: content_block_delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -0,0 +1,414 @@
"""Tests for OpenAI Responses API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.responses import (
chat_response_to_responses_response,
responses_request_to_chat_params,
)
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
Usage,
)
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseCreatedEvent,
ResponseInputMessage,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
class TestResponsesRequestToChatParams:
"""Tests for converting OpenAI Responses API requests to ChatCompletionTaskParams."""
def test_string_input_conversion(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello, how are you?",
)
params = responses_request_to_chat_params(request)
assert params.model == "gpt-4o"
assert len(params.messages) == 1
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello, how are you?"
def test_message_array_input_conversion(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="user", content="Hello"),
ResponseInputMessage(role="assistant", content="Hi there!"),
ResponseInputMessage(role="user", content="How are you?"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 3
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello"
assert params.messages[1].role == "assistant"
assert params.messages[1].content == "Hi there!"
assert params.messages[2].role == "user"
assert params.messages[2].content == "How are you?"
def test_request_with_instructions(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
instructions="You are a helpful assistant. Be concise.",
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are a helpful assistant. Be concise."
assert params.messages[1].role == "user"
assert params.messages[1].content == "Hello"
def test_request_with_optional_parameters(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
max_output_tokens=500,
temperature=0.8,
top_p=0.95,
stream=True,
)
params = responses_request_to_chat_params(request)
assert params.max_tokens == 500
assert params.temperature == 0.8
assert params.top_p == 0.95
assert params.stream is True
def test_request_with_system_role_in_messages(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="system", content="Be helpful"),
ResponseInputMessage(role="user", content="Hello"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[1].role == "user"
def test_request_with_developer_role(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="developer", content="Internal note"),
ResponseInputMessage(role="user", content="Hello"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "developer"
class TestChatResponseToResponsesResponse:
"""Tests for converting ChatCompletionResponse to OpenAI Responses API response."""
def test_basic_response_conversion(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content="Hello! How can I help you?",
),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.id == "resp_chatcmpl-123"
assert responses_response.object == "response"
assert responses_response.model == "llama-3.2-1b"
assert responses_response.status == "completed"
assert responses_response.output_text == "Hello! How can I help you?"
assert len(responses_response.output) == 1
assert responses_response.output[0].type == "message"
assert responses_response.output[0].role == "assistant"
assert len(responses_response.output[0].content) == 1
assert responses_response.output[0].content[0].type == "output_text"
assert (
responses_response.output[0].content[0].text == "Hello! How can I help you?"
)
def test_response_with_usage(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
usage=Usage(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
),
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.usage is not None
assert responses_response.usage.input_tokens == 10
assert responses_response.usage.output_tokens == 5
assert responses_response.usage.total_tokens == 15
def test_response_with_empty_content(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content=""),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output_text == ""
assert responses_response.output[0].content[0].text == ""
def test_response_with_no_choices(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output_text == ""
def test_response_without_usage(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.usage is None
def test_response_item_id_format(self):
response = ChatCompletionResponse(
id="chatcmpl-abc123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output[0].id == "item_chatcmpl-abc123"
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model="gpt-4o",
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1
class TestResponsesStreamingEvents:
"""Tests for OpenAI Responses API streaming event serialization."""
def test_response_created_event_format(self):
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="in_progress",
output=[],
output_text="",
)
event = ResponseCreatedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.created"
assert parsed["response"]["id"] == "resp_123"
assert parsed["response"]["object"] == "response"
assert parsed["response"]["status"] == "in_progress"
def test_output_item_added_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="")],
status="in_progress",
)
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.added"
assert parsed["output_index"] == 0
assert parsed["item"]["type"] == "message"
assert parsed["item"]["id"] == "item_123"
assert parsed["item"]["role"] == "assistant"
def test_content_part_added_event_format(self):
part = ResponseOutputText(text="")
event = ResponseContentPartAddedEvent(
output_index=0,
content_index=0,
part=part,
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.content_part.added"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["part"]["type"] == "output_text"
def test_text_delta_event_format(self):
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.delta"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["delta"] == "Hello"
def test_text_done_event_format(self):
event = ResponseTextDoneEvent(
output_index=0,
content_index=0,
text="Hello, world!",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.done"
assert parsed["text"] == "Hello, world!"
def test_output_item_done_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello, world!")],
status="completed",
)
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.done"
assert parsed["item"]["status"] == "completed"
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
def test_response_completed_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello!")],
status="completed",
)
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="completed",
output=[item],
output_text="Hello!",
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
)
event = ResponseCompletedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.completed"
assert parsed["response"]["status"] == "completed"
assert parsed["response"]["output_text"] == "Hello!"
assert parsed["response"]["usage"]["total_tokens"] == 15
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
# Simulate the SSE format used in the streaming generator
sse_line = (
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
)
assert sse_line.startswith("event: response.output_text.delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -146,10 +146,12 @@ class ChatCompletionTaskParams(BaseModel):
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
continue_from_prefix: bool = False # When True, continue the last assistant message
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):

View File

@@ -1,6 +1,6 @@
from enum import Enum
from exo.shared.types.api import GenerationStats
from exo.shared.types.api import GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -20,6 +20,8 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None

View File

@@ -0,0 +1,168 @@
"""Claude Messages API types for request/response conversion."""
from typing import Literal
from pydantic import BaseModel, Field
# Type aliases
ClaudeRole = Literal["user", "assistant"]
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
# Content block types
class ClaudeTextBlock(BaseModel, frozen=True):
"""Text content block in Claude Messages API."""
type: Literal["text"] = "text"
text: str
class ClaudeImageSource(BaseModel, frozen=True):
"""Image source for Claude image blocks."""
type: Literal["base64", "url"]
media_type: str | None = None
data: str | None = None
url: str | None = None
class ClaudeImageBlock(BaseModel, frozen=True):
"""Image content block in Claude Messages API."""
type: Literal["image"] = "image"
source: ClaudeImageSource
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
# Request types
class ClaudeMessage(BaseModel, frozen=True):
"""Message in Claude Messages API request."""
role: ClaudeRole
content: str | list[ClaudeContentBlock]
class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API."""
model: str
max_tokens: int
messages: list[ClaudeMessage]
system: str | list[ClaudeTextBlock] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
metadata: dict[str, str] | None = None
# Response types
class ClaudeUsage(BaseModel, frozen=True):
"""Token usage in Claude Messages API response."""
input_tokens: int
output_tokens: int
class ClaudeMessagesResponse(BaseModel, frozen=True):
"""Response body for Claude Messages API."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock]
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
# Streaming event types
class ClaudeMessageStart(BaseModel, frozen=True):
"""Partial message in message_start event."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock] = Field(default_factory=list)
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
class ClaudeMessageStartEvent(BaseModel, frozen=True):
"""Event sent at start of message stream."""
type: Literal["message_start"] = "message_start"
message: ClaudeMessageStart
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
"""Event sent at start of a content block."""
type: Literal["content_block_start"] = "content_block_start"
index: int
content_block: ClaudeTextBlock
class ClaudeTextDelta(BaseModel, frozen=True):
"""Delta for text content block."""
type: Literal["text_delta"] = "text_delta"
text: str
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
"""Event sent for content block delta."""
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: ClaudeTextDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
"""Event sent at end of a content block."""
type: Literal["content_block_stop"] = "content_block_stop"
index: int
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
"""Usage in message_delta event."""
output_tokens: int
class ClaudeMessageDelta(BaseModel, frozen=True):
"""Delta in message_delta event."""
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
"""Event sent with final message delta."""
type: Literal["message_delta"] = "message_delta"
delta: ClaudeMessageDelta
usage: ClaudeMessageDeltaUsage
class ClaudeMessageStopEvent(BaseModel, frozen=True):
"""Event sent at end of message stream."""
type: Literal["message_stop"] = "message_stop"
ClaudeStreamEvent = (
ClaudeMessageStartEvent
| ClaudeContentBlockStartEvent
| ClaudeContentBlockDeltaEvent
| ClaudeContentBlockStopEvent
| ClaudeMessageDeltaEvent
| ClaudeMessageStopEvent
)

View File

@@ -0,0 +1,162 @@
"""OpenAI Responses API types for request/response conversion."""
import time
from typing import Literal
from pydantic import BaseModel, Field
# Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
ResponseRole = Literal["user", "assistant", "system", "developer"]
# Request types
class ResponseInputMessage(BaseModel, frozen=True):
"""Input message for Responses API."""
role: ResponseRole
content: str
class ResponsesRequest(BaseModel):
"""Request body for OpenAI Responses API."""
model: str
input: str | list[ResponseInputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stream: bool = False
# previous_response_id not supported in MVP
metadata: dict[str, str] | None = None
# Response types
class ResponseOutputText(BaseModel, frozen=True):
"""Text content in response output."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list[dict[str, str]] = Field(default_factory=list)
class ResponseMessageItem(BaseModel, frozen=True):
"""Message item in response output array."""
type: Literal["message"] = "message"
id: str
role: Literal["assistant"] = "assistant"
content: list[ResponseOutputText]
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
class ResponseUsage(BaseModel, frozen=True):
"""Token usage in Responses API response."""
input_tokens: int
output_tokens: int
total_tokens: int
class ResponsesResponse(BaseModel, frozen=True):
"""Response body for OpenAI Responses API."""
id: str
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
status: ResponseStatus = "completed"
model: str
output: list[ResponseItem]
output_text: str
usage: ResponseUsage | None = None
# Streaming event types
class ResponseCreatedEvent(BaseModel, frozen=True):
"""Event sent when response is created."""
type: Literal["response.created"] = "response.created"
response: ResponsesResponse
class ResponseInProgressEvent(BaseModel, frozen=True):
"""Event sent when response starts processing."""
type: Literal["response.in_progress"] = "response.in_progress"
response: ResponsesResponse
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
"""Event sent when an output item is added."""
type: Literal["response.output_item.added"] = "response.output_item.added"
output_index: int
item: ResponseItem
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a content part is added."""
type: Literal["response.content_part.added"] = "response.content_part.added"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for text delta during streaming."""
type: Literal["response.output_text.delta"] = "response.output_text.delta"
output_index: int
content_index: int
delta: str
class ResponseTextDoneEvent(BaseModel, frozen=True):
"""Event sent when text content is done."""
type: Literal["response.output_text.done"] = "response.output_text.done"
output_index: int
content_index: int
text: str
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a content part is done."""
type: Literal["response.content_part.done"] = "response.content_part.done"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
"""Event sent when an output item is done."""
type: Literal["response.output_item.done"] = "response.output_item.done"
output_index: int
item: ResponseItem
class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed."""
type: Literal["response.completed"] = "response.completed"
response: ResponsesResponse
ResponsesStreamEvent = (
ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseOutputItemAddedEvent
| ResponseContentPartAddedEvent
| ResponseTextDeltaEvent
| ResponseTextDoneEvent
| ResponseContentPartDoneEvent
| ResponseOutputItemDoneEvent
| ResponseCompletedEvent
)

View File

@@ -1,4 +1,4 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.api import FinishReason, GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
@@ -13,7 +13,8 @@ class TokenizedResponse(BaseRunnerResponse):
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None

View File

@@ -7,13 +7,13 @@ import time
import traceback
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal, cast
from typing import Callable, Literal
from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
import certifi
import httpx
from loguru import logger
from pydantic import (
BaseModel,
@@ -207,22 +207,23 @@ async def _fetch_file_list(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
response = await session.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Failed to fetch file list: {response.status_code}")
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
async def get_download_headers() -> dict[str, str]:
@@ -230,25 +231,31 @@ async def get_download_headers() -> dict[str, str]:
def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long",
) -> httpx.AsyncClient:
) -> aiohttp.ClientSession:
if timeout_profile == "short":
total_timeout = 30
connect_timeout = 10
read_timeout = 30
sock_read_timeout = 30
sock_connect_timeout = 10
else:
total_timeout = 1800
connect_timeout = 60
read_timeout = 1800
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
return httpx.AsyncClient(
verify=ssl_context,
timeout=httpx.Timeout(
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
read=read_timeout,
write=total_timeout,
sock_read=sock_read_timeout,
sock_connect=sock_connect_timeout,
),
)
@@ -275,25 +282,23 @@ async def file_meta(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.head(url, headers=headers) as r,
):
r = await session.head(url, headers=headers)
if r.status_code == 307:
if r.status == 307:
# On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("x-linked-etag")
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag)
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = cast(str | None, r.headers.get("location"))
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
content_length = cast(
str | None,
r.headers.get("x-linked-size") or r.headers.get("content-length"),
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
content_length = 0 if content_length is None else int(content_length)
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
etag = trim_etag(etag)
@@ -352,17 +357,17 @@ async def _download_file(
n_read = resume_byte_pos or 0
async with (
create_http_session(timeout_profile="long") as session,
session.get(url, headers=headers) as r,
):
r = await session.get(url, headers=headers)
if r.status_code == 404:
if r.status == 404:
raise FileNotFoundError(f"File not found: {url}")
assert r.status_code in [200, 206], (
f"Failed to download {path} from {url}: {r.status_code}"
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
)
async with aiofiles.open(
partial_path, "ab" if resume_byte_pos else "wb"
) as f:
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
while chunk := await r.content.read(8 * 1024 * 1024):
n_read = n_read + (await f.write(chunk))
on_progress(n_read, length, False)

View File

@@ -40,4 +40,6 @@ class TokenizerWrapper:
messages_dicts: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
) -> str: ...

View File

@@ -12,6 +12,7 @@ from exo.shared.types.api import (
ChatCompletionMessage,
FinishReason,
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
@@ -115,6 +116,60 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def extract_top_logprobs(
logprobs: mx.array,
tokenizer: TokenizerWrapper,
top_k: int,
selected_token: int,
) -> tuple[float, list[TopLogprobItem]]:
"""Extract the selected token's logprob and top-k alternative tokens.
Args:
logprobs: Full vocabulary logprobs array from MLX
tokenizer: Tokenizer for decoding token IDs to strings
top_k: Number of top alternatives to return
selected_token: The token ID that was actually sampled
Returns:
Tuple of (selected_token_logprob, list of TopLogprobItem for top-k tokens)
"""
# Get the logprob of the selected token
selected_logprob = float(logprobs[selected_token].item())
# Get top-k indices (most probable tokens)
# mx.argpartition gives indices that would partition the array
# We negate logprobs since argpartition finds smallest, and we want largest
top_k = min(top_k, logprobs.shape[0]) # Don't exceed vocab size
top_indices = mx.argpartition(-logprobs, top_k)[:top_k]
# Get the actual logprob values for these indices
top_values = logprobs[top_indices]
# Sort by logprob (descending) for consistent ordering
sort_order = mx.argsort(-top_values)
top_indices = top_indices[sort_order]
top_values = top_values[sort_order]
# Convert to list of TopLogprobItem
top_logprob_items: list[TopLogprobItem] = []
for i in range(top_k):
token_id = int(top_indices[i].item())
token_logprob = float(top_values[i].item())
# Decode token ID to string
token_str = tokenizer.decode([token_id])
# Get byte representation
token_bytes = list(token_str.encode("utf-8"))
top_logprob_items.append(
TopLogprobItem(
token=token_str,
logprob=token_logprob,
bytes=token_bytes,
)
)
return selected_logprob, top_logprob_items
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
@@ -146,9 +201,24 @@ def mlx_generate(
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
top_k=task.top_k if task.top_k is not None else 0,
)
# Normalize stop sequences to a list
stop_sequences: list[str] = (
([task.stop] if isinstance(task.stop, str) else task.stop)
if task.stop is not None
else []
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
max_tokens = task.max_tokens or MAX_TOKENS
accumulated_text = ""
# Determine if we need to extract logprobs
should_extract_logprobs = task.logprobs is True
num_top_logprobs = task.top_logprobs if task.top_logprobs is not None else 5
for out in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -163,9 +233,41 @@ def mlx_generate(
kv_bits=KV_BITS,
):
logger.info(out.text)
accumulated_text += out.text
# Check for stop sequences
text = out.text
finish_reason: FinishReason | None = cast(
FinishReason | None, out.finish_reason
)
stop_matched = False
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in accumulated_text:
# Trim text to just before the stop sequence
stop_index = accumulated_text.find(stop_seq)
text_before_stop = accumulated_text[:stop_index]
chunk_start = len(accumulated_text) - len(out.text)
text = text_before_stop[chunk_start:]
finish_reason = "stop"
stop_matched = True
break
# Extract logprobs if requested
token_logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if should_extract_logprobs:
token_logprob, top_logprobs = extract_top_logprobs(
logprobs=out.logprobs,
tokenizer=tokenizer,
top_k=num_top_logprobs,
selected_token=out.token,
)
is_done = finish_reason is not None
stats: GenerationStats | None = None
if out.finish_reason is not None:
if is_done:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
@@ -173,22 +275,25 @@ def mlx_generate(
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
if not stop_matched and out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
text=text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
logprob=token_logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
)
if out.finish_reason is not None:
if is_done:
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
# TODO: Do we want an mx_barrier?

View File

@@ -359,12 +359,26 @@ def apply_chat_template(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
)
# Use continue_final_message when continuing from prefix (e.g., regenerate from token)
# This keeps the final assistant message open without EOS tokens
# Note: explicitly set add_generation_prompt=False when using continue_final_message
# because some tokenizers (e.g., Kimi) default add_generation_prompt=True
prompt: str
if chat_task_data.continue_from_prefix:
prompt = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
continue_final_message=True,
add_generation_prompt=False,
tools=chat_task_data.tools,
)
else:
prompt = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
)
logger.info(prompt)

View File

@@ -186,6 +186,8 @@ def main(
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
finish_reason=response.finish_reason,
stats=response.stats,
),

View File

@@ -1,63 +1,60 @@
import anyio
import httpx
from anyio import create_task_group
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
REACHABILITY_ATTEMPTS = 3
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
body = response.read().decode("utf-8").strip()
remote_node_id = NodeId(body)
break
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
except (
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
) as e:
last_error = e
await anyio.sleep(1)
return NodeId(body) or None
except OSError:
return None
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()
else:
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
else:
logger.warning(
f"malformed response from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
@@ -77,33 +74,18 @@ async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
if node.node_id == self_node_id:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
client,
)
return reachable

1294
uv.lock generated
View File

File diff suppressed because it is too large Load Diff