Compare commits

..

3 Commits

Author SHA1 Message Date
Alex Cheema
6eb8f9d9f5 feat: add prefill progress bar for long prompts
Shows real-time progress during prompt processing (prefill phase).
Progress is sent via SSE named events that maintain OpenAI API compatibility.

- Add PrefillProgress event type
- Wire prompt_progress_callback through MLX stream_generate
- Send progress events directly from callback for real-time updates
- Add PrefillProgressBar.svelte component
- Parse event: prefill_progress SSE events in dashboard

Note: prefill_step_size temporarily set to 256 for testing (normally 2048)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 17:12:59 +00:00
Alex Cheema
663a0faaeb feat: add uncertainty visualization with token-level logprobs
Wire log probabilities from MLX through the API to enable uncertainty
visualization in the dashboard:

Backend:
- Extract top-k logprobs from MLX stream_generate output
- Add logprob and top_logprobs fields to GenerationResponse and TokenChunk
- Populate Logprobs in streaming API response when requested

Dashboard:
- Add TokenHeatmap component with color-coded token confidence
- Parse logprobs from SSE responses and store on messages
- Add toggle button to switch between normal and uncertainty view
- Hover tooltip shows exact probability and top-5 alternatives

Color scheme:
- Green (>80%): High confidence
- Yellow (50-80%): Medium confidence
- Orange (20-50%): Low confidence
- Red (<20%): Very low confidence

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 16:45:23 +00:00
Alex Cheema
0a58aa73ec feat: add Claude Messages API and OpenAI Responses API support
Adds two new API endpoints that wrap the existing chat completions:

- /v1/messages - Claude Messages API compatible endpoint
- /v1/responses - OpenAI Responses API compatible endpoint

Both support streaming (SSE) and non-streaming modes with proper
token usage reporting from actual inference stats.

Also adds top_k sampling parameter and stop sequence support to the
MLX inference engine.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 15:14:21 +00:00
65 changed files with 4565 additions and 2413 deletions

View File

@@ -19,7 +19,6 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -56,11 +56,6 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
let result = await self.performCheck()
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:

View File

@@ -6,7 +6,7 @@ enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
@@ -28,6 +28,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

View File

