Compare commits

..

3 Commits

Author SHA1 Message Date
Alex Cheema
6ab1cb285b Enhance LaTeX rendering in dashboard markdown
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 14:49:28 +00:00
Antonio Lujano Luna
9c29eb7d48 Add proxy and custom SSL certificate support for corporate networks (#1189)
Support HTTPS_PROXY/HTTP_PROXY environment variables for proxy
configuration and SSL_CERT_FILE for custom CA certificates, enabling use
in corporate environments with SSL inspection.

## Motivation
Users in corporate environments often need to route traffic through HTTP
proxies and use custom CA certificates for SSL inspection. Without this
support, exo cannot download models in these network configurations.

## Changes
- Added `HTTPS_PROXY`/`HTTP_PROXY` environment variable support to
`create_http_session()` in `download_utils.py`
- Added `SSL_CERT_FILE` environment variable support for custom CA
certificate bundles, falling back to certifi's default bundle

## Why It Works
- `aiohttp.ClientSession` natively supports the `proxy` parameter for
routing requests through HTTP proxies
- `ssl.create_default_context(cafile=...)` accepts a custom CA bundle
path, allowing corporate CAs to be trusted
- Using environment variables is consistent with the codebase's existing
configuration patterns (e.g., `EXO_HOME`, `HF_ENDPOINT`)

## Test Plan
### Manual Testing
- Set `HTTPS_PROXY` environment variable and verified model downloads
route through proxy
- Set `SSL_CERT_FILE` to custom CA bundle and verified SSL verification
succeeds with corporate SSL inspection

### Automated Testing
- No automated tests added; this change is configuration-only and does
not alter existing behavior when environment variables are unset
2026-01-18 12:05:50 +00:00
Alex Cheema
c5158bee53 Add pre-commit checks documentation to AGENTS.md (#1184)
## Motivation

CI failures can be avoided by running checks locally before committing.
This adds clear documentation to AGENTS.md so that AI agents (and
humans) know exactly which checks must pass before pushing code.

## Changes

Added a new "Pre-Commit Checks (REQUIRED)" section to AGENTS.md that:
- Lists all 4 required checks (basedpyright, ruff, nix fmt, pytest)
- Provides a one-liner to run all checks in sequence
- Notes that `nix fmt` changes must be staged before committing
- Explains that CI runs `nix flake check` which verifies everything

## Why It Works

Clear documentation prevents CI failures by ensuring contributors run
checks locally first. The one-liner command makes it easy to run all
checks before committing.

## Test Plan

### Manual Testing
- Verified the documented commands work correctly

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 21:50:24 +00:00
15 changed files with 710 additions and 486 deletions

View File

@@ -40,6 +40,31 @@ uv run ruff check
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition

View File

@@ -53,62 +53,188 @@
marked.use({ renderer });
/**
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
* Unescape HTML entities that marked may have escaped
*/
function unescapeHtmlEntities(text: string): string {
return text
.replace(/&lt;/g, '<')
.replace(/&gt;/g, '>')
.replace(/&amp;/g, '&')
.replace(/&quot;/g, '"')
.replace(/&#39;/g, "'");
}
// Storage for math expressions extracted before markdown processing
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
let mathCounter = 0;
// Use alphanumeric placeholders that won't be interpreted as HTML tags
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
/**
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
*/
function preprocessLaTeX(text: string): string {
// Protect code blocks
// Reset math storage
mathExpressions.clear();
mathCounter = 0;
// Protect code blocks first
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `<<CODE_${codeBlocks.length - 1}>>`;
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
});
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Remove LaTeX document commands
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\begin\{document\}/g, '');
processed = processed.replace(/\\end\{document\}/g, '');
processed = processed.replace(/\\maketitle/g, '');
processed = processed.replace(/\\title\{[^}]*\}/g, '');
processed = processed.replace(/\\author\{[^}]*\}/g, '');
processed = processed.replace(/\\date\{[^}]*\}/g, '');
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
processed = processed.replace(/\\require\{[^}]*\}/g, '');
// Remove unsupported LaTeX commands/environments (tikzpicture, etc.)
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, '[diagram]');
processed = processed.replace(/\\label\{[^}]*\}/g, '');
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
// Convert LaTeX math environments to display math
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*'];
for (const env of mathEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
processed = processed.replace(envRegex, (_, content) => {
// For align environments, wrap content properly for KaTeX
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
}
// Convert LaTeX proof environments to styled blocks
processed = processed.replace(
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
'<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">$1</div></div>'
);
// Convert LaTeX theorem-like environments
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
for (const env of theoremEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
const envName = env.charAt(0).toUpperCase() + env.slice(1);
processed = processed.replace(
envRegex,
`<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">$1</div></div>`
);
}
// Convert LaTeX text formatting commands
processed = processed.replace(/\\emph\{([^}]*)\}/g, '<em>$1</em>');
processed = processed.replace(/\\textit\{([^}]*)\}/g, '<em>$1</em>');
processed = processed.replace(/\\textbf\{([^}]*)\}/g, '<strong>$1</strong>');
processed = processed.replace(/\\texttt\{([^}]*)\}/g, '<code class="inline-code">$1</code>');
processed = processed.replace(/\\underline\{([^}]*)\}/g, '<u>$1</u>');
// Convert \(...\) to placeholder (display: false)
processed = processed.replace(/\\\((.+?)\\\)/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: false });
mathCounter++;
return placeholder;
});
// Convert \[...\] to placeholder (display: true)
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: true });
mathCounter++;
return placeholder;
});
// Extract display math ($$...$$) BEFORE markdown processing
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
mathCounter++;
return placeholder;
});
// Extract inline math ($...$) BEFORE markdown processing
// Skip currency patterns like $5 or $50
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
if (/^\d/.test(content.trim())) {
return match; // Keep as-is for currency
}
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
mathCounter++;
return placeholder;
});
// Restore escaped dollar signs
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
// Restore code blocks
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
return processed;
}
/**
* Render math expressions with KaTeX after HTML is generated
* Render math expressions with KaTeX - restores placeholders with rendered math
*/
function renderMath(html: string): string {
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
// Replace all math placeholders with rendered KaTeX
for (const [placeholder, { content, displayMode }] of mathExpressions) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
html = html.replace(regex, () => {
try {
const rendered = katex.renderToString(content, {
displayMode,
throwOnError: false,
output: 'html'
});
if (displayMode) {
return `
<div class="math-display-wrapper">
<div class="math-display-header">
<span class="math-label">LaTeX</span>
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<div class="math-display-content">
${rendered}
</div>
</div>
`;
} else {
return `<span class="math-inline">${rendered}</span>`;
}
} catch {
const display = displayMode ? `$$${content}$$` : `$${content}$`;
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
}
});
}
return html;
}
@@ -154,16 +280,50 @@
}
}
async function handleMathCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedSource = target.getAttribute('data-math-source');
if (!encodedSource) return;
const source = decodeURIComponent(encodedSource);
try {
await navigator.clipboard.writeText(source);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy math:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of codeButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
for (const button of mathButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleMathCopyClick);
}
}
}
$effect(() => {
@@ -424,28 +584,263 @@
color: #60a5fa;
}
/* KaTeX math styling */
/* KaTeX math styling - Base */
.markdown-content :global(.katex) {
font-size: 1.1em;
color: oklch(0.9 0 0);
}
.markdown-content :global(.katex-display) {
/* Display math container wrapper */
.markdown-content :global(.math-display-wrapper) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.15);
background: rgba(0, 0, 0, 0.3);
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover) {
border-color: rgba(255, 215, 0, 0.25);
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
}
/* Display math header - hidden by default, slides in on hover */
.markdown-content :global(.math-display-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.375rem 0.75rem;
background: rgba(255, 215, 0, 0.03);
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
opacity: 0;
max-height: 0;
padding-top: 0;
padding-bottom: 0;
overflow: hidden;
transition:
opacity 0.2s ease,
max-height 0.2s ease,
padding 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
opacity: 1;
max-height: 2.5rem;
padding: 0.375rem 0.75rem;
}
.markdown-content :global(.math-label) {
color: rgba(255, 215, 0, 0.7);
font-size: 0.65rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-math-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
opacity: 0;
transition:
color 0.2s,
opacity 0.15s ease;
}
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
opacity: 1;
}
.markdown-content :global(.copy-math-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-math-btn.copied) {
color: #22c55e;
}
/* Display math content area */
.markdown-content :global(.math-display-content) {
padding: 1rem 1.25rem;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
.markdown-content :global(.katex-display > .katex) {
/* Custom scrollbar for math overflow */
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
height: 6px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
background: rgba(255, 255, 255, 0.05);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
background: rgba(255, 215, 0, 0.2);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
background: rgba(255, 215, 0, 0.35);
}
.markdown-content :global(.math-display-content .katex-display) {
margin: 0;
padding: 0;
}
.markdown-content :global(.math-display-content .katex-display > .katex) {
text-align: center;
}
/* Inline math wrapper */
.markdown-content :global(.math-inline) {
display: inline;
padding: 0 0.125rem;
border-radius: 0.25rem;
transition: background-color 0.15s ease;
}
.markdown-content :global(.math-inline:hover) {
background: rgba(255, 215, 0, 0.05);
}
/* Dark theme KaTeX overrides */
.markdown-content :global(.katex .mord),
.markdown-content :global(.katex .minner),
.markdown-content :global(.katex .mop),
.markdown-content :global(.katex .mbin),
.markdown-content :global(.katex .mrel),
.markdown-content :global(.katex .mpunct) {
color: oklch(0.9 0 0);
}
/* Fraction lines and rules */
.markdown-content :global(.katex .frac-line),
.markdown-content :global(.katex .overline-line),
.markdown-content :global(.katex .underline-line),
.markdown-content :global(.katex .hline),
.markdown-content :global(.katex .rule) {
border-color: oklch(0.85 0 0) !important;
background: oklch(0.85 0 0);
}
/* Square roots and SVG elements */
.markdown-content :global(.katex .sqrt-line) {
border-color: oklch(0.85 0 0) !important;
}
.markdown-content :global(.katex svg) {
fill: oklch(0.85 0 0);
stroke: oklch(0.85 0 0);
}
.markdown-content :global(.katex svg path) {
stroke: oklch(0.85 0 0);
}
/* Delimiters (parentheses, brackets, braces) */
.markdown-content :global(.katex .delimsizing),
.markdown-content :global(.katex .delim-size1),
.markdown-content :global(.katex .delim-size2),
.markdown-content :global(.katex .delim-size3),
.markdown-content :global(.katex .delim-size4),
.markdown-content :global(.katex .mopen),
.markdown-content :global(.katex .mclose) {
color: oklch(0.75 0 0);
}
/* Math error styling */
.markdown-content :global(.math-error) {
display: inline-flex;
align-items: center;
gap: 0.375rem;
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.125rem 0.25rem;
padding: 0.25rem 0.5rem;
border-radius: 0.25rem;
border: 1px solid rgba(248, 113, 113, 0.2);
}
.markdown-content :global(.math-error-icon) {
font-size: 0.875em;
opacity: 0.9;
}
/* LaTeX proof environment */
.markdown-content :global(.latex-proof) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 255, 255, 0.02);
border-left: 3px solid rgba(255, 215, 0, 0.4);
border-radius: 0 0.375rem 0.375rem 0;
}
.markdown-content :global(.latex-proof-header) {
font-weight: 600;
font-style: italic;
color: oklch(0.85 0 0);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-proof-header::after) {
content: '.';
}
.markdown-content :global(.latex-proof-content) {
color: oklch(0.9 0 0);
}
.markdown-content :global(.latex-proof-content p:last-child) {
margin-bottom: 0;
}
/* QED symbol at end of proof */
.markdown-content :global(.latex-proof-content::after) {
content: '∎';
display: block;
text-align: right;
color: oklch(0.7 0 0);
margin-top: 0.5rem;
}
/* LaTeX theorem-like environments */
.markdown-content :global(.latex-theorem) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 215, 0, 0.03);
border: 1px solid rgba(255, 215, 0, 0.15);
border-radius: 0.375rem;
}
.markdown-content :global(.latex-theorem-header) {
font-weight: 700;
color: var(--exo-yellow, #ffd700);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-theorem-header::after) {
content: '.';
}
.markdown-content :global(.latex-theorem-content) {
color: oklch(0.9 0 0);
font-style: italic;
}
.markdown-content :global(.latex-theorem-content p:last-child) {
margin-bottom: 0;
}
</style>

View File

@@ -71,36 +71,35 @@ export interface Instance {
};
}
// Split state interfaces
interface RawNodeIdentity {
modelId: string;
chipId: string;
friendlyName: string;
}
interface RawNodeMemory {
ramTotal: { inBytes: number };
ramAvailable: { inBytes: number };
swapTotal: { inBytes: number };
swapAvailable: { inBytes: number };
}
interface RawNodeSystem {
gpuUsage?: number;
temp?: number;
sysPower?: number;
pcpuUsage?: number;
ecpuUsage?: number;
anePower?: number;
}
interface RawNetworkInterface {
name: string;
ipAddress: string;
interface RawNodeProfile {
modelId?: string;
chipId?: string;
friendlyName?: string;
networkInterfaces?: Array<{
name?: string;
ipAddress?: string;
addresses?: Array<{ address?: string } | string>;
ipv4?: string;
ipv6?: string;
ipAddresses?: string[];
ips?: string[];
}>;
memory?: {
ramTotal?: { inBytes: number };
ramAvailable?: { inBytes: number };
swapTotal?: { inBytes: number };
swapAvailable?: { inBytes: number };
};
system?: {
gpuUsage?: number;
temp?: number;
sysPower?: number;
};
}
interface RawTopologyNode {
nodeId: string;
nodeProfile: RawNodeProfile;
}
interface RawTopologyConnection {
@@ -116,6 +115,8 @@ interface RawTopology {
connections?: RawTopologyConnection[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
export interface DownloadProgress {
totalBytes: number;
downloadedBytes: number;
@@ -170,11 +171,7 @@ interface RawStateResponse {
>;
runners?: Record<string, unknown>;
downloads?: Record<string, unknown[]>;
// Split state fields
nodeIdentities?: Record<string, RawNodeIdentity>;
nodeMemories?: Record<string, RawNodeMemory>;
nodeSystems?: Record<string, RawNodeSystem>;
nodeNetworks?: Record<string, RawNetworkInterface[]>;
nodeProfiles?: RawNodeProfiles;
}
export interface MessageAttachment {
@@ -211,41 +208,66 @@ const STORAGE_KEY = "exo-conversations";
function transformTopology(
raw: RawTopology,
identities?: Record<string, RawNodeIdentity>,
memories?: Record<string, RawNodeMemory>,
systems?: Record<string, RawNodeSystem>,
networks?: Record<string, RawNetworkInterface[]>,
profiles?: RawNodeProfiles,
): TopologyData {
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
for (const node of raw.nodes || []) {
// Get split state fields (may be undefined if events haven't arrived yet)
const identity = identities?.[node.nodeId];
const memory = memories?.[node.nodeId];
const system = systems?.[node.nodeId];
const network = networks?.[node.nodeId];
const ramTotal = memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = memory?.ramAvailable?.inBytes ?? 0;
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
const networkInterfaces = (network ?? []).map((iface) => ({
name: iface.name,
addresses: [iface.ipAddress],
}));
const networkInterfaces = (profile?.networkInterfaces || []).map(
(iface) => {
const addresses: string[] = [];
if (iface.ipAddress && typeof iface.ipAddress === "string") {
addresses.push(iface.ipAddress);
}
if (Array.isArray(iface.addresses)) {
for (const addr of iface.addresses) {
if (typeof addr === "string") addresses.push(addr);
else if (addr && typeof addr === "object" && addr.address)
addresses.push(addr.address);
}
}
if (Array.isArray(iface.ipAddresses)) {
addresses.push(
...iface.ipAddresses.filter(
(a): a is string => typeof a === "string",
),
);
}
if (Array.isArray(iface.ips)) {
addresses.push(
...iface.ips.filter((a): a is string => typeof a === "string"),
);
}
if (iface.ipv4 && typeof iface.ipv4 === "string")
addresses.push(iface.ipv4);
if (iface.ipv6 && typeof iface.ipv6 === "string")
addresses.push(iface.ipv6);
return {
name: iface.name,
addresses: Array.from(new Set(addresses)),
};
},
);
const ipToInterface: Record<string, string> = {};
for (const iface of networkInterfaces) {
for (const addr of iface.addresses) {
ipToInterface[addr] = iface.name;
for (const addr of iface.addresses || []) {
ipToInterface[addr] = iface.name ?? "";
}
}
nodes[node.nodeId] = {
system_info: {
model_id: identity?.modelId ?? "Unknown",
chip: identity?.chipId,
model_id: profile?.modelId ?? "Unknown",
chip: profile?.chipId,
memory: ramTotal,
},
network_interfaces: networkInterfaces,
@@ -256,15 +278,17 @@ function transformTopology(
ram_total: ramTotal,
},
temp:
system?.temp !== undefined
? { gpu_temp_avg: system.temp }
profile?.system?.temp !== undefined
? { gpu_temp_avg: profile.system.temp }
: undefined,
gpu_usage:
system?.gpuUsage !== undefined ? [0, system.gpuUsage] : undefined,
sys_power: system?.sysPower,
profile?.system?.gpuUsage !== undefined
? [0, profile.system.gpuUsage]
: undefined,
sys_power: profile?.system?.sysPower,
},
last_macmon_update: Date.now() / 1000,
friendly_name: identity?.friendlyName,
friendly_name: profile?.friendlyName,
};
}
@@ -844,13 +868,7 @@ class AppStore {
const data: RawStateResponse = await response.json();
if (data.topology) {
this.topologyData = transformTopology(
data.topology,
data.nodeIdentities,
data.nodeMemories,
data.nodeSystems,
data.nodeNetworks,
);
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
}
if (data.instances) {
this.instances = data.instances;

View File

@@ -599,8 +599,9 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for memory in self.state.node_memories.values():
total_available += memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
return total_available

View File

@@ -113,7 +113,6 @@ def place_instance(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
and node.node_profile.memory is not None
),
start=Memory(),
),

View File

@@ -25,10 +25,7 @@ class NodeWithProfile(BaseModel):
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(
node.node_profile is not None and node.node_profile.memory is not None
for node in nodes
)
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
@@ -39,14 +36,8 @@ def filter_cycles_by_memory(
if not narrow_all_nodes(cycle):
continue
# narrow_all_nodes guarantees memory is not None
total_mem = sum(
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile.memory is not None
),
start=Memory(),
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[NodeInfo], cycle))
@@ -62,13 +53,8 @@ def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
# NodeWithProfile guarantees memory is not None
cycle_memory = sum(
(
node.node_profile.memory.ram_available
for node in selected_cycle
if node.node_profile.memory is not None
),
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
total_layers = model_meta.n_layers
@@ -81,8 +67,6 @@ def get_shard_assignments_for_pipeline_parallel(
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
# NodeWithProfile guarantees memory is not None
assert node.node_profile.memory is not None
node_layers = round(
total_layers
* (

View File

@@ -19,13 +19,16 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodeIdentityMeasured,
NodeMemoryMeasured,
NodePerformanceMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.worker.instances import (
@@ -72,39 +75,29 @@ async def test_master():
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject NodeIdentityMeasured and NodeMemoryMeasured events
logger.info("inject NodeIdentityMeasured event")
# inject a NodePerformanceProfile event
logger.info("inject a NodePerformanceProfile event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
session=session_id,
event=(
NodeIdentityMeasured(
NodePerformanceMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
model_id="maccy",
chip_id="arm",
friendly_name="test",
)
),
)
)
logger.info("inject NodeMemoryMeasured event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=1,
origin=sender_node_id,
session=session_id,
event=(
NodeMemoryMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
),
@@ -115,7 +108,7 @@ async def test_master():
logger.info("wait for initial topology event")
while len(list(master.state.topology.list_nodes())) == 0:
await anyio.sleep(0.001)
while len(master.state.node_identities) == 0:
while len(master.state.node_profiles) == 0:
await anyio.sleep(0.001)
logger.info("inject a CreateInstance Command")
@@ -162,19 +155,17 @@ async def test_master():
),
)
)
while len(_get_events()) < 4:
while len(_get_events()) < 3:
await anyio.sleep(0.01)
events = _get_events()
assert len(events) == 4
assert len(events) == 3
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert events[3].idx == 3
assert isinstance(events[0].event, NodeIdentityMeasured)
assert isinstance(events[1].event, NodeMemoryMeasured)
assert isinstance(events[2].event, InstanceCreated)
created_instance = events[2].event.instance
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
# Validate the shard assignments
@@ -206,10 +197,10 @@ async def test_master():
assert len(created_instance.hosts_by_node[node_id]) == 1
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
assert created_instance.ephemeral_port > 0
assert isinstance(events[3].event, TaskCreated)
assert events[3].event.task.task_status == TaskStatus.Pending
assert isinstance(events[3].event.task, ChatCompletionTask)
assert events[3].event.task.task_params == ChatCompletionTaskParams(
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)
assert events[2].event.task.task_params == ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")

View File

@@ -13,10 +13,8 @@ from exo.shared.types.events import (
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeIdentityMeasured,
NodeMemoryMeasured,
NodeNetworkMeasured,
NodeSystemMeasured,
NodePerformanceMeasured,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -29,13 +27,7 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
NodeIdentity,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import NodeInfo
@@ -59,12 +51,8 @@ def event_apply(event: Event, state: State) -> State:
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeIdentityMeasured():
return apply_node_identity_measured(event, state)
case NodeSystemMeasured():
return apply_node_system_measured(event, state)
case NodeNetworkMeasured():
return apply_node_network_measured(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeMemoryMeasured():
@@ -202,19 +190,8 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memories = {
key: value for key, value in state.node_memories.items() if key != event.node_id
}
node_systems = {
key: value for key, value in state.node_systems.items() if key != event.node_id
}
node_networks = {
key: value for key, value in state.node_networks.items() if key != event.node_id
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
}
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
@@ -222,120 +199,32 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
return state.model_copy(
update={
"topology": topology,
"node_identities": node_identities,
"node_memories": node_memories,
"node_systems": node_systems,
"node_networks": node_networks,
"node_profiles": node_profiles,
"last_seen": last_seen,
}
)
def _reconstruct_profile(
node_id: NodeId,
state: State,
*,
identity: NodeIdentity | None = None,
memory: MemoryPerformanceProfile | None = None,
system: SystemPerformanceProfile | None = None,
network_interfaces: list[NetworkInterfaceInfo] | None = None,
) -> NodePerformanceProfile:
"""Reconstruct a NodePerformanceProfile from split state storage.
Uses provided overrides, falling back to state values.
"""
ident = identity or state.node_identities.get(node_id)
mem = memory or state.node_memories.get(node_id)
sys = system or state.node_systems.get(node_id)
nets = (
network_interfaces
if network_interfaces is not None
else state.node_networks.get(node_id, [])
)
return NodePerformanceProfile(
model_id=ident.model_id if ident else None,
chip_id=ident.chip_id if ident else None,
friendly_name=ident.friendly_name if ident else None,
memory=mem,
network_interfaces=nets,
system=sys,
)
def apply_node_identity_measured(event: NodeIdentityMeasured, state: State) -> State:
topology = copy.copy(state.topology)
identity = NodeIdentity(
model_id=event.model_id,
chip_id=event.chip_id,
friendly_name=event.friendly_name,
)
new_identities: Mapping[NodeId, NodeIdentity] = {
**state.node_identities,
event.node_id: identity,
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
reconstructed = _reconstruct_profile(event.node_id, state, identity=identity)
topology.update_node_profile(event.node_id, reconstructed)
return state.model_copy(
update={
"node_identities": new_identities,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_system_measured(event: NodeSystemMeasured, state: State) -> State:
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
new_systems: Mapping[NodeId, SystemPerformanceProfile] = {
**state.node_systems,
event.node_id: event.system,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
reconstructed = _reconstruct_profile(event.node_id, state, system=event.system)
topology.update_node_profile(event.node_id, reconstructed)
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_systems": new_systems,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_network_measured(event: NodeNetworkMeasured, state: State) -> State:
topology = copy.copy(state.topology)
new_networks: Mapping[NodeId, list[NetworkInterfaceInfo]] = {
**state.node_networks,
event.node_id: event.network_interfaces,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
reconstructed = _reconstruct_profile(
event.node_id, state, network_interfaces=event.network_interfaces
)
topology.update_node_profile(event.node_id, reconstructed)
return state.model_copy(
update={
"node_networks": new_networks,
"node_profiles": new_profiles,
"topology": topology,
"last_seen": last_seen,
}
@@ -343,26 +232,57 @@ def apply_node_network_measured(event: NodeNetworkMeasured, state: State) -> Sta
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
new_memories: Mapping[NodeId, MemoryPerformanceProfile] = {
**state.node_memories,
event.node_id: event.memory,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
reconstructed = _reconstruct_profile(event.node_id, state, memory=event.memory)
topology.update_node_profile(event.node_id, reconstructed)
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={
"node_memories": new_memories,
"topology": topology,
"last_seen": last_seen,
}
update={"node_profiles": updated_profiles, "topology": topology}
)

View File

@@ -2,14 +2,10 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
SystemPerformanceProfile,
)
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -89,35 +85,13 @@ class NodeTimedOut(BaseEvent):
node_id: NodeId
class NodeIdentityMeasured(BaseEvent):
"""Static identity info - emitted once at startup."""
class NodePerformanceMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
model_id: str
chip_id: str
friendly_name: str
class NodeSystemMeasured(BaseEvent):
"""Dynamic system metrics (GPU, temp, power) - emitted at 1s intervals."""
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
system: SystemPerformanceProfile
class NodeNetworkMeasured(BaseEvent):
"""Semi-static network interface info - emitted at 30s intervals."""
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
network_interfaces: list[NetworkInterfaceInfo]
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
"""Dynamic memory metrics - emitted at 0.5s intervals."""
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
@@ -153,9 +127,7 @@ Event = (
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodeIdentityMeasured
| NodeSystemMeasured
| NodeNetworkMeasured
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated

View File

@@ -52,21 +52,13 @@ class NetworkInterfaceInfo(CamelCaseModel):
ip_address: str
class NodeIdentity(CamelCaseModel):
"""Static identity info for a node."""
class NodePerformanceProfile(CamelCaseModel):
model_id: str
chip_id: str
friendly_name: str
class NodePerformanceProfile(CamelCaseModel):
model_id: str | None = None
chip_id: str | None = None
friendly_name: str | None = None
memory: MemoryPerformanceProfile | None = None
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile | None = None
system: SystemPerformanceProfile
class ConnectionProfile(CamelCaseModel):

View File

@@ -7,12 +7,7 @@ from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
NodeIdentity,
SystemPerformanceProfile,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -40,10 +35,7 @@ class State(CamelCaseModel):
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}
node_identities: Mapping[NodeId, NodeIdentity] = {}
node_memories: Mapping[NodeId, MemoryPerformanceProfile] = {}
node_systems: Mapping[NodeId, SystemPerformanceProfile] = {}
node_networks: Mapping[NodeId, list[NetworkInterfaceInfo]] = {}
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
last_seen: Mapping[NodeId, datetime] = {}
topology: Topology = Field(default_factory=Topology)
last_event_applied_idx: int = Field(default=-1, ge=-1)

View File

@@ -245,12 +245,15 @@ def create_http_session(
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(cafile=certifi.where())
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,

View File

@@ -16,10 +16,8 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeIdentityMeasured,
NodeMemoryMeasured,
NodeNetworkMeasured,
NodeSystemMeasured,
NodePerformanceMeasured,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
@@ -27,11 +25,7 @@ from exo.shared.types.events import (
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
SystemPerformanceProfile,
)
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -57,13 +51,7 @@ from exo.worker.download.download_utils import (
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import (
IdentityMetrics,
start_polling_identity_metrics,
start_polling_memory_metrics,
start_polling_network_metrics,
start_polling_system_metrics,
)
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
@@ -110,51 +98,37 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
async def identity_callback(identity: IdentityMetrics) -> None:
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodeIdentityMeasured(
NodePerformanceMeasured(
node_id=self.node_id,
model_id=identity.model_id,
chip_id=identity.chip_id,
friendly_name=identity.friendly_name,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
async def system_callback(system: SystemPerformanceProfile) -> None:
await self.event_sender.send(
NodeSystemMeasured(
node_id=self.node_id,
system=system,
when=str(datetime.now(tz=timezone.utc)),
),
)
async def network_callback(interfaces: list[NetworkInterfaceInfo]) -> None:
await self.event_sender.send(
NodeNetworkMeasured(
node_id=self.node_id,
network_interfaces=interfaces,
when=str(datetime.now(tz=timezone.utc)),
),
)
async def memory_callback(memory: MemoryPerformanceProfile) -> None:
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_identity_metrics, identity_callback)
tg.start_soon(start_polling_system_metrics, system_callback)
tg.start_soon(start_polling_network_metrics, network_callback)
tg.start_soon(start_polling_memory_metrics, memory_callback)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)

View File

@@ -1,15 +1,6 @@
from .profile import (
IdentityMetrics,
start_polling_identity_metrics,
start_polling_memory_metrics,
start_polling_network_metrics,
start_polling_system_metrics,
)
from .profile import start_polling_memory_metrics, start_polling_node_metrics
__all__ = [
"IdentityMetrics",
"start_polling_identity_metrics",
"start_polling_node_metrics",
"start_polling_memory_metrics",
"start_polling_network_metrics",
"start_polling_system_metrics",
]

View File

@@ -1,7 +1,6 @@
import asyncio
import os
import platform
from dataclasses import dataclass
from typing import Any, Callable, Coroutine
import anyio
@@ -10,7 +9,7 @@ from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
NodePerformanceProfile,
SystemPerformanceProfile,
)
@@ -28,13 +27,6 @@ from .system_info import (
)
@dataclass(frozen=True)
class IdentityMetrics:
model_id: str
chip_id: str
friendly_name: str
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
@@ -75,73 +67,48 @@ async def start_polling_memory_metrics(
await anyio.sleep(poll_interval_s)
async def start_polling_identity_metrics(
callback: Callable[[IdentityMetrics], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 30.0,
) -> None:
"""Continuously poll and emit identity metrics at 30s intervals."""
while True:
try:
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
await callback(
IdentityMetrics(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
)
)
except Exception as e:
logger.opt(exception=e).error("Failed to emit identity metrics")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_system_metrics(
callback: Callable[[SystemPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 1.0,
) -> None:
"""Continuously poll and emit system metrics (GPU, temp, power) at 1s intervals."""
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
)
)
except asyncio.TimeoutError:
logger.warning(
"[system_monitor] Operation timed out after 30s, skipping this cycle."
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
)
except MacMonError as e:
logger.opt(exception=e).error("System Monitor encountered error")
logger.opt(exception=e).error("Resource Monitor encountered error")
return
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_network_metrics(
callback: Callable[[list[NetworkInterfaceInfo]], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 30.0,
) -> None:
"""Continuously poll and emit network interface info at 30s intervals."""
while True:
try:
network_interfaces = get_network_interfaces()
await callback(network_interfaces)
except Exception as e:
logger.opt(exception=e).error("Network Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)