mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-18 02:50:24 -05:00
Compare commits
3 Commits
main
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6eb8f9d9f5 | ||
|
|
663a0faaeb | ||
|
|
0a58aa73ec |
106
.github/workflows/build-app.yml
vendored
106
.github/workflows/build-app.yml
vendored
@@ -1,16 +1,5 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
# Release workflow:
|
||||
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
|
||||
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
|
||||
# 3. This workflow builds, signs, and notarizes the DMG
|
||||
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
|
||||
# 5. DMG and appcast.xml are uploaded to S3
|
||||
# 6. The draft GitHub Release is published with the DMG attached
|
||||
#
|
||||
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
|
||||
# If no draft exists, a release is auto-created with generated notes.
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
@@ -22,10 +11,8 @@ on:
|
||||
jobs:
|
||||
build-macos-app:
|
||||
runs-on: "macos-26"
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
SPARKLE_VERSION: 2.9.0-beta.1
|
||||
SPARKLE_VERSION: 2.8.1
|
||||
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
||||
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
||||
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
||||
@@ -100,52 +87,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Fetch and validate release notes
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Find draft release by name using gh release list (more reliable with default token)
|
||||
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
|
||||
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -z "$DRAFT_EXISTS" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
|
||||
echo "Please create a draft release with release notes before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch full release details via API to get body and ID
|
||||
echo "Found draft release, fetching details..."
|
||||
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
|
||||
|
||||
# Extract release notes
|
||||
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
|
||||
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "Draft release has no notes (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Draft release exists but has no release notes"
|
||||
echo "Please add release notes to the draft release before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Save release ID for later publishing
|
||||
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
|
||||
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
|
||||
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
|
||||
|
||||
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
|
||||
echo "$NOTES" > /tmp/release_notes.md
|
||||
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
|
||||
|
||||
# ============================================================
|
||||
# Install dependencies
|
||||
# ============================================================
|
||||
@@ -363,28 +304,6 @@ jobs:
|
||||
$CHANNEL_FLAG \
|
||||
.
|
||||
|
||||
- name: Inject release notes into appcast
|
||||
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
|
||||
env:
|
||||
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
|
||||
run: |
|
||||
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
|
||||
export NOTES=$(cat "$RELEASE_NOTES_FILE")
|
||||
|
||||
# Insert description after the enclosure tag for this version
|
||||
awk '
|
||||
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
|
||||
print
|
||||
print " <description sparkle:format=\"markdown\"><![CDATA["
|
||||
print ENVIRON["NOTES"]
|
||||
print " ]]></description>"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
|
||||
|
||||
echo "Injected markdown release notes for version $RELEASE_VERSION"
|
||||
|
||||
# ============================================================
|
||||
# Upload artifacts
|
||||
# ============================================================
|
||||
@@ -417,26 +336,3 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
|
||||
- name: Publish GitHub Release
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
|
||||
|
||||
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
|
||||
# Update the draft release with the tag and upload DMG
|
||||
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
|
||||
-f tag_name="$GITHUB_REF_NAME" \
|
||||
-F draft=false
|
||||
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
|
||||
echo "Published release $GITHUB_REF_NAME with DMG attached"
|
||||
else
|
||||
# Alpha without draft release - create one with auto-generated notes
|
||||
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
|
||||
--title "$GITHUB_REF_NAME" \
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
|
||||
fi
|
||||
|
||||
25
AGENTS.md
25
AGENTS.md
@@ -40,31 +40,6 @@ uv run ruff check
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Pre-Commit Checks (REQUIRED)
|
||||
|
||||
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
|
||||
|
||||
```bash
|
||||
# 1. Type checking - MUST pass with 0 errors
|
||||
uv run basedpyright
|
||||
|
||||
# 2. Linting - MUST pass
|
||||
uv run ruff check
|
||||
|
||||
# 3. Formatting - MUST be applied
|
||||
nix fmt
|
||||
|
||||
# 4. Tests - MUST pass
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Run all checks in sequence:
|
||||
```bash
|
||||
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
|
||||
```
|
||||
|
||||
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
|
||||
@@ -585,7 +585,7 @@
|
||||
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.9.0-beta.1;
|
||||
minimumVersion = 2.8.1;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/sparkle-project/Sparkle.git",
|
||||
"state" : {
|
||||
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
|
||||
"version" : "2.9.0-beta.1"
|
||||
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
|
||||
"version" : "2.8.1"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -56,11 +56,6 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Show warning if local network is not working and EXO is running.
|
||||
// The checker uses a longer timeout on first launch to allow time for
|
||||
// the permission prompt, so this correctly handles both:
|
||||
// 1. User denied permission on first launch
|
||||
// 2. Permission broke after restart (macOS TCC bug)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import os.log
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service uses NWConnection to mDNS multicast
|
||||
/// to verify actual connectivity.
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks connectivity using NWConnection to mDNS multicast.
|
||||
/// The connection attempt triggers the permission prompt if not yet shown.
|
||||
private func checkConnectivity(timeout: UInt64) async -> Status {
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { state in
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
// Otherwise keep waiting - might be showing permission prompt
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: timeout)
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
@@ -27,7 +26,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -105,46 +104,22 @@ def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
@@ -266,9 +241,6 @@ class PromptSizer:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Fix for transformers 5.x
|
||||
if hasattr(ids, "input_ids"):
|
||||
ids = ids.input_ids
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
@@ -324,12 +296,6 @@ def main() -> int:
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
@@ -351,7 +317,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -430,7 +396,7 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
if 0 < n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
@@ -472,13 +438,7 @@ def main() -> int:
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize placement: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
continue
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
isLoading,
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
@@ -9,6 +9,8 @@
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
import TokenHeatmap from './TokenHeatmap.svelte';
|
||||
import PrefillProgressBar from './PrefillProgressBar.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -95,6 +97,23 @@
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
// Uncertainty view state - tracks which messages show token heatmap
|
||||
let uncertaintyViewMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function toggleUncertaintyView(messageId: string) {
|
||||
const newSet = new Set(uncertaintyViewMessageIds);
|
||||
if (newSet.has(messageId)) {
|
||||
newSet.delete(messageId);
|
||||
} else {
|
||||
newSet.add(messageId);
|
||||
}
|
||||
uncertaintyViewMessageIds = newSet;
|
||||
}
|
||||
|
||||
function isUncertaintyViewEnabled(messageId: string): boolean {
|
||||
return uncertaintyViewMessageIds.has(messageId);
|
||||
}
|
||||
|
||||
function formatTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toLocaleTimeString('en-US', {
|
||||
hour12: false,
|
||||
@@ -330,6 +349,10 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{:else}
|
||||
<!-- Assistant message styling -->
|
||||
<div class="p-3 sm:p-4">
|
||||
{#if message.prefillProgress}
|
||||
<!-- Prefill progress bar -->
|
||||
<PrefillProgressBar progress={message.prefillProgress} class="mb-3" />
|
||||
{/if}
|
||||
{#if message.thinking && message.thinking.trim().length > 0}
|
||||
<div class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40">
|
||||
<button
|
||||
@@ -366,7 +389,13 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if message.role === 'assistant' && isUncertaintyViewEnabled(message.id) && message.tokens && message.tokens.length > 0}
|
||||
<!-- Uncertainty heatmap view -->
|
||||
<TokenHeatmap tokens={message.tokens} />
|
||||
{:else}
|
||||
<!-- Normal markdown view -->
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{/if}
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
@@ -419,6 +448,19 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Uncertainty view toggle (assistant messages with tokens only) -->
|
||||
{#if message.role === 'assistant' && message.tokens && message.tokens.length > 0}
|
||||
<button
|
||||
onclick={() => toggleUncertaintyView(message.id)}
|
||||
class="p-1.5 transition-colors rounded cursor-pointer {isUncertaintyViewEnabled(message.id) ? 'text-exo-yellow' : 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
title={isUncertaintyViewEnabled(message.id) ? 'Hide uncertainty' : 'Show uncertainty'}
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
|
||||
67
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
67
dashboard/src/lib/components/PrefillProgressBar.svelte
Normal file
@@ -0,0 +1,67 @@
|
||||
<script lang="ts">
|
||||
import type { PrefillProgress } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
progress: PrefillProgress;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { progress, class: className = '' }: Props = $props();
|
||||
|
||||
const percentage = $derived(
|
||||
progress.total > 0 ? Math.round((progress.processed / progress.total) * 100) : 0
|
||||
);
|
||||
|
||||
function formatTokenCount(count: number): string {
|
||||
if (count >= 1000) {
|
||||
return `${(count / 1000).toFixed(1)}k`;
|
||||
}
|
||||
return count.toString();
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="prefill-progress {className}">
|
||||
<div class="flex items-center justify-between text-xs text-gray-400 mb-1">
|
||||
<span class="flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5 animate-spin"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<circle
|
||||
class="opacity-25"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r="10"
|
||||
stroke="currentColor"
|
||||
stroke-width="4"
|
||||
></circle>
|
||||
<path
|
||||
class="opacity-75"
|
||||
fill="currentColor"
|
||||
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||
></path>
|
||||
</svg>
|
||||
<span>Processing prompt</span>
|
||||
</span>
|
||||
<span class="font-mono">
|
||||
{formatTokenCount(progress.processed)} / {formatTokenCount(progress.total)} tokens
|
||||
</span>
|
||||
</div>
|
||||
<div class="h-1.5 bg-gray-700 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-blue-500 rounded-full transition-all duration-150 ease-out"
|
||||
style="width: {percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="text-right text-xs text-gray-500 mt-0.5">
|
||||
{percentage}%
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.prefill-progress {
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
121
dashboard/src/lib/components/TokenHeatmap.svelte
Normal file
121
dashboard/src/lib/components/TokenHeatmap.svelte
Normal file
@@ -0,0 +1,121 @@
|
||||
<script lang="ts">
|
||||
import type { TokenData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
tokens: TokenData[];
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { tokens, class: className = '' }: Props = $props();
|
||||
|
||||
// Tooltip state
|
||||
let hoveredToken = $state<{ token: TokenData; x: number; y: number } | null>(null);
|
||||
|
||||
/**
|
||||
* Get confidence level based on probability
|
||||
* High: >0.8 (logprob > -0.22)
|
||||
* Medium: 0.5-0.8 (logprob -0.69 to -0.22)
|
||||
* Low: 0.2-0.5 (logprob -1.61 to -0.69)
|
||||
* Very Low: <0.2 (logprob < -1.61)
|
||||
*/
|
||||
function getConfidenceClass(probability: number): string {
|
||||
if (probability > 0.8) return 'bg-green-500/30 text-green-100';
|
||||
if (probability > 0.5) return 'bg-yellow-500/30 text-yellow-100';
|
||||
if (probability > 0.2) return 'bg-orange-500/30 text-orange-100';
|
||||
return 'bg-red-500/40 text-red-100';
|
||||
}
|
||||
|
||||
/**
|
||||
* Get border color for token based on probability
|
||||
*/
|
||||
function getBorderClass(probability: number): string {
|
||||
if (probability > 0.8) return 'border-green-500/50';
|
||||
if (probability > 0.5) return 'border-yellow-500/50';
|
||||
if (probability > 0.2) return 'border-orange-500/50';
|
||||
return 'border-red-500/50';
|
||||
}
|
||||
|
||||
function handleMouseEnter(event: MouseEvent, token: TokenData) {
|
||||
const rect = (event.target as HTMLElement).getBoundingClientRect();
|
||||
hoveredToken = {
|
||||
token,
|
||||
x: rect.left + rect.width / 2,
|
||||
y: rect.top - 10
|
||||
};
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
hoveredToken = null;
|
||||
}
|
||||
|
||||
function formatProbability(prob: number): string {
|
||||
return (prob * 100).toFixed(1) + '%';
|
||||
}
|
||||
|
||||
function formatLogprob(logprob: number): string {
|
||||
return logprob.toFixed(3);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="token-heatmap leading-relaxed {className}">
|
||||
{#each tokens as tokenData, i (i)}
|
||||
<span
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(tokenData.probability)} {getBorderClass(tokenData.probability)} hover:opacity-80"
|
||||
onmouseenter={(e) => handleMouseEnter(e, tokenData)}
|
||||
onmouseleave={handleMouseLeave}
|
||||
>{tokenData.token}</span>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Tooltip -->
|
||||
{#if hoveredToken}
|
||||
<div
|
||||
class="fixed z-50 pointer-events-none"
|
||||
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
|
||||
>
|
||||
<div class="bg-gray-900 border border-gray-700 rounded-lg shadow-xl p-3 text-sm min-w-48">
|
||||
<!-- Token info -->
|
||||
<div class="mb-2">
|
||||
<span class="text-gray-400 text-xs">Token:</span>
|
||||
<span class="text-white font-mono ml-1">"{hoveredToken.token.token}"</span>
|
||||
<span class="text-green-400 ml-2">{formatProbability(hoveredToken.token.probability)}</span>
|
||||
</div>
|
||||
|
||||
<div class="text-gray-400 text-xs mb-1">
|
||||
logprob: <span class="text-gray-300 font-mono">{formatLogprob(hoveredToken.token.logprob)}</span>
|
||||
</div>
|
||||
|
||||
<!-- Top alternatives -->
|
||||
{#if hoveredToken.token.topLogprobs.length > 0}
|
||||
<div class="border-t border-gray-700 mt-2 pt-2">
|
||||
<div class="text-gray-400 text-xs mb-1">Alternatives:</div>
|
||||
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
|
||||
{@const altProb = Math.exp(alt.logprob)}
|
||||
<div class="flex justify-between items-center text-xs py-0.5">
|
||||
<span class="text-gray-300 font-mono truncate max-w-24">"{alt.token}"</span>
|
||||
<span class="text-gray-400 ml-2">{formatProbability(altProb)}</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Arrow -->
|
||||
<div class="absolute left-1/2 -translate-x-1/2 top-full">
|
||||
<div class="border-8 border-transparent border-t-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.token-heatmap {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.token-span {
|
||||
margin: 0;
|
||||
border-width: 1px;
|
||||
}
|
||||
</style>
|
||||
@@ -182,6 +182,26 @@ export interface MessageAttachment {
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
// Token-level data for uncertainty visualization
|
||||
export interface TopLogprob {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes?: number[];
|
||||
}
|
||||
|
||||
export interface TokenData {
|
||||
token: string;
|
||||
logprob: number;
|
||||
probability: number; // exp(logprob)
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
// Prefill progress data for long prompts
|
||||
export interface PrefillProgress {
|
||||
processed: number;
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -191,6 +211,8 @@ export interface Message {
|
||||
attachments?: MessageAttachment[];
|
||||
ttftMs?: number; // Time to first token in ms (for assistant messages)
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
tokens?: TokenData[]; // Token-level data for uncertainty visualization
|
||||
prefillProgress?: PrefillProgress | null; // Prefill progress for long prompts
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -1107,6 +1129,8 @@ class AppStore {
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1408,6 +1432,8 @@ class AppStore {
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1424,6 +1450,8 @@ class AppStore {
|
||||
const decoder = new TextDecoder();
|
||||
let fullContent = "";
|
||||
let buffer = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
let currentEventType = ""; // Track SSE event type
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@@ -1437,14 +1465,43 @@ class AppStore {
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) continue;
|
||||
if (!trimmed) {
|
||||
// Empty line resets event type
|
||||
currentEventType = "";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle event type declaration
|
||||
if (trimmed.startsWith("event: ")) {
|
||||
currentEventType = trimmed.slice(7);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (trimmed.startsWith("data: ")) {
|
||||
const data = trimmed.slice(6);
|
||||
if (data === "[DONE]") continue;
|
||||
if (data === "[DONE]") {
|
||||
currentEventType = "";
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
|
||||
// Handle prefill progress events
|
||||
if (currentEventType === "prefill_progress") {
|
||||
const idx = this.messages.findIndex(
|
||||
(m) => m.id === assistantMessage.id,
|
||||
);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].prefillProgress = {
|
||||
processed: parsed.processed,
|
||||
total: parsed.total,
|
||||
};
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle regular token data
|
||||
const tokenContent = parsed.choices?.[0]?.delta?.content;
|
||||
if (tokenContent) {
|
||||
// Track first token for TTFT
|
||||
@@ -1453,6 +1510,14 @@ class AppStore {
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
}
|
||||
|
||||
// Clear prefill progress when first token arrives
|
||||
const msgIdx = this.messages.findIndex(
|
||||
(m) => m.id === assistantMessage.id,
|
||||
);
|
||||
if (msgIdx !== -1 && this.messages[msgIdx].prefillProgress) {
|
||||
this.messages[msgIdx].prefillProgress = null;
|
||||
}
|
||||
|
||||
// Count tokens (each SSE chunk is typically one token)
|
||||
tokenCount += 1;
|
||||
this.totalTokens = tokenCount;
|
||||
@@ -1463,6 +1528,25 @@ class AppStore {
|
||||
this.tps = (tokenCount / elapsed) * 1000;
|
||||
}
|
||||
|
||||
// Extract logprobs for uncertainty visualization
|
||||
const logprobsData = parsed.choices?.[0]?.logprobs;
|
||||
if (logprobsData?.content?.[0]) {
|
||||
const logprobItem = logprobsData.content[0];
|
||||
const tokenData: TokenData = {
|
||||
token: logprobItem.token || tokenContent,
|
||||
logprob: logprobItem.logprob ?? 0,
|
||||
probability: Math.exp(logprobItem.logprob ?? 0),
|
||||
topLogprobs: (logprobItem.top_logprobs || []).map(
|
||||
(item: { token: string; logprob: number; bytes?: number[] }) => ({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
bytes: item.bytes,
|
||||
}),
|
||||
),
|
||||
};
|
||||
collectedTokens.push(tokenData);
|
||||
}
|
||||
|
||||
fullContent += tokenContent;
|
||||
|
||||
// Strip thinking tags for display and extract thinking content
|
||||
@@ -1477,6 +1561,8 @@ class AppStore {
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
// Update tokens during streaming for real-time visualization
|
||||
this.messages[idx].tokens = [...collectedTokens];
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
@@ -1524,6 +1610,10 @@ class AppStore {
|
||||
if (this.tps !== null) {
|
||||
this.messages[idx].tps = this.tps;
|
||||
}
|
||||
// Store token data for uncertainty visualization
|
||||
if (collectedTokens.length > 0) {
|
||||
this.messages[idx].tokens = collectedTokens;
|
||||
}
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} catch (error) {
|
||||
|
||||
@@ -23,7 +23,6 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -205,14 +205,6 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
logger.info("FAST_SYNCH forced ON")
|
||||
elif args.fast_synch is False:
|
||||
os.environ["EXO_FAST_SYNCH"] = "off"
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
logger.info("EXO Shutdown complete")
|
||||
@@ -226,7 +218,6 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -268,20 +259,6 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
action="store_true",
|
||||
dest="fast_synch",
|
||||
default=None,
|
||||
help="Force MLX FAST_SYNCH on (for JACCL backend)",
|
||||
)
|
||||
fast_synch_group.add_argument(
|
||||
"--no-fast-synch",
|
||||
action="store_false",
|
||||
dest="fast_synch",
|
||||
help="Force MLX FAST_SYNCH off",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
1
src/exo/master/adapters/__init__.py
Normal file
1
src/exo/master/adapters/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""
|
||||
184
src/exo/master/adapters/claude.py
Normal file
184
src/exo/master/adapters/claude.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
FinishReason,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
ClaudeContentBlockStartEvent,
|
||||
ClaudeContentBlockStopEvent,
|
||||
ClaudeMessageDelta,
|
||||
ClaudeMessageDeltaEvent,
|
||||
ClaudeMessageDeltaUsage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessagesResponse,
|
||||
ClaudeMessageStart,
|
||||
ClaudeMessageStartEvent,
|
||||
ClaudeMessageStopEvent,
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeUsage,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
|
||||
|
||||
def finish_reason_to_claude_stop_reason(
|
||||
finish_reason: FinishReason | None,
|
||||
) -> ClaudeStopReason | None:
|
||||
"""Map OpenAI finish_reason to Claude stop_reason."""
|
||||
if finish_reason is None:
|
||||
return None
|
||||
mapping: dict[FinishReason, ClaudeStopReason] = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
"content_filter": "end_turn",
|
||||
"function_call": "tool_use",
|
||||
}
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
|
||||
def claude_request_to_chat_params(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> ChatCompletionTaskParams:
|
||||
"""Convert Claude Messages API request to internal ChatCompletionTaskParams."""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
# Add system message if present
|
||||
if request.system:
|
||||
if isinstance(request.system, str):
|
||||
messages.append(
|
||||
ChatCompletionMessage(role="system", content=request.system)
|
||||
)
|
||||
else:
|
||||
# List of text blocks
|
||||
system_text = "".join(block.text for block in request.system)
|
||||
messages.append(ChatCompletionMessage(role="system", content=system_text))
|
||||
|
||||
# Convert messages
|
||||
for msg in request.messages:
|
||||
content: str
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
else:
|
||||
# Concatenate text blocks (images not supported for MVP)
|
||||
text_parts: list[str] = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
messages.append(ChatCompletionMessage(role=msg.role, content=content))
|
||||
|
||||
return ChatCompletionTaskParams(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
)
|
||||
|
||||
|
||||
def chat_response_to_claude_response(
|
||||
response: ChatCompletionResponse,
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Convert internal ChatCompletionResponse to Claude Messages API response."""
|
||||
content_text = ""
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
|
||||
if response.choices:
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
|
||||
content_text = (
|
||||
choice.message.content
|
||||
if isinstance(choice.message.content, str)
|
||||
else str(choice.message.content)
|
||||
)
|
||||
stop_reason = finish_reason_to_claude_stop_reason(choice.finish_reason)
|
||||
|
||||
# Use actual usage data from response if available
|
||||
input_tokens = response.usage.prompt_tokens if response.usage else 0
|
||||
output_tokens = response.usage.completion_tokens if response.usage else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
id=f"msg_{response.id}",
|
||||
model=response.model,
|
||||
content=[ClaudeTextBlock(text=content_text)],
|
||||
stop_reason=stop_reason,
|
||||
usage=ClaudeUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Claude Messages API streaming events from TokenChunks."""
|
||||
# Initial message_start event
|
||||
initial_message = ClaudeMessageStart(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=[],
|
||||
stop_reason=None,
|
||||
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
|
||||
)
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
|
||||
# content_block_stop
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
delta=ClaudeMessageDelta(stop_reason=stop_reason),
|
||||
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
|
||||
)
|
||||
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
|
||||
|
||||
# message_stop
|
||||
message_stop = ClaudeMessageStopEvent()
|
||||
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"
|
||||
199
src/exo/master/adapters/responses.py
Normal file
199
src/exo/master/adapters/responses.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""OpenAI Responses API adapter for converting requests/responses."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseMessageItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
|
||||
|
||||
def responses_request_to_chat_params(
|
||||
request: ResponsesRequest,
|
||||
) -> ChatCompletionTaskParams:
|
||||
"""Convert OpenAI Responses API request to internal ChatCompletionTaskParams."""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
# Add instructions as system message if present
|
||||
if request.instructions:
|
||||
messages.append(
|
||||
ChatCompletionMessage(role="system", content=request.instructions)
|
||||
)
|
||||
|
||||
# Convert input to messages
|
||||
if isinstance(request.input, str):
|
||||
messages.append(ChatCompletionMessage(role="user", content=request.input))
|
||||
else:
|
||||
for msg in request.input:
|
||||
messages.append(
|
||||
ChatCompletionMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
)
|
||||
)
|
||||
|
||||
return ChatCompletionTaskParams(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
max_tokens=request.max_output_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stream=request.stream,
|
||||
)
|
||||
|
||||
|
||||
def chat_response_to_responses_response(
|
||||
response: ChatCompletionResponse,
|
||||
) -> ResponsesResponse:
|
||||
"""Convert internal ChatCompletionResponse to OpenAI Responses API response."""
|
||||
output_text = ""
|
||||
|
||||
if response.choices:
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
|
||||
output_text = (
|
||||
choice.message.content
|
||||
if isinstance(choice.message.content, str)
|
||||
else str(choice.message.content)
|
||||
)
|
||||
|
||||
item_id = f"item_{response.id}"
|
||||
output_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=output_text)],
|
||||
)
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=response.usage.prompt_tokens,
|
||||
output_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return ResponsesResponse(
|
||||
id=f"resp_{response.id}",
|
||||
model=response.model,
|
||||
output=[output_item],
|
||||
output_text=output_text,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
|
||||
# response.created
|
||||
initial_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="in_progress",
|
||||
output=[],
|
||||
output_text="",
|
||||
)
|
||||
created_event = ResponseCreatedEvent(response=initial_response)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(response=initial_response)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
output_index=0, content_index=0, part=initial_part
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
|
||||
accumulated_text = ""
|
||||
last_stats = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
output_index=0, content_index=0, text=accumulated_text
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
output_index=0, content_index=0, part=final_part
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.done
|
||||
final_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=[final_item],
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
completed_event = ResponseCompletedEvent(response=final_response)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
@@ -1,20 +1,30 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
chat_response_to_claude_response,
|
||||
claude_request_to_chat_params,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.master.adapters.responses import (
|
||||
chat_response_to_responses_response,
|
||||
generate_responses_stream,
|
||||
responses_request_to_chat_params,
|
||||
)
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
@@ -30,10 +40,10 @@ from exo.shared.types.api import (
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -42,6 +52,10 @@ from exo.shared.types.api import (
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessagesResponse,
|
||||
)
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
@@ -57,9 +71,14 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
@@ -70,9 +89,35 @@ from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefillProgressData:
|
||||
"""Data class for prefill progress events."""
|
||||
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
# Union type for stream events
|
||||
StreamEvent = TokenChunk | PrefillProgressData
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
# Build logprobs if available
|
||||
logprobs: Logprobs | None = None
|
||||
if chunk.logprob is not None:
|
||||
logprobs = Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
bytes=list(chunk.text.encode("utf-8")),
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -81,6 +126,7 @@ def chunk_to_response(
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -123,7 +169,6 @@ class API:
|
||||
self.paused_ev: anyio.Event = anyio.Event()
|
||||
|
||||
self.app = FastAPI()
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
@@ -136,7 +181,7 @@ class API:
|
||||
name="dashboard",
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._chat_completion_queues: dict[CommandId, Sender[StreamEvent]] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
@@ -154,20 +199,6 @@ class API:
|
||||
self.paused_ev.set()
|
||||
self.paused_ev = anyio.Event()
|
||||
|
||||
def _setup_exception_handlers(self) -> None:
|
||||
@self.app.exception_handler(HTTPException)
|
||||
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
|
||||
_: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -191,6 +222,8 @@ class API:
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
|
||||
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -396,18 +429,18 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
async def _stream_events(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
) -> AsyncGenerator[StreamEvent, None]:
|
||||
"""Yield stream events (TokenChunks or PrefillProgressData) for a command."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
self._chat_completion_queues[command_id], recv = channel[StreamEvent]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
with recv as events:
|
||||
async for event in events:
|
||||
yield event
|
||||
if isinstance(event, TokenChunk) and event.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
@@ -423,33 +456,36 @@ class API:
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield only TokenChunks, filtering out progress events."""
|
||||
|
||||
async for event in self._stream_events(command_id):
|
||||
if isinstance(event, TokenChunk):
|
||||
yield event
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
async for event in self._stream_events(command_id):
|
||||
if isinstance(event, PrefillProgressData):
|
||||
# Send prefill progress as a named SSE event
|
||||
progress_json = f'{{"processed":{event.processed_tokens},"total":{event.total_tokens}}}'
|
||||
yield f"event: prefill_progress\ndata: {progress_json}\n\n"
|
||||
else:
|
||||
# TokenChunk - regular token generation
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
event, command_id
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
if event.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId
|
||||
@@ -461,12 +497,6 @@ class API:
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -504,12 +534,6 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -595,6 +619,75 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def claude_messages(
|
||||
self, payload: ClaudeMessagesRequest
|
||||
) -> ClaudeMessagesResponse | StreamingResponse:
|
||||
"""Handle Claude Messages API requests."""
|
||||
chat_params = claude_request_to_chat_params(payload)
|
||||
model_meta = await resolve_model_meta(chat_params.model)
|
||||
chat_params.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == chat_params.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(chat_params.model)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No instance found for model {chat_params.model}",
|
||||
)
|
||||
|
||||
command = ChatCompletion(request_params=chat_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_claude_stream(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._chat_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
response = await self._collect_chat_completion(command.command_id)
|
||||
return chat_response_to_claude_response(response)
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
) -> ResponsesResponse | StreamingResponse:
|
||||
"""Handle OpenAI Responses API requests."""
|
||||
chat_params = responses_request_to_chat_params(payload)
|
||||
|
||||
model_meta = await resolve_model_meta(chat_params.model)
|
||||
chat_params.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == chat_params.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(chat_params.model)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No instance found for model {chat_params.model}",
|
||||
)
|
||||
|
||||
command = ChatCompletion(request_params=chat_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_responses_stream(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._chat_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
response = await self._collect_chat_completion(command.command_id)
|
||||
return chat_response_to_responses_response(response)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
@@ -654,14 +747,24 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
elif (
|
||||
isinstance(event, PrefillProgress)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
PrefillProgressData(
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
from exo.master.api import API
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Setup exception handler
|
||||
api = object.__new__(API)
|
||||
api.app = app
|
||||
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Add test routes that raise HTTPException
|
||||
@app.get("/test-error")
|
||||
async def _test_error() -> None:
|
||||
raise HTTPException(status_code=500, detail="Test error message")
|
||||
|
||||
@app.get("/test-not-found")
|
||||
async def _test_not_found() -> None:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Test 500 error
|
||||
response = client.get("/test-error")
|
||||
assert response.status_code == 500
|
||||
data: dict[str, Any] = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Test error message"
|
||||
assert data["error"]["type"] == "Internal Server Error"
|
||||
assert data["error"]["code"] == 500
|
||||
|
||||
# Test 404 error
|
||||
response = client.get("/test-not-found")
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
392
src/exo/master/tests/test_claude_api.py
Normal file
392
src/exo/master/tests/test_claude_api.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Tests for Claude Messages API conversion functions and types."""
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
chat_response_to_claude_response,
|
||||
claude_request_to_chat_params,
|
||||
finish_reason_to_claude_stop_reason,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
ClaudeContentBlockStartEvent,
|
||||
ClaudeContentBlockStopEvent,
|
||||
ClaudeMessage,
|
||||
ClaudeMessageDelta,
|
||||
ClaudeMessageDeltaEvent,
|
||||
ClaudeMessageDeltaUsage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessageStart,
|
||||
ClaudeMessageStartEvent,
|
||||
ClaudeMessageStopEvent,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeUsage,
|
||||
)
|
||||
|
||||
|
||||
class TestFinishReasonToClaudeStopReason:
|
||||
"""Tests for finish_reason to Claude stop_reason mapping."""
|
||||
|
||||
def test_stop_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
|
||||
|
||||
def test_length_maps_to_max_tokens(self):
|
||||
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
|
||||
|
||||
def test_tool_calls_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
|
||||
|
||||
def test_function_call_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
|
||||
|
||||
def test_content_filter_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert finish_reason_to_claude_stop_reason(None) is None
|
||||
|
||||
|
||||
class TestClaudeRequestToChatParams:
|
||||
"""Tests for converting Claude Messages API requests to ChatCompletionTaskParams."""
|
||||
|
||||
def test_basic_request_conversion(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert params.model == "claude-3-opus"
|
||||
assert params.max_tokens == 100
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello"
|
||||
|
||||
def test_request_with_system_string(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
system="You are a helpful assistant.",
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are a helpful assistant."
|
||||
assert params.messages[1].role == "user"
|
||||
assert params.messages[1].content == "Hello"
|
||||
|
||||
def test_request_with_system_text_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
system=[
|
||||
ClaudeTextBlock(text="You are helpful. "),
|
||||
ClaudeTextBlock(text="Be concise."),
|
||||
],
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are helpful. Be concise."
|
||||
|
||||
def test_request_with_content_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(
|
||||
role="user",
|
||||
content=[
|
||||
ClaudeTextBlock(text="First part. "),
|
||||
ClaudeTextBlock(text="Second part."),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].content == "First part. Second part."
|
||||
|
||||
def test_request_with_multi_turn_conversation(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
ClaudeMessage(role="assistant", content="Hi there!"),
|
||||
ClaudeMessage(role="user", content="How are you?"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 3
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[1].role == "assistant"
|
||||
assert params.messages[2].role == "user"
|
||||
|
||||
def test_request_with_optional_parameters(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[ClaudeMessage(role="user", content="Hello")],
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
stop_sequences=["STOP", "END"],
|
||||
stream=True,
|
||||
)
|
||||
params = claude_request_to_chat_params(request)
|
||||
|
||||
assert params.temperature == 0.7
|
||||
assert params.top_p == 0.9
|
||||
assert params.top_k == 40
|
||||
assert params.stop == ["STOP", "END"]
|
||||
assert params.stream is True
|
||||
|
||||
|
||||
class TestChatResponseToClaudeResponse:
|
||||
"""Tests for converting ChatCompletionResponse to Claude Messages API response."""
|
||||
|
||||
def test_basic_response_conversion(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello! How can I help you?",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(prompt_tokens=10, completion_tokens=7, total_tokens=17),
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.id == "msg_chatcmpl-123"
|
||||
assert claude_response.model == "llama-3.2-1b"
|
||||
assert claude_response.role == "assistant"
|
||||
assert claude_response.type == "message"
|
||||
assert len(claude_response.content) == 1
|
||||
assert claude_response.content[0].type == "text"
|
||||
assert claude_response.content[0].text == "Hello! How can I help you?"
|
||||
assert claude_response.stop_reason == "end_turn"
|
||||
assert claude_response.usage.input_tokens == 10
|
||||
assert claude_response.usage.output_tokens == 7
|
||||
|
||||
def test_response_with_length_finish_reason(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content="Truncated..."
|
||||
),
|
||||
finish_reason="length",
|
||||
)
|
||||
],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.stop_reason == "max_tokens"
|
||||
|
||||
def test_response_with_empty_content(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(prompt_tokens=10, completion_tokens=0, total_tokens=10),
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == ""
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
def test_response_with_no_choices(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == ""
|
||||
assert claude_response.stop_reason is None
|
||||
assert claude_response.usage.input_tokens == 0
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
def test_response_without_usage(self):
|
||||
"""Test response conversion when usage data is not available."""
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
claude_response = chat_response_to_claude_response(response)
|
||||
|
||||
assert claude_response.content[0].text == "Hello!"
|
||||
assert claude_response.usage.input_tokens == 0
|
||||
assert claude_response.usage.output_tokens == 0
|
||||
|
||||
|
||||
class TestClaudeMessagesRequestValidation:
|
||||
"""Tests for Claude Messages API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"max_tokens": 100,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_max_tokens(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_messages(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestClaudeStreamingEvents:
|
||||
"""Tests for Claude Messages API streaming event serialization."""
|
||||
|
||||
def test_message_start_event_format(self):
|
||||
message = ClaudeMessageStart(
|
||||
id="msg_123",
|
||||
model="claude-3-opus",
|
||||
content=[],
|
||||
stop_reason=None,
|
||||
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
|
||||
)
|
||||
event = ClaudeMessageStartEvent(message=message)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_start"
|
||||
assert parsed["message"]["id"] == "msg_123"
|
||||
assert parsed["message"]["type"] == "message"
|
||||
assert parsed["message"]["role"] == "assistant"
|
||||
assert parsed["message"]["model"] == "claude-3-opus"
|
||||
|
||||
def test_content_block_start_event_format(self):
|
||||
event = ClaudeContentBlockStartEvent(
|
||||
index=0,
|
||||
content_block=ClaudeTextBlock(text=""),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_start"
|
||||
assert parsed["index"] == 0
|
||||
assert parsed["content_block"]["type"] == "text"
|
||||
assert parsed["content_block"]["text"] == ""
|
||||
|
||||
def test_content_block_delta_event_format(self):
|
||||
event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text="Hello"),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_delta"
|
||||
assert parsed["index"] == 0
|
||||
assert parsed["delta"]["type"] == "text_delta"
|
||||
assert parsed["delta"]["text"] == "Hello"
|
||||
|
||||
def test_content_block_stop_event_format(self):
|
||||
event = ClaudeContentBlockStopEvent(index=0)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "content_block_stop"
|
||||
assert parsed["index"] == 0
|
||||
|
||||
def test_message_delta_event_format(self):
|
||||
event = ClaudeMessageDeltaEvent(
|
||||
delta=ClaudeMessageDelta(stop_reason="end_turn"),
|
||||
usage=ClaudeMessageDeltaUsage(output_tokens=25),
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_delta"
|
||||
assert parsed["delta"]["stop_reason"] == "end_turn"
|
||||
assert parsed["usage"]["output_tokens"] == 25
|
||||
|
||||
def test_message_stop_event_format(self):
|
||||
event = ClaudeMessageStopEvent()
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "message_stop"
|
||||
|
||||
def test_sse_format(self):
|
||||
"""Test that SSE format is correctly generated."""
|
||||
event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text="Hello"),
|
||||
)
|
||||
# Simulate the SSE format used in the streaming generator
|
||||
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
|
||||
|
||||
assert sse_line.startswith("event: content_block_delta\n")
|
||||
assert "data: " in sse_line
|
||||
assert sse_line.endswith("\n\n")
|
||||
414
src/exo/master/tests/test_openai_responses_api.py
Normal file
414
src/exo/master/tests/test_openai_responses_api.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""Tests for OpenAI Responses API conversion functions and types."""
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.master.adapters.responses import (
|
||||
chat_response_to_responses_response,
|
||||
responses_request_to_chat_params,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInputMessage,
|
||||
ResponseMessageItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
)
|
||||
|
||||
|
||||
class TestResponsesRequestToChatParams:
|
||||
"""Tests for converting OpenAI Responses API requests to ChatCompletionTaskParams."""
|
||||
|
||||
def test_string_input_conversion(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello, how are you?",
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert params.model == "gpt-4o"
|
||||
assert len(params.messages) == 1
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello, how are you?"
|
||||
|
||||
def test_message_array_input_conversion(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
ResponseInputMessage(role="assistant", content="Hi there!"),
|
||||
ResponseInputMessage(role="user", content="How are you?"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 3
|
||||
assert params.messages[0].role == "user"
|
||||
assert params.messages[0].content == "Hello"
|
||||
assert params.messages[1].role == "assistant"
|
||||
assert params.messages[1].content == "Hi there!"
|
||||
assert params.messages[2].role == "user"
|
||||
assert params.messages[2].content == "How are you?"
|
||||
|
||||
def test_request_with_instructions(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
instructions="You are a helpful assistant. Be concise.",
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[0].content == "You are a helpful assistant. Be concise."
|
||||
assert params.messages[1].role == "user"
|
||||
assert params.messages[1].content == "Hello"
|
||||
|
||||
def test_request_with_optional_parameters(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
max_output_tokens=500,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
stream=True,
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert params.max_tokens == 500
|
||||
assert params.temperature == 0.8
|
||||
assert params.top_p == 0.95
|
||||
assert params.stream is True
|
||||
|
||||
def test_request_with_system_role_in_messages(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="system", content="Be helpful"),
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "system"
|
||||
assert params.messages[1].role == "user"
|
||||
|
||||
def test_request_with_developer_role(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[
|
||||
ResponseInputMessage(role="developer", content="Internal note"),
|
||||
ResponseInputMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = responses_request_to_chat_params(request)
|
||||
|
||||
assert len(params.messages) == 2
|
||||
assert params.messages[0].role == "developer"
|
||||
|
||||
|
||||
class TestChatResponseToResponsesResponse:
|
||||
"""Tests for converting ChatCompletionResponse to OpenAI Responses API response."""
|
||||
|
||||
def test_basic_response_conversion(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Hello! How can I help you?",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.id == "resp_chatcmpl-123"
|
||||
assert responses_response.object == "response"
|
||||
assert responses_response.model == "llama-3.2-1b"
|
||||
assert responses_response.status == "completed"
|
||||
assert responses_response.output_text == "Hello! How can I help you?"
|
||||
assert len(responses_response.output) == 1
|
||||
assert responses_response.output[0].type == "message"
|
||||
assert responses_response.output[0].role == "assistant"
|
||||
assert len(responses_response.output[0].content) == 1
|
||||
assert responses_response.output[0].content[0].type == "output_text"
|
||||
assert (
|
||||
responses_response.output[0].content[0].text == "Hello! How can I help you?"
|
||||
)
|
||||
|
||||
def test_response_with_usage(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.usage is not None
|
||||
assert responses_response.usage.input_tokens == 10
|
||||
assert responses_response.usage.output_tokens == 5
|
||||
assert responses_response.usage.total_tokens == 15
|
||||
|
||||
def test_response_with_empty_content(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output_text == ""
|
||||
assert responses_response.output[0].content[0].text == ""
|
||||
|
||||
def test_response_with_no_choices(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output_text == ""
|
||||
|
||||
def test_response_without_usage(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.usage is None
|
||||
|
||||
def test_response_item_id_format(self):
|
||||
response = ChatCompletionResponse(
|
||||
id="chatcmpl-abc123",
|
||||
created=1234567890,
|
||||
model="llama-3.2-1b",
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
responses_response = chat_response_to_responses_response(response)
|
||||
|
||||
assert responses_response.output[0].id == "item_chatcmpl-abc123"
|
||||
|
||||
|
||||
class TestResponsesRequestValidation:
|
||||
"""Tests for OpenAI Responses API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"input": "Hello",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_input(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_accepts_string_input(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input="Hello",
|
||||
)
|
||||
assert request.input == "Hello"
|
||||
|
||||
def test_request_accepts_message_array_input(self):
|
||||
request = ResponsesRequest(
|
||||
model="gpt-4o",
|
||||
input=[ResponseInputMessage(role="user", content="Hello")],
|
||||
)
|
||||
assert len(request.input) == 1
|
||||
|
||||
|
||||
class TestResponsesStreamingEvents:
|
||||
"""Tests for OpenAI Responses API streaming event serialization."""
|
||||
|
||||
def test_response_created_event_format(self):
|
||||
response = ResponsesResponse(
|
||||
id="resp_123",
|
||||
model="gpt-4o",
|
||||
status="in_progress",
|
||||
output=[],
|
||||
output_text="",
|
||||
)
|
||||
event = ResponseCreatedEvent(response=response)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.created"
|
||||
assert parsed["response"]["id"] == "resp_123"
|
||||
assert parsed["response"]["object"] == "response"
|
||||
assert parsed["response"]["status"] == "in_progress"
|
||||
|
||||
def test_output_item_added_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_item.added"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["item"]["type"] == "message"
|
||||
assert parsed["item"]["id"] == "item_123"
|
||||
assert parsed["item"]["role"] == "assistant"
|
||||
|
||||
def test_content_part_added_event_format(self):
|
||||
part = ResponseOutputText(text="")
|
||||
event = ResponseContentPartAddedEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=part,
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.content_part.added"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["content_index"] == 0
|
||||
assert parsed["part"]["type"] == "output_text"
|
||||
|
||||
def test_text_delta_event_format(self):
|
||||
event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta="Hello",
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_text.delta"
|
||||
assert parsed["output_index"] == 0
|
||||
assert parsed["content_index"] == 0
|
||||
assert parsed["delta"] == "Hello"
|
||||
|
||||
def test_text_done_event_format(self):
|
||||
event = ResponseTextDoneEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
text="Hello, world!",
|
||||
)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_text.done"
|
||||
assert parsed["text"] == "Hello, world!"
|
||||
|
||||
def test_output_item_done_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="Hello, world!")],
|
||||
status="completed",
|
||||
)
|
||||
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.output_item.done"
|
||||
assert parsed["item"]["status"] == "completed"
|
||||
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
|
||||
|
||||
def test_response_completed_event_format(self):
|
||||
item = ResponseMessageItem(
|
||||
id="item_123",
|
||||
content=[ResponseOutputText(text="Hello!")],
|
||||
status="completed",
|
||||
)
|
||||
response = ResponsesResponse(
|
||||
id="resp_123",
|
||||
model="gpt-4o",
|
||||
status="completed",
|
||||
output=[item],
|
||||
output_text="Hello!",
|
||||
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
|
||||
)
|
||||
event = ResponseCompletedEvent(response=response)
|
||||
json_str = event.model_dump_json()
|
||||
parsed = cast(dict[str, Any], json.loads(json_str))
|
||||
|
||||
assert parsed["type"] == "response.completed"
|
||||
assert parsed["response"]["status"] == "completed"
|
||||
assert parsed["response"]["output_text"] == "Hello!"
|
||||
assert parsed["response"]["usage"]["total_tokens"] == 15
|
||||
|
||||
def test_sse_format(self):
|
||||
"""Test that SSE format is correctly generated."""
|
||||
event = ResponseTextDeltaEvent(
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta="Hello",
|
||||
)
|
||||
# Simulate the SSE format used in the streaming generator
|
||||
sse_line = (
|
||||
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
|
||||
)
|
||||
|
||||
assert sse_line.startswith("event: response.output_text.delta\n")
|
||||
assert "data: " in sse_line
|
||||
assert sse_line.endswith("\n\n")
|
||||
@@ -16,6 +16,7 @@ from exo.shared.types.events import (
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -40,7 +41,7 @@ def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged()
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | PrefillProgress()
|
||||
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
|
||||
return state
|
||||
case InstanceCreated():
|
||||
|
||||
@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
|
||||
|
||||
def logger_setup(log_file: Path | None, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
|
||||
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
logger.remove()
|
||||
|
||||
# replace all stdlib loggers with _InterceptHandlers that log to loguru
|
||||
|
||||
@@ -11,21 +11,10 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
]
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
message: str
|
||||
type: str
|
||||
param: str | None = None
|
||||
code: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
@@ -157,6 +146,7 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
stream: bool = False
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.shared.types.api import GenerationStats, TopLogprobItem
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -20,9 +20,10 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
logprob: float | None = None # Log probability of the selected token
|
||||
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
168
src/exo/shared/types/claude_api.py
Normal file
168
src/exo/shared/types/claude_api.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Claude Messages API types for request/response conversion."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Type aliases
|
||||
ClaudeRole = Literal["user", "assistant"]
|
||||
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
|
||||
|
||||
|
||||
# Content block types
|
||||
class ClaudeTextBlock(BaseModel, frozen=True):
|
||||
"""Text content block in Claude Messages API."""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeImageSource(BaseModel, frozen=True):
|
||||
"""Image source for Claude image blocks."""
|
||||
|
||||
type: Literal["base64", "url"]
|
||||
media_type: str | None = None
|
||||
data: str | None = None
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class ClaudeImageBlock(BaseModel, frozen=True):
|
||||
"""Image content block in Claude Messages API."""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
source: ClaudeImageSource
|
||||
|
||||
|
||||
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
|
||||
|
||||
|
||||
# Request types
|
||||
class ClaudeMessage(BaseModel, frozen=True):
|
||||
"""Message in Claude Messages API request."""
|
||||
|
||||
role: ClaudeRole
|
||||
content: str | list[ClaudeContentBlock]
|
||||
|
||||
|
||||
class ClaudeMessagesRequest(BaseModel):
|
||||
"""Request body for Claude Messages API."""
|
||||
|
||||
model: str
|
||||
max_tokens: int
|
||||
messages: list[ClaudeMessage]
|
||||
system: str | list[ClaudeTextBlock] | None = None
|
||||
stop_sequences: list[str] | None = None
|
||||
stream: bool = False
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
class ClaudeUsage(BaseModel, frozen=True):
|
||||
"""Token usage in Claude Messages API response."""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
||||
|
||||
class ClaudeMessagesResponse(BaseModel, frozen=True):
|
||||
"""Response body for Claude Messages API."""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ClaudeTextBlock]
|
||||
model: str
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
usage: ClaudeUsage
|
||||
|
||||
|
||||
# Streaming event types
|
||||
class ClaudeMessageStart(BaseModel, frozen=True):
|
||||
"""Partial message in message_start event."""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ClaudeTextBlock] = Field(default_factory=list)
|
||||
model: str
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
usage: ClaudeUsage
|
||||
|
||||
|
||||
class ClaudeMessageStartEvent(BaseModel, frozen=True):
|
||||
"""Event sent at start of message stream."""
|
||||
|
||||
type: Literal["message_start"] = "message_start"
|
||||
message: ClaudeMessageStart
|
||||
|
||||
|
||||
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
|
||||
"""Event sent at start of a content block."""
|
||||
|
||||
type: Literal["content_block_start"] = "content_block_start"
|
||||
index: int
|
||||
content_block: ClaudeTextBlock
|
||||
|
||||
|
||||
class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
"""Delta for text content block."""
|
||||
|
||||
type: Literal["text_delta"] = "text_delta"
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for content block delta."""
|
||||
|
||||
type: Literal["content_block_delta"] = "content_block_delta"
|
||||
index: int
|
||||
delta: ClaudeTextDelta
|
||||
|
||||
|
||||
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
||||
"""Event sent at end of a content block."""
|
||||
|
||||
type: Literal["content_block_stop"] = "content_block_stop"
|
||||
index: int
|
||||
|
||||
|
||||
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
|
||||
"""Usage in message_delta event."""
|
||||
|
||||
output_tokens: int
|
||||
|
||||
|
||||
class ClaudeMessageDelta(BaseModel, frozen=True):
|
||||
"""Delta in message_delta event."""
|
||||
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
stop_sequence: str | None = None
|
||||
|
||||
|
||||
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent with final message delta."""
|
||||
|
||||
type: Literal["message_delta"] = "message_delta"
|
||||
delta: ClaudeMessageDelta
|
||||
usage: ClaudeMessageDeltaUsage
|
||||
|
||||
|
||||
class ClaudeMessageStopEvent(BaseModel, frozen=True):
|
||||
"""Event sent at end of message stream."""
|
||||
|
||||
type: Literal["message_stop"] = "message_stop"
|
||||
|
||||
|
||||
ClaudeStreamEvent = (
|
||||
ClaudeMessageStartEvent
|
||||
| ClaudeContentBlockStartEvent
|
||||
| ClaudeContentBlockDeltaEvent
|
||||
| ClaudeContentBlockStopEvent
|
||||
| ClaudeMessageDeltaEvent
|
||||
| ClaudeMessageStopEvent
|
||||
)
|
||||
@@ -106,6 +106,12 @@ class ChunkGenerated(BaseEvent):
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
edge: Connection
|
||||
|
||||
@@ -131,6 +137,7 @@ Event = (
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
)
|
||||
|
||||
162
src/exo/shared/types/openai_responses.py
Normal file
162
src/exo/shared/types/openai_responses.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""OpenAI Responses API types for request/response conversion."""
|
||||
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Type aliases
|
||||
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
|
||||
ResponseRole = Literal["user", "assistant", "system", "developer"]
|
||||
|
||||
|
||||
# Request types
|
||||
class ResponseInputMessage(BaseModel, frozen=True):
|
||||
"""Input message for Responses API."""
|
||||
|
||||
role: ResponseRole
|
||||
content: str
|
||||
|
||||
|
||||
class ResponsesRequest(BaseModel):
|
||||
"""Request body for OpenAI Responses API."""
|
||||
|
||||
model: str
|
||||
input: str | list[ResponseInputMessage]
|
||||
instructions: str | None = None
|
||||
max_output_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
stream: bool = False
|
||||
# previous_response_id not supported in MVP
|
||||
metadata: dict[str, str] | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
class ResponseOutputText(BaseModel, frozen=True):
|
||||
"""Text content in response output."""
|
||||
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
annotations: list[dict[str, str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ResponseMessageItem(BaseModel, frozen=True):
|
||||
"""Message item in response output array."""
|
||||
|
||||
type: Literal["message"] = "message"
|
||||
id: str
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[ResponseOutputText]
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel, frozen=True):
|
||||
"""Token usage in Responses API response."""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ResponsesResponse(BaseModel, frozen=True):
|
||||
"""Response body for OpenAI Responses API."""
|
||||
|
||||
id: str
|
||||
object: Literal["response"] = "response"
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
status: ResponseStatus = "completed"
|
||||
model: str
|
||||
output: list[ResponseItem]
|
||||
output_text: str
|
||||
usage: ResponseUsage | None = None
|
||||
|
||||
|
||||
# Streaming event types
|
||||
class ResponseCreatedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is created."""
|
||||
|
||||
type: Literal["response.created"] = "response.created"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
class ResponseInProgressEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response starts processing."""
|
||||
|
||||
type: Literal["response.in_progress"] = "response.in_progress"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when an output item is added."""
|
||||
|
||||
type: Literal["response.output_item.added"] = "response.output_item.added"
|
||||
output_index: int
|
||||
item: ResponseItem
|
||||
|
||||
|
||||
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a content part is added."""
|
||||
|
||||
type: Literal["response.content_part.added"] = "response.content_part.added"
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ResponseOutputText
|
||||
|
||||
|
||||
class ResponseTextDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for text delta during streaming."""
|
||||
|
||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseTextDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when text content is done."""
|
||||
|
||||
type: Literal["response.output_text.done"] = "response.output_text.done"
|
||||
output_index: int
|
||||
content_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a content part is done."""
|
||||
|
||||
type: Literal["response.content_part.done"] = "response.content_part.done"
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ResponseOutputText
|
||||
|
||||
|
||||
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when an output item is done."""
|
||||
|
||||
type: Literal["response.output_item.done"] = "response.output_item.done"
|
||||
output_index: int
|
||||
item: ResponseItem
|
||||
|
||||
|
||||
class ResponseCompletedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is completed."""
|
||||
|
||||
type: Literal["response.completed"] = "response.completed"
|
||||
response: ResponsesResponse
|
||||
|
||||
|
||||
ResponsesStreamEvent = (
|
||||
ResponseCreatedEvent
|
||||
| ResponseInProgressEvent
|
||||
| ResponseOutputItemAddedEvent
|
||||
| ResponseContentPartAddedEvent
|
||||
| ResponseTextDeltaEvent
|
||||
| ResponseTextDoneEvent
|
||||
| ResponseContentPartDoneEvent
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseCompletedEvent
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, TopLogprobItem
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -13,10 +13,16 @@ class TokenizedResponse(BaseRunnerResponse):
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
logprob: float | None = None # Log probability of the selected token
|
||||
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
|
||||
class PrefillProgressResponse(BaseRunnerResponse):
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
@@ -12,6 +12,7 @@ from exo.shared.types.api import (
|
||||
ChatCompletionMessage,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
@@ -81,7 +82,7 @@ def warmup_inference(
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
prefill_step_size=256, # Temporarily reduced from 2048 for testing progress bar
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
@@ -115,10 +116,65 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs: mx.array,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_k: int,
|
||||
selected_token: int,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top-k alternative tokens.
|
||||
|
||||
Args:
|
||||
logprobs: Full vocabulary logprobs array from MLX
|
||||
tokenizer: Tokenizer for decoding token IDs to strings
|
||||
top_k: Number of top alternatives to return
|
||||
selected_token: The token ID that was actually sampled
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_token_logprob, list of TopLogprobItem for top-k tokens)
|
||||
"""
|
||||
# Get the logprob of the selected token
|
||||
selected_logprob = float(logprobs[selected_token].item())
|
||||
|
||||
# Get top-k indices (most probable tokens)
|
||||
# mx.argpartition gives indices that would partition the array
|
||||
# We negate logprobs since argpartition finds smallest, and we want largest
|
||||
top_k = min(top_k, logprobs.shape[0]) # Don't exceed vocab size
|
||||
top_indices = mx.argpartition(-logprobs, top_k)[:top_k]
|
||||
|
||||
# Get the actual logprob values for these indices
|
||||
top_values = logprobs[top_indices]
|
||||
|
||||
# Sort by logprob (descending) for consistent ordering
|
||||
sort_order = mx.argsort(-top_values)
|
||||
top_indices = top_indices[sort_order]
|
||||
top_values = top_values[sort_order]
|
||||
|
||||
# Convert to list of TopLogprobItem
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for i in range(top_k):
|
||||
token_id = int(top_indices[i].item())
|
||||
token_logprob = float(top_values[i].item())
|
||||
# Decode token ID to string
|
||||
token_str = tokenizer.decode([token_id])
|
||||
# Get byte representation
|
||||
token_bytes = list(token_str.encode("utf-8"))
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=token_logprob,
|
||||
bytes=token_bytes,
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -146,9 +202,24 @@ def mlx_generate(
|
||||
sampler = make_sampler(
|
||||
temp=task.temperature if task.temperature is not None else 0.7,
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
top_k=task.top_k if task.top_k is not None else 0,
|
||||
)
|
||||
|
||||
# Normalize stop sequences to a list
|
||||
stop_sequences: list[str] = (
|
||||
([task.stop] if isinstance(task.stop, str) else task.stop)
|
||||
if task.stop is not None
|
||||
else []
|
||||
)
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
|
||||
# Determine if we need to extract logprobs
|
||||
should_extract_logprobs = task.logprobs is True
|
||||
num_top_logprobs = task.top_logprobs if task.top_logprobs is not None else 5
|
||||
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -158,14 +229,47 @@ def mlx_generate(
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
prefill_step_size=256, # Temporarily reduced from 2048 for testing progress bar
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=on_prefill_progress,
|
||||
):
|
||||
logger.info(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
# Check for stop sequences
|
||||
text = out.text
|
||||
finish_reason: FinishReason | None = cast(
|
||||
FinishReason | None, out.finish_reason
|
||||
)
|
||||
stop_matched = False
|
||||
|
||||
if stop_sequences:
|
||||
for stop_seq in stop_sequences:
|
||||
if stop_seq in accumulated_text:
|
||||
# Trim text to just before the stop sequence
|
||||
stop_index = accumulated_text.find(stop_seq)
|
||||
text_before_stop = accumulated_text[:stop_index]
|
||||
chunk_start = len(accumulated_text) - len(out.text)
|
||||
text = text_before_stop[chunk_start:]
|
||||
finish_reason = "stop"
|
||||
stop_matched = True
|
||||
break
|
||||
|
||||
# Extract logprobs if requested
|
||||
token_logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if should_extract_logprobs:
|
||||
token_logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs=out.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_k=num_top_logprobs,
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
is_done = finish_reason is not None
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
if is_done:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
@@ -173,22 +277,25 @@ def mlx_generate(
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
if not stop_matched and out.finish_reason not in get_args(FinishReason):
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
text=text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
logprob=token_logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
if is_done:
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -2,9 +2,7 @@ import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -84,45 +82,6 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -229,9 +188,7 @@ def initialize_mlx(
|
||||
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
@@ -245,9 +202,7 @@ def load_mlx_items(
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
@@ -261,7 +216,6 @@ def load_mlx_items(
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
@@ -298,15 +252,7 @@ def shard_and_load(
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
@@ -17,23 +17,15 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
fast_synch_override != "off"
|
||||
and (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
)
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -15,15 +13,14 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -52,7 +49,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
@@ -62,33 +58,6 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def send_error_chunk_on_exception(
|
||||
event_sender: MpSender[Event],
|
||||
command_id: CommandId,
|
||||
model_id: ModelId,
|
||||
device_rank: int,
|
||||
):
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
@@ -150,20 +119,7 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
@@ -180,7 +136,7 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=cast(Model, model),
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
@@ -193,6 +149,8 @@ def main(
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -201,47 +159,55 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender,
|
||||
command_id,
|
||||
shard_metadata.model_meta.model_id,
|
||||
shard_metadata.device_rank,
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
)
|
||||
# Define callback to send prefill progress events directly
|
||||
def on_prefill_progress(processed: int, total: int) -> None:
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
PrefillProgress(
|
||||
command_id=command_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
)
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
on_prefill_progress=on_prefill_progress,
|
||||
)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
# pyright: reportAny=false
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated
|
||||
from exo.worker.runner.runner import send_error_chunk_on_exception
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_no_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
_ = 1 + 1
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_catches_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_called_once()
|
||||
call_args = event_sender.send.call_args[0][0]
|
||||
assert isinstance(call_args, ChunkGenerated)
|
||||
assert call_args.command_id == command_id
|
||||
assert isinstance(call_args.chunk, TokenChunk)
|
||||
assert call_args.chunk.finish_reason == "error"
|
||||
assert call_args.chunk.error_message == "test error"
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=1
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
@@ -1,64 +1,62 @@
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio import create_task_group
|
||||
import http.client
|
||||
import time
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
REACHABILITY_ATTEMPTS = 3
|
||||
BAD_STATUSLINE_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
if ":" in target_ip:
|
||||
# TODO: use real IpAddress types
|
||||
target_ip = f"[{target_ip}]"
|
||||
url = f"http://{target_ip}:52415/node_id"
|
||||
|
||||
remote_node_id = None
|
||||
last_error = None
|
||||
|
||||
for _ in range(REACHABILITY_ATTEMPTS):
|
||||
# TODO: use an async http client
|
||||
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
|
||||
try:
|
||||
r = await client.get(url)
|
||||
if r.status_code != 200:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = r.text.strip().strip('"')
|
||||
if not body:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
remote_node_id = NodeId(body)
|
||||
break
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
# expected failure cases
|
||||
except (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
):
|
||||
await anyio.sleep(1)
|
||||
|
||||
# other failures should be logged on last attempt
|
||||
except httpx.HTTPError as e:
|
||||
last_error = e
|
||||
await anyio.sleep(1)
|
||||
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.BadStatusLine:
|
||||
if attempt >= BAD_STATUSLINE_ATTEMPTS:
|
||||
logger.warning(
|
||||
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
|
||||
)
|
||||
return None
|
||||
time.sleep(1)
|
||||
return _fetch_remote_node_id(attempt=attempt + 1)
|
||||
except http.client.HTTPException as e:
|
||||
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
@@ -76,33 +74,18 @@ async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
|
||||
# these are intentionally httpx's defaults so we can tune them later
|
||||
timeout = httpx.Timeout(timeout=5.0)
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=5,
|
||||
)
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
continue
|
||||
if node.node_id == self_node_id:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
client,
|
||||
)
|
||||
|
||||
return reachable
|
||||
|
||||
Reference in New Issue
Block a user