@@ -241,9 +241,6 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import {
messages,
currentResponse,
import {
messages,
currentResponse,
isLoading,
deleteMessage,
editAndRegenerate,
@@ -9,6 +9,8 @@
} from '$lib/stores/app.svelte';
import type { MessageAttachment } from '$lib/stores/app.svelte';
import MarkdownContent from './MarkdownContent.svelte';
import TokenHeatmap from './TokenHeatmap.svelte';
import PrefillProgressBar from './PrefillProgressBar.svelte';
interface Props {
class?: string;
@@ -95,6 +97,23 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty view state - tracks which messages show token heatmap
let uncertaintyViewMessageIds = $state<Set<string>>(new Set());
function toggleUncertaintyView(messageId: string) {
const newSet = new Set(uncertaintyViewMessageIds);
if (newSet.has(messageId)) {
newSet.delete(messageId);
} else {
newSet.add(messageId);
}
uncertaintyViewMessageIds = newSet;
}
function isUncertaintyViewEnabled(messageId: string): boolean {
return uncertaintyViewMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString('en-US', {
hour12: false,
@@ -330,6 +349,10 @@ function isThinkingExpanded(messageId: string): boolean {
{:else}
<!-- Assistant message styling -->
<div class="p-3 sm:p-4">
{#if message.prefillProgress}
<!-- Prefill progress bar -->
<PrefillProgressBar progress={message.prefillProgress} class="mb-3" />
{/if}
{#if message.thinking && message.thinking.trim().length > 0}
<div class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40">
<button
@@ -366,7 +389,13 @@ function isThinkingExpanded(messageId: string): boolean {
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if message.role === 'assistant' && isUncertaintyViewEnabled(message.id) && message.tokens && message.tokens.length > 0}
<!-- Uncertainty heatmap view -->
<TokenHeatmap tokens={message.tokens} />
{:else}
<!-- Normal markdown view -->
<MarkdownContent content={message.content || (loading ? response : '')} />
{/if}
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
@@ -419,6 +448,19 @@ function isThinkingExpanded(messageId: string): boolean {
</svg>
</button>
{/if}
<!-- Uncertainty view toggle (assistant messages with tokens only) -->
{#if message.role === 'assistant' && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleUncertaintyView(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isUncertaintyViewEnabled(message.id) ? 'text-exo-yellow' : 'text-exo-light-gray hover:text-exo-yellow'}"
title={isUncertaintyViewEnabled(message.id) ? 'Hide uncertainty' : 'Show uncertainty'}
>
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
</svg>
</button>
{/if}
<!-- Delete button -->
<button

View File

@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
// Uses API preview data when available, falls back to local estimation
const placementPreview = $derived(() => {
const nodeArray = nodeList();
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
const numNodes = nodeArray.length;
const iconSize = numNodes === 1 ? 50 : 36;

View File

@@ -0,0 +1,67 @@
<script lang="ts">
import type { PrefillProgress } from '$lib/stores/app.svelte';
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = '' }: Props = $props();
const percentage = $derived(
progress.total > 0 ? Math.round((progress.processed / progress.total) * 100) : 0
);
function formatTokenCount(count: number): string {
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div class="flex items-center justify-between text-xs text-gray-400 mb-1">
<span class="flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5 animate-spin"
fill="none"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
>
<circle
class="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
stroke-width="4"
></circle>
<path
class="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
></path>
</svg>
<span>Processing prompt</span>
</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(progress.total)} tokens
</span>
</div>
<div class="h-1.5 bg-gray-700 rounded-full overflow-hidden">
<div
class="h-full bg-blue-500 rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-gray-500 mt-0.5">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -0,0 +1,121 @@
<script lang="ts">
import type { TokenData } from '$lib/stores/app.svelte';
interface Props {
tokens: TokenData[];
class?: string;
}
let { tokens, class: className = '' }: Props = $props();
// Tooltip state
let hoveredToken = $state<{ token: TokenData; x: number; y: number } | null>(null);
/**
* Get confidence level based on probability
* High: >0.8 (logprob > -0.22)
* Medium: 0.5-0.8 (logprob -0.69 to -0.22)
* Low: 0.2-0.5 (logprob -1.61 to -0.69)
* Very Low: <0.2 (logprob < -1.61)
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return 'bg-green-500/30 text-green-100';
if (probability > 0.5) return 'bg-yellow-500/30 text-yellow-100';
if (probability > 0.2) return 'bg-orange-500/30 text-orange-100';
return 'bg-red-500/40 text-red-100';
}
/**
* Get border color for token based on probability
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return 'border-green-500/50';
if (probability > 0.5) return 'border-yellow-500/50';
if (probability > 0.2) return 'border-orange-500/50';
return 'border-red-500/50';
}
function handleMouseEnter(event: MouseEvent, token: TokenData) {
const rect = (event.target as HTMLElement).getBoundingClientRect();
hoveredToken = {
token,
x: rect.left + rect.width / 2,
y: rect.top - 10
};
}
function handleMouseLeave() {
hoveredToken = null;
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + '%';
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(tokenData.probability)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData)}
onmouseleave={handleMouseLeave}
>{tokenData.token}</span>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50 pointer-events-none"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
>
<div class="bg-gray-900 border border-gray-700 rounded-lg shadow-xl p-3 text-sm min-w-48">
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-400 text-xs">Token:</span>
<span class="text-white font-mono ml-1">"{hoveredToken.token.token}"</span>
<span class="text-green-400 ml-2">{formatProbability(hoveredToken.token.probability)}</span>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono">{formatLogprob(hoveredToken.token.logprob)}</span>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700 mt-2 pt-2">
<div class="text-gray-400 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24">"{alt.token}"</span>
<span class="text-gray-400 ml-2">{formatProbability(altProb)}</span>
</div>
{/each}
</div>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: NodeInfo | undefined): string | null {
function checkNode(node: typeof data.nodes[string]): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
@@ -39,19 +39,17 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
return matchFromInterfaces.name;
}
if (node.ip_to_interface) {
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
if (mapped && mapped.trim().length > 0) {
return mapped;
}
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
}
return null;
}
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
@@ -257,24 +255,21 @@ function wrapLine(text: string, maxLen: number): string[] {
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
const pairMap = new Map<string, PairEntry>();
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
if (edge.source === a) entry.aToB = true;
else entry.bToA = true;
const ip = edge.sendBackIp || '?';
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
const ifaceInfo = getInterfaceLabel(edge.source, ip);
entry.connections.push({
from: edge.source,
@@ -343,8 +338,9 @@ function wrapLine(text: string, maxLen: number): string[] {
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
@@ -385,32 +381,32 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
const quadrants: Record<string, typeof debugEdgeLabels> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
if (quadrantEdges.length === 0) return;
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
quadrantEdges.forEach(edge => {
edges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;

View File

@@ -99,36 +99,20 @@ interface RawNodeProfile {
interface RawTopologyNode {
nodeId: string;
nodeProfile?: RawNodeProfile;
nodeProfile: RawNodeProfile;
}
// New connection edge types from Python SocketConnection/RDMAConnection
interface RawSocketConnection {
sinkMultiaddr?: {
address?: string;
// Multiaddr uses snake_case (no camelCase alias)
ip_address?: string;
ipAddress?: string; // fallback in case it changes
address_type?: string;
port?: number;
};
interface RawTopologyConnection {
localNodeId: string;
sendBackNodeId: string;
sendBackMultiaddr?:
| { multiaddr?: string; address?: string; ip_address?: string }
| string;
}
interface RawRDMAConnection {
sourceRdmaIface?: string;
sinkRdmaIface?: string;
}
type RawConnectionEdge = RawSocketConnection | RawRDMAConnection;
// New nested mapping format: { source: { sink: [edge1, edge2, ...] } }
type RawConnectionsMap = Record<string, Record<string, RawConnectionEdge[]>>;
interface RawTopology {
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
// New nested mapping format
connections?: RawConnectionsMap;
nodes: RawTopologyNode[];
connections?: RawTopologyConnection[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
@@ -198,6 +182,26 @@ export interface MessageAttachment {
mimeType?: string;
}
// Token-level data for uncertainty visualization
export interface TopLogprob {
token: string;
logprob: number;
bytes?: number[];
}
export interface TokenData {
token: string;
logprob: number;
probability: number; // exp(logprob)
topLogprobs: TopLogprob[];
}
// Prefill progress data for long prompts
export interface PrefillProgress {
processed: number;
total: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -207,6 +211,8 @@ export interface Message {
attachments?: MessageAttachment[];
ttftMs?: number; // Time to first token in ms (for assistant messages)
tps?: number; // Tokens per second (for assistant messages)
tokens?: TokenData[]; // Token-level data for uncertainty visualization
prefillProgress?: PrefillProgress | null; // Prefill progress for long prompts
}
export interface Conversation {
@@ -229,18 +235,9 @@ function transformTopology(
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === "string" ? node : node.nodeId;
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode =
typeof node === "object" ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
@@ -289,7 +286,7 @@ function transformTopology(
}
}
nodes[nodeId] = {
nodes[node.nodeId] = {
system_info: {
model_id: profile?.modelId ?? "Unknown",
chip: profile?.chipId,
@@ -317,34 +314,29 @@ function transformTopology(
};
}
// Handle connections - nested mapping format { source: { sink: [edges] } }
const connections = raw.connections;
if (connections && typeof connections === "object") {
for (const [source, sinks] of Object.entries(connections)) {
if (!sinks || typeof sinks !== "object") continue;
for (const [sink, edgeList] of Object.entries(sinks)) {
if (!Array.isArray(edgeList)) continue;
for (const edge of edgeList) {
// Extract IP from SocketConnection (uses snake_case: ip_address)
let sendBackIp: string | undefined;
if (edge && typeof edge === "object" && "sinkMultiaddr" in edge) {
const multiaddr = edge.sinkMultiaddr;
if (multiaddr) {
// Try both snake_case (actual) and camelCase (in case it changes)
sendBackIp =
multiaddr.ip_address ||
multiaddr.ipAddress ||
extractIpFromMultiaddr(multiaddr.address);
}
}
// RDMAConnection (sourceRdmaIface/sinkRdmaIface) has no IP - edge just shows connection exists
for (const conn of raw.connections || []) {
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
if (conn.localNodeId === conn.sendBackNodeId) continue;
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
if (nodes[source] && nodes[sink] && source !== sink) {
edges.push({ source, target: sink, sendBackIp });
}
}
let sendBackIp: string | undefined;
if (conn.sendBackMultiaddr) {
const multi = conn.sendBackMultiaddr;
if (typeof multi === "string") {
sendBackIp = extractIpFromMultiaddr(multi);
} else {
sendBackIp =
multi.ip_address ||
extractIpFromMultiaddr(multi.multiaddr) ||
extractIpFromMultiaddr(multi.address);
}
}
edges.push({
source: conn.localNodeId,
target: conn.sendBackNodeId,
sendBackIp,
});
}
return { nodes, edges };
@@ -1137,6 +1129,8 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1438,6 +1432,8 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1454,6 +1450,8 @@ class AppStore {
const decoder = new TextDecoder();
let fullContent = "";
let buffer = "";
const collectedTokens: TokenData[] = [];
let currentEventType = ""; // Track SSE event type
while (true) {
const { done, value } = await reader.read();
@@ -1467,14 +1465,43 @@ class AppStore {
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
if (!trimmed) {
// Empty line resets event type
currentEventType = "";
continue;
}
// Handle event type declaration
if (trimmed.startsWith("event: ")) {
currentEventType = trimmed.slice(7);
continue;
}
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
if (data === "[DONE]") {
currentEventType = "";
continue;
}
try {
const parsed = JSON.parse(data);
// Handle prefill progress events
if (currentEventType === "prefill_progress") {
const idx = this.messages.findIndex(
(m) => m.id === assistantMessage.id,
);
if (idx !== -1) {
this.messages[idx].prefillProgress = {
processed: parsed.processed,
total: parsed.total,
};
}
continue;
}
// Handle regular token data
const tokenContent = parsed.choices?.[0]?.delta?.content;
if (tokenContent) {
// Track first token for TTFT
@@ -1483,6 +1510,14 @@ class AppStore {
this.ttftMs = firstTokenTime - requestStartTime;
}
// Clear prefill progress when first token arrives
const msgIdx = this.messages.findIndex(
(m) => m.id === assistantMessage.id,
);
if (msgIdx !== -1 && this.messages[msgIdx].prefillProgress) {
this.messages[msgIdx].prefillProgress = null;
}
// Count tokens (each SSE chunk is typically one token)
tokenCount += 1;
this.totalTokens = tokenCount;
@@ -1493,6 +1528,25 @@ class AppStore {
this.tps = (tokenCount / elapsed) * 1000;
}
// Extract logprobs for uncertainty visualization
const logprobsData = parsed.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
const tokenData: TokenData = {
token: logprobItem.token || tokenContent,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: { token: string; logprob: number; bytes?: number[] }) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
};
collectedTokens.push(tokenData);
}
fullContent += tokenContent;
// Strip thinking tags for display and extract thinking content
@@ -1507,6 +1561,8 @@ class AppStore {
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
// Update tokens during streaming for real-time visualization
this.messages[idx].tokens = [...collectedTokens];
}
this.persistActiveConversation();
}
@@ -1554,6 +1610,10 @@ class AppStore {
if (this.tps !== null) {
this.messages[idx].tps = this.tps;
}
// Store token data for uncertainty visualization
if (collectedTokens.length > 0) {
this.messages[idx].tokens = collectedTokens;
}
}
this.persistActiveConversation();
} catch (error) {

View File

@@ -915,7 +915,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
const [tag, shard] = getTagged(shardWrapped);
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
return { runnerId, tag, deviceRank };
});

View File

@@ -23,7 +23,6 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
]
[project.scripts]

View File

@@ -0,0 +1 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -0,0 +1,184 @@
"""Claude Messages API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
FinishReason,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.tasks import ChatCompletionTaskParams
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def claude_request_to_chat_params(
request: ClaudeMessagesRequest,
) -> ChatCompletionTaskParams:
"""Convert Claude Messages API request to internal ChatCompletionTaskParams."""
messages: list[ChatCompletionMessage] = []
# Add system message if present
if request.system:
if isinstance(request.system, str):
messages.append(
ChatCompletionMessage(role="system", content=request.system)
)
else:
# List of text blocks
system_text = "".join(block.text for block in request.system)
messages.append(ChatCompletionMessage(role="system", content=system_text))
# Convert messages
for msg in request.messages:
content: str
if isinstance(msg.content, str):
content = msg.content
else:
# Concatenate text blocks (images not supported for MVP)
text_parts: list[str] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
content = "".join(text_parts)
messages.append(ChatCompletionMessage(role=msg.role, content=content))
return ChatCompletionTaskParams(
model=request.model,
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
)
def chat_response_to_claude_response(
response: ChatCompletionResponse,
) -> ClaudeMessagesResponse:
"""Convert internal ChatCompletionResponse to Claude Messages API response."""
content_text = ""
stop_reason: ClaudeStopReason | None = None
if response.choices:
choice = response.choices[0]
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
content_text = (
choice.message.content
if isinstance(choice.message.content, str)
else str(choice.message.content)
)
stop_reason = finish_reason_to_claude_stop_reason(choice.finish_reason)
# Use actual usage data from response if available
input_tokens = response.usage.prompt_tokens if response.usage else 0
output_tokens = response.usage.completion_tokens if response.usage else 0
return ClaudeMessagesResponse(
id=f"msg_{response.id}",
model=response.model,
content=[ClaudeTextBlock(text=content_text)],
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
async for chunk in chunk_stream:
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -0,0 +1,199 @@
"""OpenAI Responses API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
from exo.shared.types.tasks import ChatCompletionTaskParams
def responses_request_to_chat_params(
request: ResponsesRequest,
) -> ChatCompletionTaskParams:
"""Convert OpenAI Responses API request to internal ChatCompletionTaskParams."""
messages: list[ChatCompletionMessage] = []
# Add instructions as system message if present
if request.instructions:
messages.append(
ChatCompletionMessage(role="system", content=request.instructions)
)
# Convert input to messages
if isinstance(request.input, str):
messages.append(ChatCompletionMessage(role="user", content=request.input))
else:
for msg in request.input:
messages.append(
ChatCompletionMessage(
role=msg.role,
content=msg.content,
)
)
return ChatCompletionTaskParams(
model=request.model,
messages=messages,
max_tokens=request.max_output_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=request.stream,
)
def chat_response_to_responses_response(
response: ChatCompletionResponse,
) -> ResponsesResponse:
"""Convert internal ChatCompletionResponse to OpenAI Responses API response."""
output_text = ""
if response.choices:
choice = response.choices[0]
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
output_text = (
choice.message.content
if isinstance(choice.message.content, str)
else str(choice.message.content)
)
item_id = f"item_{response.id}"
output_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=output_text)],
)
usage = None
if response.usage:
usage = ResponseUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ResponsesResponse(
id=f"resp_{response.id}",
model=response.model,
output=[output_item],
output_text=output_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(response=initial_response)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(response=initial_response)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
output_index=0, content_index=0, part=initial_part
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
last_stats = None
async for chunk in chunk_stream:
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
output_index=0, content_index=0, text=accumulated_text
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
output_index=0, content_index=0, part=final_part
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[final_item],
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(response=final_response)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -1,5 +1,6 @@
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import cast
import anyio
@@ -14,6 +15,16 @@ from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from exo.master.adapters.claude import (
chat_response_to_claude_response,
claude_request_to_chat_params,
generate_claude_stream,
)
from exo.master.adapters.responses import (
chat_response_to_responses_response,
generate_responses_stream,
responses_request_to_chat_params,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
@@ -31,6 +42,8 @@ from exo.shared.types.api import (
DeleteInstanceResponse,
FinishReason,
GenerationStats,
Logprobs,
LogprobsContentItem,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -39,6 +52,10 @@ from exo.shared.types.api import (
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeMessagesRequest,
ClaudeMessagesResponse,
)
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -49,9 +66,19 @@ from exo.shared.types.commands import (
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -62,9 +89,35 @@ from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
@dataclass
class PrefillProgressData:
"""Data class for prefill progress events."""
processed_tokens: int
total_tokens: int
# Union type for stream events
StreamEvent = TokenChunk | PrefillProgressData
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
bytes=list(chunk.text.encode("utf-8")),
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -73,6 +126,7 @@ def chunk_to_response(
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
@@ -127,7 +181,7 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._chat_completion_queues: dict[CommandId, Sender[StreamEvent]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -168,6 +222,8 @@ class API:
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -228,7 +284,6 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -284,7 +339,6 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -375,18 +429,18 @@ class API:
instance_id=instance_id,
)
async def _chat_chunk_stream(
async def _stream_events(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
) -> AsyncGenerator[StreamEvent, None]:
"""Yield stream events (TokenChunks or PrefillProgressData) for a command."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
self._chat_completion_queues[command_id], recv = channel[StreamEvent]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
with recv as events:
async for event in events:
yield event
if isinstance(event, TokenChunk) and event.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
@@ -402,21 +456,36 @@ class API:
await self._send(command)
del self._chat_completion_queues[command_id]
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield only TokenChunks, filtering out progress events."""
async for event in self._stream_events(command_id):
if isinstance(event, TokenChunk):
yield event
async def _generate_chat_stream(
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
async for event in self._stream_events(command_id):
if isinstance(event, PrefillProgressData):
# Send prefill progress as a named SSE event
progress_json = f'{{"processed":{event.processed_tokens},"total":{event.total_tokens}}}'
yield f"event: prefill_progress\ndata: {progress_json}\n\n"
else:
# TokenChunk - regular token generation
chunk_response: ChatCompletionResponse = chunk_to_response(
event, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
if event.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
@@ -550,12 +619,82 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def claude_messages(
self, payload: ClaudeMessagesRequest
) -> ClaudeMessagesResponse | StreamingResponse:
"""Handle Claude Messages API requests."""
chat_params = claude_request_to_chat_params(payload)
model_meta = await resolve_model_meta(chat_params.model)
chat_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == chat_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(chat_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {chat_params.model}",
)
command = ChatCompletion(request_params=chat_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_claude_stream(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
response = await self._collect_chat_completion(command.command_id)
return chat_response_to_claude_response(response)
async def openai_responses(
self, payload: ResponsesRequest
) -> ResponsesResponse | StreamingResponse:
"""Handle OpenAI Responses API requests."""
chat_params = responses_request_to_chat_params(payload)
model_meta = await resolve_model_meta(chat_params.model)
chat_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == chat_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(chat_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {chat_params.model}",
)
command = ChatCompletion(request_params=chat_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_responses_stream(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
response = await self._collect_chat_completion(command.command_id)
return chat_response_to_responses_response(response)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
return total_available
@@ -616,6 +755,16 @@ class API:
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif (
isinstance(event, PrefillProgress)
and event.command_id in self._chat_completion_queues
):
await self._chat_completion_queues[event.command_id].send(
PrefillProgressData(
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -27,7 +27,6 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
@@ -159,7 +158,6 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -202,7 +200,9 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
@@ -237,8 +237,6 @@ class Master:
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)

View File

@@ -6,10 +6,9 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
@@ -20,11 +19,10 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -54,14 +52,19 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_profiles, command.model_meta.storage_size
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
if len(cycles_with_sufficient_memory) == 0:
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
@@ -89,38 +92,44 @@ def place_instance(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
if any(topology.node_is_leaf(node.node_id) for node in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle),
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
start=Memory(),
),
)
shard_assignments = get_shard_assignments(
command.model_meta, selected_cycle, command.sharding, node_profiles
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -128,20 +137,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
[node_id for node_id in selected_cycle],
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle.node_ids[0],
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_profiles=node_profiles,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
@@ -150,7 +158,6 @@ def place_instance(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
node_profiles=node_profiles,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,

View File

@@ -1,13 +1,15 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import TypeGuard, cast
from loguru import logger
from pydantic import BaseModel
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -17,55 +19,58 @@ from exo.shared.types.worker.shards import (
)
class NodeWithProfile(BaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[Cycle],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[Cycle]:
filtered_cycles: list[Cycle] = []
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
for cycle in cycles:
if not all(node in node_profiles for node in cycle):
if not narrow_all_nodes(cycle):
continue
total_mem = sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
start=Memory(),
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cycle)
filtered_cycles.append(cast(list[NodeInfo], cycle))
return filtered_cycles
def get_smallest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
cycle: Cycle,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
selected_cycle: list[NodeWithProfile],
):
cycle_memory = sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
total_layers = model_meta.n_layers
world_size = len(cycle)
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node_id in enumerate(cycle):
if i == len(cycle) - 1:
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node_profiles[node_id].memory.ram_available.in_bytes
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
@@ -83,7 +88,7 @@ def get_shard_assignments_for_pipeline_parallel(
)
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
node_to_runner[node.node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
@@ -97,14 +102,14 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_meta: ModelMetadata,
cycle: Cycle,
selected_cycle: list[NodeWithProfile],
):
total_layers = model_meta.n_layers
world_size = len(cycle)
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
for i, node in enumerate(selected_cycle):
shard = TensorShardMetadata(
model_meta=model_meta,
device_rank=i,
@@ -117,7 +122,7 @@ def get_shard_assignments_for_tensor_parallel(
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
node_to_runner[node.node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
@@ -130,21 +135,21 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
cycle: Cycle,
selected_cycle: list[NodeInfo],
sharding: Sharding,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_meta=model_meta,
cycle=cycle,
node_profiles=node_profiles,
selected_cycle=selected_cycle,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_meta=model_meta,
cycle=cycle,
selected_cycle=selected_cycle,
)
@@ -159,40 +164,38 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
)
return []
cycle = cycles[0]
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycle):
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
get_thunderbolt = True
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
cycle = cycles[0]
hosts: list[Host] = []
for i in range(len(cycle)):
current_node = cycle.node_ids[i]
next_node = cycle.node_ids[(i + 1) % len(cycle)]
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.get_all_connections_between(
source=current_node, sink=next_node
):
if not isinstance(connection, SocketConnection):
continue
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break
return hosts
def get_mlx_jaccl_devices_matrix(
selected_cycle: list[NodeId],
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -211,37 +214,72 @@ def get_mlx_jaccl_devices_matrix(
if i == j:
continue
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
break
else:
logger.warning(
f"Failed to find interface name between {node_i} and {node_j}"
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
"Current ibv backend requires all-to-all rdma connections"
)
return matrix
def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def _find_interface_name_for_ip(
ip_address: str, node_profile: NodePerformanceProfile
ip_address: str,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
for interface in node_profile.network_interfaces:
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -249,10 +287,7 @@ def _find_interface_name_for_ip(
def _find_ip_prioritised(
node_id: NodeId,
other_node_id: NodeId,
cycle_digraph: Topology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -263,12 +298,9 @@ def _find_ip_prioritised(
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {
_find_interface_name_for_ip(ip, node_profiles[other_node_id]): ip
for ip, _ in ips
}
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
en0_ip = iface_map.get("en0")
if en0_ip:
@@ -292,10 +324,9 @@ def _find_ip_prioritised(
def get_mlx_ring_hosts_by_node(
selected_cycle: Cycle,
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
@@ -310,13 +341,14 @@ def get_mlx_ring_hosts_by_node(
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node_id in enumerate(selected_cycle):
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
hosts_for_node: list[Host] = []
for idx, other_node_id in enumerate(selected_cycle):
for idx, other_node in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
@@ -326,12 +358,10 @@ def get_mlx_ring_hosts_by_node(
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_profiles
)
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
@@ -345,34 +375,31 @@ def get_mlx_ring_hosts_by_node(
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
logger.info(f"Selecting coordinator: {coordinator}")
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_profiles)
if ip is not None:
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n: f"{get_ip_for_node(n)}:{coordinator_port}"
for n in cycle_digraph.list_nodes()
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}

View File

@@ -1,39 +1,67 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
for i in range(10)
],
system=SystemPerformanceProfile(),
)
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)
return _create_connection

View File

@@ -0,0 +1,392 @@
"""Tests for Claude Messages API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.claude import (
chat_response_to_claude_response,
claude_request_to_chat_params,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
Usage,
)
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessage,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToChatParams:
"""Tests for converting Claude Messages API requests to ChatCompletionTaskParams."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert params.model == "claude-3-opus"
assert params.max_tokens == 100
assert len(params.messages) == 1
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello"
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are a helpful assistant."
assert params.messages[1].role == "user"
assert params.messages[1].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are helpful. Be concise."
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 1
assert params.messages[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 3
assert params.messages[0].role == "user"
assert params.messages[1].role == "assistant"
assert params.messages[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_chat_params(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestChatResponseToClaudeResponse:
"""Tests for converting ChatCompletionResponse to Claude Messages API response."""
def test_basic_response_conversion(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content="Hello! How can I help you?",
),
finish_reason="stop",
)
],
usage=Usage(prompt_tokens=10, completion_tokens=7, total_tokens=17),
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.id == "msg_chatcmpl-123"
assert claude_response.model == "llama-3.2-1b"
assert claude_response.role == "assistant"
assert claude_response.type == "message"
assert len(claude_response.content) == 1
assert claude_response.content[0].type == "text"
assert claude_response.content[0].text == "Hello! How can I help you?"
assert claude_response.stop_reason == "end_turn"
assert claude_response.usage.input_tokens == 10
assert claude_response.usage.output_tokens == 7
def test_response_with_length_finish_reason(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content="Truncated..."
),
finish_reason="length",
)
],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.stop_reason == "max_tokens"
def test_response_with_empty_content(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content=""),
finish_reason="stop",
)
],
usage=Usage(prompt_tokens=10, completion_tokens=0, total_tokens=10),
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == ""
assert claude_response.usage.output_tokens == 0
def test_response_with_no_choices(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == ""
assert claude_response.stop_reason is None
assert claude_response.usage.input_tokens == 0
assert claude_response.usage.output_tokens == 0
def test_response_without_usage(self):
"""Test response conversion when usage data is not available."""
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == "Hello!"
assert claude_response.usage.input_tokens == 0
assert claude_response.usage.output_tokens == 0
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)
class TestClaudeStreamingEvents:
"""Tests for Claude Messages API streaming event serialization."""
def test_message_start_event_format(self):
message = ClaudeMessageStart(
id="msg_123",
model="claude-3-opus",
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
)
event = ClaudeMessageStartEvent(message=message)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_start"
assert parsed["message"]["id"] == "msg_123"
assert parsed["message"]["type"] == "message"
assert parsed["message"]["role"] == "assistant"
assert parsed["message"]["model"] == "claude-3-opus"
def test_content_block_start_event_format(self):
event = ClaudeContentBlockStartEvent(
index=0,
content_block=ClaudeTextBlock(text=""),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_start"
assert parsed["index"] == 0
assert parsed["content_block"]["type"] == "text"
assert parsed["content_block"]["text"] == ""
def test_content_block_delta_event_format(self):
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_delta"
assert parsed["index"] == 0
assert parsed["delta"]["type"] == "text_delta"
assert parsed["delta"]["text"] == "Hello"
def test_content_block_stop_event_format(self):
event = ClaudeContentBlockStopEvent(index=0)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_stop"
assert parsed["index"] == 0
def test_message_delta_event_format(self):
event = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason="end_turn"),
usage=ClaudeMessageDeltaUsage(output_tokens=25),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_delta"
assert parsed["delta"]["stop_reason"] == "end_turn"
assert parsed["usage"]["output_tokens"] == 25
def test_message_stop_event_format(self):
event = ClaudeMessageStopEvent()
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_stop"
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
# Simulate the SSE format used in the streaming generator
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
assert sse_line.startswith("event: content_block_delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -19,13 +19,15 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodeGatheredInfo,
NodePerformanceMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -81,14 +83,21 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
NodePerformanceMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
),
@@ -154,7 +163,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)

View File

@@ -0,0 +1,414 @@
"""Tests for OpenAI Responses API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.responses import (
chat_response_to_responses_response,
responses_request_to_chat_params,
)
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
Usage,
)
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseCreatedEvent,
ResponseInputMessage,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
class TestResponsesRequestToChatParams:
"""Tests for converting OpenAI Responses API requests to ChatCompletionTaskParams."""
def test_string_input_conversion(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello, how are you?",
)
params = responses_request_to_chat_params(request)
assert params.model == "gpt-4o"
assert len(params.messages) == 1
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello, how are you?"
def test_message_array_input_conversion(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="user", content="Hello"),
ResponseInputMessage(role="assistant", content="Hi there!"),
ResponseInputMessage(role="user", content="How are you?"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 3
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello"
assert params.messages[1].role == "assistant"
assert params.messages[1].content == "Hi there!"
assert params.messages[2].role == "user"
assert params.messages[2].content == "How are you?"
def test_request_with_instructions(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
instructions="You are a helpful assistant. Be concise.",
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are a helpful assistant. Be concise."
assert params.messages[1].role == "user"
assert params.messages[1].content == "Hello"
def test_request_with_optional_parameters(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
max_output_tokens=500,
temperature=0.8,
top_p=0.95,
stream=True,
)
params = responses_request_to_chat_params(request)
assert params.max_tokens == 500
assert params.temperature == 0.8
assert params.top_p == 0.95
assert params.stream is True
def test_request_with_system_role_in_messages(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="system", content="Be helpful"),
ResponseInputMessage(role="user", content="Hello"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[1].role == "user"
def test_request_with_developer_role(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="developer", content="Internal note"),
ResponseInputMessage(role="user", content="Hello"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "developer"
class TestChatResponseToResponsesResponse:
"""Tests for converting ChatCompletionResponse to OpenAI Responses API response."""
def test_basic_response_conversion(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content="Hello! How can I help you?",
),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.id == "resp_chatcmpl-123"
assert responses_response.object == "response"
assert responses_response.model == "llama-3.2-1b"
assert responses_response.status == "completed"
assert responses_response.output_text == "Hello! How can I help you?"
assert len(responses_response.output) == 1
assert responses_response.output[0].type == "message"
assert responses_response.output[0].role == "assistant"
assert len(responses_response.output[0].content) == 1
assert responses_response.output[0].content[0].type == "output_text"
assert (
responses_response.output[0].content[0].text == "Hello! How can I help you?"
)
def test_response_with_usage(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
usage=Usage(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
),
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.usage is not None
assert responses_response.usage.input_tokens == 10
assert responses_response.usage.output_tokens == 5
assert responses_response.usage.total_tokens == 15
def test_response_with_empty_content(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content=""),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output_text == ""
assert responses_response.output[0].content[0].text == ""
def test_response_with_no_choices(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output_text == ""
def test_response_without_usage(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.usage is None
def test_response_item_id_format(self):
response = ChatCompletionResponse(
id="chatcmpl-abc123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output[0].id == "item_chatcmpl-abc123"
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model="gpt-4o",
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1
class TestResponsesStreamingEvents:
"""Tests for OpenAI Responses API streaming event serialization."""
def test_response_created_event_format(self):
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="in_progress",
output=[],
output_text="",
)
event = ResponseCreatedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.created"
assert parsed["response"]["id"] == "resp_123"
assert parsed["response"]["object"] == "response"
assert parsed["response"]["status"] == "in_progress"
def test_output_item_added_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="")],
status="in_progress",
)
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.added"
assert parsed["output_index"] == 0
assert parsed["item"]["type"] == "message"
assert parsed["item"]["id"] == "item_123"
assert parsed["item"]["role"] == "assistant"
def test_content_part_added_event_format(self):
part = ResponseOutputText(text="")
event = ResponseContentPartAddedEvent(
output_index=0,
content_index=0,
part=part,
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.content_part.added"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["part"]["type"] == "output_text"
def test_text_delta_event_format(self):
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.delta"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["delta"] == "Hello"
def test_text_done_event_format(self):
event = ResponseTextDoneEvent(
output_index=0,
content_index=0,
text="Hello, world!",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.done"
assert parsed["text"] == "Hello, world!"
def test_output_item_done_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello, world!")],
status="completed",
)
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.done"
assert parsed["item"]["status"] == "completed"
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
def test_response_completed_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello!")],
status="completed",
)
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="completed",
output=[item],
output_text="Hello!",
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
)
event = ResponseCompletedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.completed"
assert parsed["response"]["status"] == "completed"
assert parsed["response"]["output_text"] == "Hello!"
assert parsed["response"]["usage"]["total_tokens"] == 15
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
# Simulate the SSE format used in the streaming generator
sse_line = (
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
)
assert sse_line.startswith("event: response.output_text.delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -1,23 +1,20 @@
from typing import Callable
import pytest
from loguru import logger
from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_node_profile,
create_rdma_connection,
create_socket_connection,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -29,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -75,57 +77,34 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_id_c, sink=node_id_b, edge=create_socket_connection(4)
)
conn_a_c = Connection(
source=node_id_a, sink=node_id_c, edge=create_socket_connection(5)
)
conn_b_a = Connection(
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
)
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
topology.add_connection(conn_b_a)
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
@@ -151,11 +130,12 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit() -> None:
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -166,7 +146,7 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -177,11 +157,12 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
topology.add_node(create_node(1001 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -192,7 +173,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -203,11 +184,12 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit() -> None:
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
@@ -220,7 +202,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, profiles)
place_instance(cic, topology, {})
def test_get_transition_events_no_change(instance: Instance):
@@ -265,130 +247,179 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_leaf_nodes(
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
model_meta.storage_size = Memory.from_bytes(1000)
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Daisy chain topology (directed)
topology.add_connection(
Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_d, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
)
cic = place_instance_command(model_meta=model_meta)
# Act
placements = place_instance(cic, topology, {})
# act
placements = place_instance(cic, topology, {}, profiles)
# assert
# Assert: D-E-F cycle should be selected as it has more total memory
assert len(placements) == 1
instance = list(placements.values())[0]
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(
node_id_c,
node_id_d,
)
)
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="10.0.0.1",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
ip_address="192.168.1.100",
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
# RDMA connections (directed)
topology.add_connection(
Connection(source=node_a, sink=node_b, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_a, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_c, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_c, sink=node_b, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_a, sink=node_c, edge=create_rdma_connection(5))
)
topology.add_connection(
Connection(source=node_c, sink=node_a, edge=create_rdma_connection(5))
)
# Ethernet connections (directed)
topology.add_connection(Connection(source=node_a, sink=node_b, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_a, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_b, edge=ethernet_conn))
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
cic = PlaceInstance(
sharding=Sharding.Tensor,
@@ -398,34 +429,35 @@ def test_tensor_rdma_backend_connectivity_matrix(
min_nodes=1,
)
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert isinstance(instance, MlxJacclInstance)
assert instance.jaccl_devices is not None
assert instance.ibv_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.jaccl_devices
matrix = instance.ibv_devices
assert len(matrix) == 3
for i in range(3):
assert matrix[i][i] is None
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3
@@ -437,5 +469,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
if node_id == assigned_nodes[0]:
assert coordinator.startswith("0.0.0.0:")
else:
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
ip_part = coordinator.split(":")[0]
# Just verify it's a valid IP format
assert len(ip_part.split(".")) == 4

View File

@@ -1,4 +1,4 @@
from copy import copy
from typing import Callable
import pytest
@@ -9,178 +9,154 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_node_profile, create_socket_connection
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = [c for c in topology.get_cycles() if len(c) != 1]
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 2
assert set(n for n in filtered_cycles[0]) == {node1_id, node2_id}
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory():
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
topology.get_cycles(), Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory():
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 3
assert set(n for n in filtered_cycles[0]) == {
assert set(n.node_id for n in filtered_cycles[0]) == {
node_a_id,
node_b_id,
node_c_id,
}
def test_get_smallest_cycles():
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
# act
smallest_cycles = get_smallest_cycles(cycles)
smallest_cycles = get_smallest_cycles(topology.get_cycles())
# assert
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
@pytest.mark.parametrize(
@@ -192,6 +168,9 @@ def test_get_smallest_cycles():
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -201,37 +180,18 @@ def test_get_shard_assignments(
node_b_id = NodeId()
node_c_id = NodeId()
# create connections (A -> B -> C -> A forms a 3-cycle, plus B -> A also exists)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(4)
)
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
@@ -241,22 +201,23 @@ def test_get_shard_assignments(
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
# pick the 3-node cycle deterministically (cycle ordering can vary)
selected_cycle = next(cycle for cycle in cycles if len(cycle) == 3)
selected_cycle = cycles[0]
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_profiles=node_profiles
model_meta, selected_cycle, Sharding.Pipeline
)
# assert
runner_id_a = shard_assignments.node_to_runner[node_a_id]
runner_id_b = shard_assignments.node_to_runner[node_b_id]
runner_id_c = shard_assignments.node_to_runner[node_c_id]
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
assert (
shard_assignments.runner_to_shard[runner_id_a].end_layer
- shard_assignments.runner_to_shard[runner_id_a].start_layer
@@ -267,37 +228,30 @@ def test_get_shard_assignments(
- shard_assignments.runner_to_shard[runner_id_b].start_layer
== expected_layers[1]
)
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
def test_get_hosts_from_subgraph():
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -305,78 +259,95 @@ def test_get_hosts_from_subgraph():
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip="169.254.0.1", port=1234),
Host(ip="169.254.0.2", port=1234),
Host(ip="169.254.0.3", port=1234),
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(5)
)
conn_a_c = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
npp = NodePerformanceProfile(
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=0,
ram_available=0,
swap_total=0,
swap_available=0,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
npp_a = copy(npp)
npp_a.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
npp_b = copy(npp)
npp_b.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
npp_c = copy(npp)
npp_c.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
node_profiles = {
node_a_id: npp_a,
node_b_id: npp_b,
node_c_id: npp_c,
}
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
@@ -385,12 +356,11 @@ def test_get_mlx_jaccl_coordinators():
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_jaccl_coordinators(
node_a_id,
coordinator_port=5000,
cycle_digraph=topology,
node_profiles=node_profiles,
cycle, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -411,20 +381,19 @@ def test_get_mlx_jaccl_coordinators():
f"Coordinator for {node_id} should use port 5000"
)
# Rank 0 (node_a) treats this as the listen socket so should listen on all IPs
# Rank 0 (node_a) treats this as the listen socket so should listen on all
# IPs
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
"Rank 0 node should use 0.0.0.0 as coordinator listen address"
"Rank 0 node should use localhost as coordinator"
)
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert isinstance(conn_b_a.edge, SocketConnection)
assert (
coordinators[node_b_id] == f"{conn_b_a.edge.sink_multiaddr.ip_address}:5000"
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert isinstance(conn_c_a.edge, SocketConnection)
assert (
coordinators[node_c_id] == f"{conn_c_a.edge.sink_multiaddr.ip_address}:5000"
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"

View File

@@ -1,14 +1,13 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
@pytest.fixture
@@ -17,15 +16,20 @@ def topology() -> Topology:
@pytest.fixture
def socket_connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryUsage.from_bytes(
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -39,91 +43,162 @@ def node_profile() -> NodePerformanceProfile:
)
def test_add_node(topology: Topology):
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
# arrange
node_id = NodeId()
# act
topology.add_node(node_id)
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
# assert
assert topology.node_is_leaf(node_id)
data = topology.get_node_profile(node_id)
assert data == node_profile
def test_add_connection(topology: Topology, socket_connection: SocketConnection):
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
connection = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
data = list(topology.list_connections())
data = topology.get_connection_profile(connection)
# assert
assert data == [connection]
assert data == connection.connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
def test_remove_connection_still_connected(
topology: Topology, socket_connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_connection(conn)
topology.remove_connection(connection)
# assert
assert list(topology.get_all_connections_between(node_a, node_b)) == []
assert topology.get_connection_profile(connection) is None
def test_remove_node_still_connected(
topology: Topology, socket_connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_node(node_b)
topology.remove_node(connection.local_node_id)
# assert
assert list(topology.out_edges(node_a)) == []
assert topology.get_node_profile(connection.local_node_id) is None
def test_list_nodes(topology: Topology, socket_connection: SocketConnection):
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeId) for node in nodes)
assert set(node for node in nodes) == set([node_a, node_b])
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}

View File

@@ -11,9 +11,12 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -25,42 +28,36 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import Connection, RDMAConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacThunderboltConnections,
MacThunderboltIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged()
TestEvent() | ChunkGenerated() | TaskAcknowledged() | PrefillProgress()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
return state
case InstanceCreated():
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -192,7 +189,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -200,12 +197,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -213,68 +206,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
case MacThunderboltIdentifiers():
profile.tb_interfaces = info.idents
case MacThunderboltConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
Connection(
source=event.node_id,
sink=conn_map[tb_conn.sink_uuid][0],
edge=RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_profiles": new_profiles,
"last_seen": last_seen,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_connection(event.conn)
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.conn)
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -38,7 +38,6 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"

View File

@@ -11,6 +11,9 @@ class InterceptLogger(HypercornLogger):
def __init__(self, config: Config):
super().__init__(config)
assert self.error_logger
# TODO: Decide if we want to provide access logs
# assert self.access_logger
# self.access_logger.handlers = [_InterceptHandler()]
self.error_logger.handlers = [_InterceptHandler()]
@@ -26,11 +29,6 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -43,4 +43,7 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection
def test_state_serialization_roundtrip() -> None:
@@ -12,11 +12,9 @@ def test_state_serialization_roundtrip() -> None:
node_b = NodeId("node-b")
connection = Connection(
source=node_a,
sink=node_b,
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
),
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
@@ -25,11 +23,5 @@ def test_state_serialization_roundtrip() -> None:
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)
assert (
state.topology.to_snapshot().nodes
== restored_state.topology.to_snapshot().nodes
)
assert set(state.topology.to_snapshot().connections) == set(
restored_state.topology.to_snapshot().connections
)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert restored_state.model_dump_json() == json_repr

View File

@@ -1,227 +1,203 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.topology import (
Connection,
Cycle,
RDMAConnection,
SocketConnection,
)
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
class TopologySnapshot(BaseModel):
nodes: Sequence[NodeId]
connections: Mapping[
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
]
nodes: list[NodeInfo]
connections: list[Connection]
model_config = ConfigDict(frozen=True, extra="forbid")
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
@dataclass
class Topology:
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()), connections=self.map_connections()
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node_id in snapshot.nodes:
for node in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node_id)
topology.add_node(node)
for source in snapshot.connections:
for sink in snapshot.connections[source]:
for edge in snapshot.connections[source][sink]:
topology.add_connection(
Connection(source=source, sink=sink, edge=edge)
)
for connection in snapshot.connections:
topology.add_connection(connection)
return topology
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
return
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
]
def out_edges(self, node_id: NodeId) -> Iterable[Connection]:
if node_id not in self._vertex_indices:
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return (
Connection(source=self._graph[source], sink=self._graph[sink], edge=edge)
for source, sink, edge in self._graph.out_edges(
self._vertex_indices[node_id]
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
)
]
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._vertex_indices
return node_id in self._node_id_to_rx_id_map
def add_connection(self, conn: Connection) -> None:
source, sink, edge = conn.source, conn.sink, conn.edge
del conn
if edge in self.get_all_connections_between(source, sink):
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
connection: Connection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
return
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
_ = self._graph.add_edge(src_id, sink_id, edge)
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
try:
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def list_connections(
self,
) -> Iterable[Connection]:
return (
(
Connection(
source=self._graph[src_id],
sink=self._graph[sink_id],
edge=connection,
)
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._vertex_indices:
if node_id not in self._node_id_to_rx_id_map:
return
rx_idx = self._vertex_indices[node_id]
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph.remove_node(rx_idx)
del self._vertex_indices[node_id]
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
def replace_all_out_rdma_connections(
self, source: NodeId, new_connections: Sequence[Connection]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for conn in new_connections:
self.add_connection(conn)
def remove_connection(self, conn: Connection) -> None:
if (
conn.source not in self._vertex_indices
or conn.sink not in self._vertex_indices
):
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
return
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[conn.source], self._vertex_indices[conn.sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == conn.edge:
self._graph.remove_edge_from_index(conn_idx)
def get_cycles(self) -> list[Cycle]:
"""Get simple cycles in the graph, including singleton cycles"""
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
def get_cycles(self) -> list[list[NodeInfo]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[Cycle] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[self._graph[idx] for idx in cycle_idx])
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
for node_id in self.list_nodes():
cycles.append(Cycle(node_ids=[node_id]))
return cycles
def get_cycles_tb(self) -> list[Cycle]:
def get_cycles_tb(self) -> list[list[NodeInfo]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[Cycle] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
if connection.source in node_ids and connection.sink in node_ids:
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -146,6 +146,7 @@ class ChatCompletionTaskParams(BaseModel):
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None

View File

@@ -1,6 +1,6 @@
from enum import Enum
from exo.shared.types.api import GenerationStats
from exo.shared.types.api import GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -20,6 +20,8 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None

View File

@@ -0,0 +1,168 @@
"""Claude Messages API types for request/response conversion."""
from typing import Literal
from pydantic import BaseModel, Field
# Type aliases
ClaudeRole = Literal["user", "assistant"]
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
# Content block types
class ClaudeTextBlock(BaseModel, frozen=True):
"""Text content block in Claude Messages API."""
type: Literal["text"] = "text"
text: str
class ClaudeImageSource(BaseModel, frozen=True):
"""Image source for Claude image blocks."""
type: Literal["base64", "url"]
media_type: str | None = None
data: str | None = None
url: str | None = None
class ClaudeImageBlock(BaseModel, frozen=True):
"""Image content block in Claude Messages API."""
type: Literal["image"] = "image"
source: ClaudeImageSource
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
# Request types
class ClaudeMessage(BaseModel, frozen=True):
"""Message in Claude Messages API request."""
role: ClaudeRole
content: str | list[ClaudeContentBlock]
class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API."""
model: str
max_tokens: int
messages: list[ClaudeMessage]
system: str | list[ClaudeTextBlock] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
metadata: dict[str, str] | None = None
# Response types
class ClaudeUsage(BaseModel, frozen=True):
"""Token usage in Claude Messages API response."""
input_tokens: int
output_tokens: int
class ClaudeMessagesResponse(BaseModel, frozen=True):
"""Response body for Claude Messages API."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock]
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
# Streaming event types
class ClaudeMessageStart(BaseModel, frozen=True):
"""Partial message in message_start event."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock] = Field(default_factory=list)
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
class ClaudeMessageStartEvent(BaseModel, frozen=True):
"""Event sent at start of message stream."""
type: Literal["message_start"] = "message_start"
message: ClaudeMessageStart
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
"""Event sent at start of a content block."""
type: Literal["content_block_start"] = "content_block_start"
index: int
content_block: ClaudeTextBlock
class ClaudeTextDelta(BaseModel, frozen=True):
"""Delta for text content block."""
type: Literal["text_delta"] = "text_delta"
text: str
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
"""Event sent for content block delta."""
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: ClaudeTextDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
"""Event sent at end of a content block."""
type: Literal["content_block_stop"] = "content_block_stop"
index: int
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
"""Usage in message_delta event."""
output_tokens: int
class ClaudeMessageDelta(BaseModel, frozen=True):
"""Delta in message_delta event."""
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
"""Event sent with final message delta."""
type: Literal["message_delta"] = "message_delta"
delta: ClaudeMessageDelta
usage: ClaudeMessageDeltaUsage
class ClaudeMessageStopEvent(BaseModel, frozen=True):
"""Event sent at end of message stream."""
type: Literal["message_stop"] = "message_stop"
ClaudeStreamEvent = (
ClaudeMessageStartEvent
| ClaudeContentBlockStartEvent
| ClaudeContentBlockDeltaEvent
| ClaudeContentBlockStopEvent
| ClaudeMessageDeltaEvent
| ClaudeMessageStopEvent
)

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
# TODO: bikeshed this name
class NodeGatheredInfo(BaseEvent):
class NodePerformanceMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
info: GatheredInfo
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
class NodeDownloadProgress(BaseEvent):
@@ -96,12 +106,18 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
edge: Connection
class TopologyEdgeDeleted(BaseEvent):
conn: Connection
edge: Connection
Event = (
@@ -115,10 +131,13 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodeGatheredInfo
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -1,11 +1,10 @@
import re
from typing import ClassVar
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from pydantic import BaseModel, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -0,0 +1,162 @@
"""OpenAI Responses API types for request/response conversion."""
import time
from typing import Literal
from pydantic import BaseModel, Field
# Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
ResponseRole = Literal["user", "assistant", "system", "developer"]
# Request types
class ResponseInputMessage(BaseModel, frozen=True):
"""Input message for Responses API."""
role: ResponseRole
content: str
class ResponsesRequest(BaseModel):
"""Request body for OpenAI Responses API."""
model: str
input: str | list[ResponseInputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stream: bool = False
# previous_response_id not supported in MVP
metadata: dict[str, str] | None = None
# Response types
class ResponseOutputText(BaseModel, frozen=True):
"""Text content in response output."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list[dict[str, str]] = Field(default_factory=list)
class ResponseMessageItem(BaseModel, frozen=True):
"""Message item in response output array."""
type: Literal["message"] = "message"
id: str
role: Literal["assistant"] = "assistant"
content: list[ResponseOutputText]
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
class ResponseUsage(BaseModel, frozen=True):
"""Token usage in Responses API response."""
input_tokens: int
output_tokens: int
total_tokens: int
class ResponsesResponse(BaseModel, frozen=True):
"""Response body for OpenAI Responses API."""
id: str
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
status: ResponseStatus = "completed"
model: str
output: list[ResponseItem]
output_text: str
usage: ResponseUsage | None = None
# Streaming event types
class ResponseCreatedEvent(BaseModel, frozen=True):
"""Event sent when response is created."""
type: Literal["response.created"] = "response.created"
response: ResponsesResponse
class ResponseInProgressEvent(BaseModel, frozen=True):
"""Event sent when response starts processing."""
type: Literal["response.in_progress"] = "response.in_progress"
response: ResponsesResponse
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
"""Event sent when an output item is added."""
type: Literal["response.output_item.added"] = "response.output_item.added"
output_index: int
item: ResponseItem
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a content part is added."""
type: Literal["response.content_part.added"] = "response.content_part.added"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for text delta during streaming."""
type: Literal["response.output_text.delta"] = "response.output_text.delta"
output_index: int
content_index: int
delta: str
class ResponseTextDoneEvent(BaseModel, frozen=True):
"""Event sent when text content is done."""
type: Literal["response.output_text.done"] = "response.output_text.done"
output_index: int
content_index: int
text: str
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a content part is done."""
type: Literal["response.content_part.done"] = "response.content_part.done"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
"""Event sent when an output item is done."""
type: Literal["response.output_item.done"] = "response.output_item.done"
output_index: int
item: ResponseItem
class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed."""
type: Literal["response.completed"] = "response.completed"
response: ResponsesResponse
ResponsesStreamEvent = (
ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseOutputItemAddedEvent
| ResponseContentPartAddedEvent
| ResponseTextDeltaEvent
| ResponseTextDoneEvent
| ResponseContentPartDoneEvent
| ResponseOutputItemDoneEvent
| ResponseCompletedEvent
)

View File

@@ -1,14 +1,12 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import ThunderboltIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryUsage(CamelCaseModel):
class MemoryPerformanceProfile(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -54,12 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[ThunderboltIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
class ConnectionProfile(CamelCaseModel):
throughput: float
latency: float
jitter: float

View File

@@ -1,81 +0,0 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class ThunderboltConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class ThunderboltIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
## Intentionally minimal, only collecting data we care about - there's a lot more
class _ReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str | None = None
class _ConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
class ThunderboltConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
items: list[_ConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: _ReceptacleTag | None = None
def ident(self, ifaces: dict[str, str]) -> ThunderboltIdentifier | None:
if (
self.domain_uuid_key is None
or self.receptacle_1_tag is None
or self.receptacle_1_tag.receptacle_id_key is None
):
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
assert tag in ifaces # doesn't need to be an assertion but im confident
# if tag not in ifaces: return None
iface = f"rdma_{ifaces[tag]}"
return ThunderboltIdentifier(
rdma_interface=iface, domain_uuid=self.domain_uuid_key
)
def conn(self) -> ThunderboltConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
),
None,
)
if sink_key is None:
return None
return ThunderboltConnection(
source_uuid=self.domain_uuid_key, sink_uuid=sink_key
)
class ThunderboltConnectivity(BaseModel, extra="ignore"):
SPThunderboltDataType: list[ThunderboltConnectivityData] = []
@classmethod
async def gather(cls) -> list[ThunderboltConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return ThunderboltConnectivity.model_validate_json(
proc.stdout
).SPThunderboltDataType

View File

@@ -1,41 +1,37 @@
from collections.abc import Iterator
from dataclasses import dataclass
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.utils.pydantic_ext import FrozenModel
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
@dataclass(frozen=True)
class Cycle:
node_ids: list[NodeId]
def __len__(self) -> int:
return self.node_ids.__len__()
def __iter__(self) -> Iterator[NodeId]:
return self.node_ids.__iter__()
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
def is_thunderbolt(self) -> bool:
return True
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
class Connection(FrozenModel):
source: NodeId
sink: NodeId
edge: RDMAConnection | SocketConnection
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
ibv_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]

View File

@@ -0,0 +1,43 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Callable
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
SystemPerformanceProfile,
)
class ResourceCollector(ABC):
@abstractmethod
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
class SystemResourceCollector(ResourceCollector):
async def collect(self) -> SystemPerformanceProfile: ...
class MemoryResourceCollector(ResourceCollector):
async def collect(self) -> MemoryPerformanceProfile: ...
class ResourceMonitor:
data_collectors: list[ResourceCollector]
effect_handlers: set[
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
]
async def _collect(
self,
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
tasks: list[
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
] = [collector.collect() for collector in self.data_collectors]
return await asyncio.gather(*tasks)
async def collect(self) -> None:
profiles = await self._collect()
for profile in profiles:
for effect_handler in self.effect_handlers:
effect_handler(profile)

View File

@@ -1,4 +1,4 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.api import FinishReason, GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
@@ -13,10 +13,16 @@ class TokenizedResponse(BaseRunnerResponse):
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
class FinishedResponse(BaseRunnerResponse):
pass
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -1,235 +0,0 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import (
ThunderboltConnection,
ThunderboltConnectivity,
ThunderboltIdentifier,
)
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacThunderboltIdentifiers(TaggedModel):
idents: Sequence[ThunderboltIdentifier]
class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if IS_DARWIN:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler_thunderbolt_data(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await ThunderboltConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -1,70 +0,0 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -1,113 +0,0 @@
from collections.abc import Mapping
import anyio
import httpx
from anyio import create_task_group
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
REACHABILITY_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
remote_node_id = NodeId(body)
break
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
if remote_node_id is None:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(
topology: Topology,
self_node_id: NodeId,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
for node_id in topology.list_nodes():
if node_id not in node_profiles:
continue
if node_id == self_node_id:
continue
for iface in node_profiles[node_id].network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node_id,
reachable,
client,
)
return reachable

View File

@@ -1,24 +0,0 @@
import sys
import pytest
from exo.shared.types.thunderbolt import (
ThunderboltConnectivity,
)
from exo.utils.info_gatherer.info_gatherer import (
_gather_iface_map, # pyright: ignore[reportPrivateUsage]
)
@pytest.mark.anyio
@pytest.mark.skipif(
sys.platform != "darwin", reason="Thunderbolt info can only be gathered on macos"
)
async def test_tb_parsing():
data = await ThunderboltConnectivity.gather()
ifaces = await _gather_iface_map()
assert ifaces
assert data
for datum in data:
datum.ident(ifaces)
datum.conn()

View File

@@ -19,20 +19,11 @@ class CamelCaseModel(BaseModel):
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
strict=True,
)
class FrozenModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
strict=True,
frozen=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):

View File

@@ -28,8 +28,9 @@ def bar(send: MpSender[str]):
send.close()
# not async, just want the fail_after
@pytest.mark.anyio
async def test_channel_ipc():
async def test_channel_setup():
with fail_after(0.5):
s, r = mp_channel[str]()
p1 = mp.Process(target=foo, args=(r,))

View File

@@ -12,6 +12,7 @@ from exo.shared.types.api import (
ChatCompletionMessage,
FinishReason,
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
@@ -81,7 +82,7 @@ def warmup_inference(
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
prefill_step_size=256, # Temporarily reduced from 2048 for testing progress bar
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
@@ -115,10 +116,65 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def extract_top_logprobs(
logprobs: mx.array,
tokenizer: TokenizerWrapper,
top_k: int,
selected_token: int,
) -> tuple[float, list[TopLogprobItem]]:
"""Extract the selected token's logprob and top-k alternative tokens.
Args:
logprobs: Full vocabulary logprobs array from MLX
tokenizer: Tokenizer for decoding token IDs to strings
top_k: Number of top alternatives to return
selected_token: The token ID that was actually sampled
Returns:
Tuple of (selected_token_logprob, list of TopLogprobItem for top-k tokens)
"""
# Get the logprob of the selected token
selected_logprob = float(logprobs[selected_token].item())
# Get top-k indices (most probable tokens)
# mx.argpartition gives indices that would partition the array
# We negate logprobs since argpartition finds smallest, and we want largest
top_k = min(top_k, logprobs.shape[0]) # Don't exceed vocab size
top_indices = mx.argpartition(-logprobs, top_k)[:top_k]
# Get the actual logprob values for these indices
top_values = logprobs[top_indices]
# Sort by logprob (descending) for consistent ordering
sort_order = mx.argsort(-top_values)
top_indices = top_indices[sort_order]
top_values = top_values[sort_order]
# Convert to list of TopLogprobItem
top_logprob_items: list[TopLogprobItem] = []
for i in range(top_k):
token_id = int(top_indices[i].item())
token_logprob = float(top_values[i].item())
# Decode token ID to string
token_str = tokenizer.decode([token_id])
# Get byte representation
token_bytes = list(token_str.encode("utf-8"))
top_logprob_items.append(
TopLogprobItem(
token=token_str,
logprob=token_logprob,
bytes=token_bytes,
)
)
return selected_logprob, top_logprob_items
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
on_prefill_progress: Callable[[int, int], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -146,9 +202,24 @@ def mlx_generate(
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
top_k=task.top_k if task.top_k is not None else 0,
)
# Normalize stop sequences to a list
stop_sequences: list[str] = (
([task.stop] if isinstance(task.stop, str) else task.stop)
if task.stop is not None
else []
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
max_tokens = task.max_tokens or MAX_TOKENS
accumulated_text = ""
# Determine if we need to extract logprobs
should_extract_logprobs = task.logprobs is True
num_top_logprobs = task.top_logprobs if task.top_logprobs is not None else 5
for out in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -158,14 +229,47 @@ def mlx_generate(
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
prefill_step_size=256, # Temporarily reduced from 2048 for testing progress bar
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=on_prefill_progress,
):
logger.info(out.text)
accumulated_text += out.text
# Check for stop sequences
text = out.text
finish_reason: FinishReason | None = cast(
FinishReason | None, out.finish_reason
)
stop_matched = False
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in accumulated_text:
# Trim text to just before the stop sequence
stop_index = accumulated_text.find(stop_seq)
text_before_stop = accumulated_text[:stop_index]
chunk_start = len(accumulated_text) - len(out.text)
text = text_before_stop[chunk_start:]
finish_reason = "stop"
stop_matched = True
break
# Extract logprobs if requested
token_logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if should_extract_logprobs:
token_logprob, top_logprobs = extract_top_logprobs(
logprobs=out.logprobs,
tokenizer=tokenizer,
top_k=num_top_logprobs,
selected_token=out.token,
)
is_done = finish_reason is not None
stats: GenerationStats | None = None
if out.finish_reason is not None:
if is_done:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
@@ -173,22 +277,25 @@ def mlx_generate(
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
if not stop_matched and out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
text=text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
logprob=token_logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
)
if out.finish_reason is not None:
if is_done:
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
# TODO: Do we want an mx_barrier?

View File

@@ -145,26 +145,20 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
assert all(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
ibv_devices_json = json.dumps(ibv_devices)
with open(coordination_file, "w") as f:
_ = f.write(jaccl_devices_json)
_ = f.write(ibv_devices_json)
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)

View File

@@ -16,7 +16,8 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
@@ -24,6 +25,7 @@ from exo.shared.types.events import (
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -32,7 +34,7 @@ from exo.shared.types.tasks import (
Task,
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -43,14 +45,14 @@ from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
class Worker:
@@ -84,7 +86,7 @@ class Worker:
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
self._tg: TaskGroup | None = None
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -96,13 +98,37 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
@@ -116,17 +142,6 @@ class Worker:
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -146,6 +161,7 @@ class Worker:
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
assert self._tg
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
@@ -236,7 +252,8 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
if self._tg:
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
@@ -253,28 +270,24 @@ class Worker:
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
async def _nack_request(self, since_idx: int) -> None:
@@ -323,6 +336,7 @@ class Worker:
event_sender=self.event_sender.clone(),
)
self.runners[task.bound_instance.bound_runner_id] = runner
assert self._tg
self._tg.start_soon(runner.run)
return runner
@@ -385,6 +399,7 @@ class Worker:
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
@@ -405,14 +420,9 @@ class Worker:
async def _poll_connection_updates(self):
while True:
edges = set(
conn.edge for conn in self.state.topology.out_edges(self.node_id)
)
conns = await check_reachable(
self.state.topology,
self.node_id,
self.state.node_profiles,
)
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology, self.node_id)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
@@ -420,33 +430,26 @@ class Worker:
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(
TopologyEdgeCreated(
conn=Connection(
source=self.node_id, sink=nid, edge=edge
)
)
)
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
for nid, conn in self.state.topology.out_edges(self.node_id):
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address
not in conns.get(conn.source, set())
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)

View File

@@ -19,7 +19,7 @@ def entrypoint(
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.jaccl_devices) >= 2
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"

View File

@@ -16,6 +16,7 @@ from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
@@ -161,11 +162,23 @@ def main(
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Define callback to send prefill progress events directly
def on_prefill_progress(processed: int, total: int) -> None:
if shard_metadata.device_rank == 0:
event_sender.send(
PrefillProgress(
command_id=command_id,
processed_tokens=processed,
total_tokens=total,
)
)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
on_prefill_progress=on_prefill_progress,
)
# GPT-OSS specific parsing to match other model formats.
@@ -186,6 +199,8 @@ def main(
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
finish_reason=response.finish_reason,
stats=response.stats,
),

View File

@@ -0,0 +1,6 @@
from .profile import start_polling_memory_metrics, start_polling_node_metrics
__all__ = [
"start_polling_node_metrics",
"start_polling_memory_metrics",
]

View File

@@ -0,0 +1,103 @@
import platform
import shutil
from subprocess import CalledProcessError
from typing import cast
from anyio import run_process
from pydantic import BaseModel, ConfigDict, ValidationError
class MacMonError(Exception):
"""Exception raised for errors in the MacMon functions."""
def _get_binary_path() -> str:
"""
Get the path to the macmon binary.
Raises:
MacMonError: If the binary doesn't exist or can't be made executable.
"""
# Check for macOS with ARM chip
system = platform.system().lower()
machine = platform.machine().lower()
if system != "darwin" or not (
"arm" in machine or "m1" in machine or "m2" in machine
):
raise MacMonError("MacMon only supports macOS with Apple Silicon (ARM) chips")
path = shutil.which("macmon")
if path is None:
raise MacMonError("MacMon not found in PATH")
return path
class TempMetrics(BaseModel):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
model_config = ConfigDict(extra="ignore")
class Metrics(BaseModel):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
all_power: float
ane_power: float
cpu_power: float
ecpu_usage: tuple[int, float]
gpu_power: float
gpu_ram_power: float
gpu_usage: tuple[int, float]
pcpu_usage: tuple[int, float]
ram_power: float
sys_power: float
temp: TempMetrics
timestamp: str
model_config = ConfigDict(extra="ignore")
async def get_metrics_async() -> Metrics:
"""
Asynchronously run the binary and return the metrics as a Python dictionary.
Args:
binary_path: Optional path to the binary. If not provided, will use the bundled binary.
Returns:
A mapping containing system metrics.
Raises:
MacMonError: If there's an error running the binary.
"""
path = _get_binary_path()
try:
# TODO: Keep Macmon running in the background?
result = await run_process([path, "pipe", "-s", "1"])
return Metrics.model_validate_json(result.stdout.decode().strip())
except ValidationError as e:
raise MacMonError(f"Error parsing JSON output: {e}") from e
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
raise MacMonError(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
) from e

View File

@@ -0,0 +1,91 @@
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = response.read().decode("utf-8").strip()
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
return NodeId(body) or None
except OSError:
return None
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
)
return reachable

View File

@@ -0,0 +1,114 @@
import asyncio
import os
import platform
from typing import Any, Callable, Coroutine
import anyio
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from .macmon import (
MacMonError,
Metrics,
)
from .macmon import (
get_metrics_async as macmon_get_metrics_async,
)
from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
)
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
if platform.system().lower() == "darwin":
return await macmon_get_metrics_async()
def get_memory_profile() -> MemoryPerformanceProfile:
"""Construct a MemoryPerformanceProfile using psutil"""
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
async def start_polling_memory_metrics(
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 0.5,
) -> None:
"""Continuously poll and emit memory-only metrics at a faster cadence.
Parameters
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
- poll_interval_s: interval between polls
"""
while True:
try:
mem = get_memory_profile()
await callback(mem)
except MacMonError as e:
logger.opt(exception=e).error("Memory Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
)
)
except asyncio.TimeoutError:
logger.warning(
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
)
except MacMonError as e:
logger.opt(exception=e).error("Resource Monitor encountered error")
return
finally:
await anyio.sleep(poll_interval_s)

View File

@@ -0,0 +1,77 @@
"""Tests for macmon error handling.
These tests verify that MacMon errors are handled gracefully without
crashing the application or spamming logs.
"""
import platform
from subprocess import CalledProcessError
from unittest.mock import AsyncMock, patch
import pytest
from exo.worker.utils.macmon import MacMonError, get_metrics_async
@pytest.mark.skipif(
platform.system().lower() != "darwin" or "arm" not in platform.machine().lower(),
reason="MacMon only supports macOS with Apple Silicon",
)
class TestMacMonErrorHandling:
"""Test MacMon error handling."""
async def test_called_process_error_wrapped_as_macmon_error(self) -> None:
"""CalledProcessError should be wrapped as MacMonError."""
mock_error = CalledProcessError(
returncode=1,
cmd=["macmon", "pipe", "-s", "1"],
stderr=b"some error message",
)
with (
patch(
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
),
patch(
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
) as mock_run,
):
mock_run.side_effect = mock_error
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon failed with return code 1" in str(exc_info.value)
assert "some error message" in str(exc_info.value)
async def test_called_process_error_with_no_stderr(self) -> None:
"""CalledProcessError with no stderr should be handled gracefully."""
mock_error = CalledProcessError(
returncode=1,
cmd=["macmon", "pipe", "-s", "1"],
stderr=None,
)
with (
patch(
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
),
patch(
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
) as mock_run,
):
mock_run.side_effect = mock_error
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon failed with return code 1" in str(exc_info.value)
assert "no stderr" in str(exc_info.value)
async def test_macmon_not_found_raises_macmon_error(self) -> None:
"""When macmon is not found in PATH, MacMonError should be raised."""
with patch("exo.worker.utils.macmon.shutil.which", return_value=None):
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon not found in PATH" in str(exc_info.value)

View File

@@ -34,8 +34,7 @@ from exo.shared.types.worker.instances import (
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.channels import MpReceiver, MpSender, mp_channel
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
@@ -66,7 +65,6 @@ async def main():
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
app, # type: ignore
@@ -78,15 +76,6 @@ async def main():
shutdown.set()
async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
@@ -220,16 +209,16 @@ async def jaccl_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
meta = MODEL_CARDS[test.model_id].metadata
world_size = len(test.devs)
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
ibv_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs

1484
uv.lock generated
View File

File diff suppressed because it is too large Load Diff