mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-19 03:22:01 -05:00
Compare commits
3 Commits
alexcheema
...
enhance-la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ab1cb285b | ||
|
|
9c29eb7d48 | ||
|
|
c5158bee53 |
25
AGENTS.md
25
AGENTS.md
@@ -40,6 +40,31 @@ uv run ruff check
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Pre-Commit Checks (REQUIRED)
|
||||
|
||||
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
|
||||
|
||||
```bash
|
||||
# 1. Type checking - MUST pass with 0 errors
|
||||
uv run basedpyright
|
||||
|
||||
# 2. Linting - MUST pass
|
||||
uv run ruff check
|
||||
|
||||
# 3. Formatting - MUST be applied
|
||||
nix fmt
|
||||
|
||||
# 4. Tests - MUST pass
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Run all checks in sequence:
|
||||
```bash
|
||||
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
|
||||
```
|
||||
|
||||
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
|
||||
@@ -53,62 +53,188 @@
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
* Unescape HTML entities that marked may have escaped
|
||||
*/
|
||||
function unescapeHtmlEntities(text: string): string {
|
||||
return text
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/&/g, '&')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
// Storage for math expressions extracted before markdown processing
|
||||
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
|
||||
let mathCounter = 0;
|
||||
|
||||
// Use alphanumeric placeholders that won't be interpreted as HTML tags
|
||||
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
|
||||
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
// Reset math storage
|
||||
mathExpressions.clear();
|
||||
mathCounter = 0;
|
||||
|
||||
// Protect code blocks first
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
// Remove LaTeX document commands
|
||||
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\begin\{document\}/g, '');
|
||||
processed = processed.replace(/\\end\{document\}/g, '');
|
||||
processed = processed.replace(/\\maketitle/g, '');
|
||||
processed = processed.replace(/\\title\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\author\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\date\{[^}]*\}/g, '');
|
||||
|
||||
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
|
||||
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
|
||||
processed = processed.replace(/\\require\{[^}]*\}/g, '');
|
||||
|
||||
// Remove unsupported LaTeX commands/environments (tikzpicture, etc.)
|
||||
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, '[diagram]');
|
||||
processed = processed.replace(/\\label\{[^}]*\}/g, '');
|
||||
|
||||
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
|
||||
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
|
||||
|
||||
// Convert LaTeX math environments to display math
|
||||
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*'];
|
||||
for (const env of mathEnvs) {
|
||||
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
|
||||
processed = processed.replace(envRegex, (_, content) => {
|
||||
// For align environments, wrap content properly for KaTeX
|
||||
const cleanEnv = env.replace('\\*', '*');
|
||||
const mathContent = `\\begin{${cleanEnv}}${content}\\end{${cleanEnv}}`;
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert LaTeX proof environments to styled blocks
|
||||
processed = processed.replace(
|
||||
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
|
||||
'<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">$1</div></div>'
|
||||
);
|
||||
|
||||
// Convert LaTeX theorem-like environments
|
||||
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
|
||||
for (const env of theoremEnvs) {
|
||||
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
|
||||
const envName = env.charAt(0).toUpperCase() + env.slice(1);
|
||||
processed = processed.replace(
|
||||
envRegex,
|
||||
`<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">$1</div></div>`
|
||||
);
|
||||
}
|
||||
|
||||
// Convert LaTeX text formatting commands
|
||||
processed = processed.replace(/\\emph\{([^}]*)\}/g, '<em>$1</em>');
|
||||
processed = processed.replace(/\\textit\{([^}]*)\}/g, '<em>$1</em>');
|
||||
processed = processed.replace(/\\textbf\{([^}]*)\}/g, '<strong>$1</strong>');
|
||||
processed = processed.replace(/\\texttt\{([^}]*)\}/g, '<code class="inline-code">$1</code>');
|
||||
processed = processed.replace(/\\underline\{([^}]*)\}/g, '<u>$1</u>');
|
||||
|
||||
// Convert \(...\) to placeholder (display: false)
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content, displayMode: false });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Convert \[...\] to placeholder (display: true)
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Extract display math ($$...$$) BEFORE markdown processing
|
||||
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Extract inline math ($...$) BEFORE markdown processing
|
||||
// Skip currency patterns like $5 or $50
|
||||
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
|
||||
if (/^\d/.test(content.trim())) {
|
||||
return match; // Keep as-is for currency
|
||||
}
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Restore escaped dollar signs
|
||||
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
* Render math expressions with KaTeX - restores placeholders with rendered math
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
// Replace all math placeholders with rendered KaTeX
|
||||
for (const [placeholder, { content, displayMode }] of mathExpressions) {
|
||||
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const regex = new RegExp(escapedPlaceholder, 'g');
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
html = html.replace(regex, () => {
|
||||
try {
|
||||
const rendered = katex.renderToString(content, {
|
||||
displayMode,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
|
||||
if (displayMode) {
|
||||
return `
|
||||
<div class="math-display-wrapper">
|
||||
<div class="math-display-header">
|
||||
<span class="math-label">LaTeX</span>
|
||||
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="math-display-content">
|
||||
${rendered}
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
} else {
|
||||
return `<span class="math-inline">${rendered}</span>`;
|
||||
}
|
||||
} catch {
|
||||
const display = displayMode ? `$$${content}$$` : `$${content}$`;
|
||||
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return html;
|
||||
}
|
||||
@@ -154,16 +280,50 @@
|
||||
}
|
||||
}
|
||||
|
||||
async function handleMathCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedSource = target.getAttribute('data-math-source');
|
||||
if (!encodedSource) return;
|
||||
|
||||
const source = decodeURIComponent(encodedSource);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(source);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy math:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of codeButtons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
|
||||
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
|
||||
for (const button of mathButtons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleMathCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
@@ -424,28 +584,263 @@
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
/* KaTeX math styling - Base */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
/* Display math container wrapper */
|
||||
.markdown-content :global(.math-display-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.15);
|
||||
background: rgba(0, 0, 0, 0.3);
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover) {
|
||||
border-color: rgba(255, 215, 0, 0.25);
|
||||
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
|
||||
}
|
||||
|
||||
/* Display math header - hidden by default, slides in on hover */
|
||||
.markdown-content :global(.math-display-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.375rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.03);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
|
||||
opacity: 0;
|
||||
max-height: 0;
|
||||
padding-top: 0;
|
||||
padding-bottom: 0;
|
||||
overflow: hidden;
|
||||
transition:
|
||||
opacity 0.2s ease,
|
||||
max-height 0.2s ease,
|
||||
padding 0.2s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
|
||||
opacity: 1;
|
||||
max-height: 2.5rem;
|
||||
padding: 0.375rem 0.75rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-label) {
|
||||
color: rgba(255, 215, 0, 0.7);
|
||||
font-size: 0.65rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
opacity: 0;
|
||||
transition:
|
||||
color 0.2s,
|
||||
opacity 0.15s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
/* Display math content area */
|
||||
.markdown-content :global(.math-display-content) {
|
||||
padding: 1rem 1.25rem;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
/* Custom scrollbar for math overflow */
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
|
||||
background: rgba(255, 215, 0, 0.2);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
|
||||
background: rgba(255, 215, 0, 0.35);
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content .katex-display) {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content .katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/* Inline math wrapper */
|
||||
.markdown-content :global(.math-inline) {
|
||||
display: inline;
|
||||
padding: 0 0.125rem;
|
||||
border-radius: 0.25rem;
|
||||
transition: background-color 0.15s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-inline:hover) {
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
}
|
||||
|
||||
/* Dark theme KaTeX overrides */
|
||||
.markdown-content :global(.katex .mord),
|
||||
.markdown-content :global(.katex .minner),
|
||||
.markdown-content :global(.katex .mop),
|
||||
.markdown-content :global(.katex .mbin),
|
||||
.markdown-content :global(.katex .mrel),
|
||||
.markdown-content :global(.katex .mpunct) {
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
/* Fraction lines and rules */
|
||||
.markdown-content :global(.katex .frac-line),
|
||||
.markdown-content :global(.katex .overline-line),
|
||||
.markdown-content :global(.katex .underline-line),
|
||||
.markdown-content :global(.katex .hline),
|
||||
.markdown-content :global(.katex .rule) {
|
||||
border-color: oklch(0.85 0 0) !important;
|
||||
background: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
/* Square roots and SVG elements */
|
||||
.markdown-content :global(.katex .sqrt-line) {
|
||||
border-color: oklch(0.85 0 0) !important;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex svg) {
|
||||
fill: oklch(0.85 0 0);
|
||||
stroke: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex svg path) {
|
||||
stroke: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
/* Delimiters (parentheses, brackets, braces) */
|
||||
.markdown-content :global(.katex .delimsizing),
|
||||
.markdown-content :global(.katex .delim-size1),
|
||||
.markdown-content :global(.katex .delim-size2),
|
||||
.markdown-content :global(.katex .delim-size3),
|
||||
.markdown-content :global(.katex .delim-size4),
|
||||
.markdown-content :global(.katex .mopen),
|
||||
.markdown-content :global(.katex .mclose) {
|
||||
color: oklch(0.75 0 0);
|
||||
}
|
||||
|
||||
/* Math error styling */
|
||||
.markdown-content :global(.math-error) {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.375rem;
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
padding: 0.25rem 0.5rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid rgba(248, 113, 113, 0.2);
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error-icon) {
|
||||
font-size: 0.875em;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
/* LaTeX proof environment */
|
||||
.markdown-content :global(.latex-proof) {
|
||||
margin: 1rem 0;
|
||||
padding: 1rem 1.25rem;
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border-left: 3px solid rgba(255, 215, 0, 0.4);
|
||||
border-radius: 0 0.375rem 0.375rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-header) {
|
||||
font-weight: 600;
|
||||
font-style: italic;
|
||||
color: oklch(0.85 0 0);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-header::after) {
|
||||
content: '.';
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-content) {
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-content p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* QED symbol at end of proof */
|
||||
.markdown-content :global(.latex-proof-content::after) {
|
||||
content: '∎';
|
||||
display: block;
|
||||
text-align: right;
|
||||
color: oklch(0.7 0 0);
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
/* LaTeX theorem-like environments */
|
||||
.markdown-content :global(.latex-theorem) {
|
||||
margin: 1rem 0;
|
||||
padding: 1rem 1.25rem;
|
||||
background: rgba(255, 215, 0, 0.03);
|
||||
border: 1px solid rgba(255, 215, 0, 0.15);
|
||||
border-radius: 0.375rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-header) {
|
||||
font-weight: 700;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-header::after) {
|
||||
content: '.';
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-content) {
|
||||
color: oklch(0.9 0 0);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-content p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -71,36 +71,35 @@ export interface Instance {
|
||||
};
|
||||
}
|
||||
|
||||
// Split state interfaces
|
||||
interface RawNodeIdentity {
|
||||
modelId: string;
|
||||
chipId: string;
|
||||
friendlyName: string;
|
||||
}
|
||||
|
||||
interface RawNodeMemory {
|
||||
ramTotal: { inBytes: number };
|
||||
ramAvailable: { inBytes: number };
|
||||
swapTotal: { inBytes: number };
|
||||
swapAvailable: { inBytes: number };
|
||||
}
|
||||
|
||||
interface RawNodeSystem {
|
||||
gpuUsage?: number;
|
||||
temp?: number;
|
||||
sysPower?: number;
|
||||
pcpuUsage?: number;
|
||||
ecpuUsage?: number;
|
||||
anePower?: number;
|
||||
}
|
||||
|
||||
interface RawNetworkInterface {
|
||||
name: string;
|
||||
ipAddress: string;
|
||||
interface RawNodeProfile {
|
||||
modelId?: string;
|
||||
chipId?: string;
|
||||
friendlyName?: string;
|
||||
networkInterfaces?: Array<{
|
||||
name?: string;
|
||||
ipAddress?: string;
|
||||
addresses?: Array<{ address?: string } | string>;
|
||||
ipv4?: string;
|
||||
ipv6?: string;
|
||||
ipAddresses?: string[];
|
||||
ips?: string[];
|
||||
}>;
|
||||
memory?: {
|
||||
ramTotal?: { inBytes: number };
|
||||
ramAvailable?: { inBytes: number };
|
||||
swapTotal?: { inBytes: number };
|
||||
swapAvailable?: { inBytes: number };
|
||||
};
|
||||
system?: {
|
||||
gpuUsage?: number;
|
||||
temp?: number;
|
||||
sysPower?: number;
|
||||
};
|
||||
}
|
||||
|
||||
interface RawTopologyNode {
|
||||
nodeId: string;
|
||||
nodeProfile: RawNodeProfile;
|
||||
}
|
||||
|
||||
interface RawTopologyConnection {
|
||||
@@ -116,6 +115,8 @@ interface RawTopology {
|
||||
connections?: RawTopologyConnection[];
|
||||
}
|
||||
|
||||
type RawNodeProfiles = Record<string, RawNodeProfile>;
|
||||
|
||||
export interface DownloadProgress {
|
||||
totalBytes: number;
|
||||
downloadedBytes: number;
|
||||
@@ -170,11 +171,7 @@ interface RawStateResponse {
|
||||
>;
|
||||
runners?: Record<string, unknown>;
|
||||
downloads?: Record<string, unknown[]>;
|
||||
// Split state fields
|
||||
nodeIdentities?: Record<string, RawNodeIdentity>;
|
||||
nodeMemories?: Record<string, RawNodeMemory>;
|
||||
nodeSystems?: Record<string, RawNodeSystem>;
|
||||
nodeNetworks?: Record<string, RawNetworkInterface[]>;
|
||||
nodeProfiles?: RawNodeProfiles;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -211,41 +208,66 @@ const STORAGE_KEY = "exo-conversations";
|
||||
|
||||
function transformTopology(
|
||||
raw: RawTopology,
|
||||
identities?: Record<string, RawNodeIdentity>,
|
||||
memories?: Record<string, RawNodeMemory>,
|
||||
systems?: Record<string, RawNodeSystem>,
|
||||
networks?: Record<string, RawNetworkInterface[]>,
|
||||
profiles?: RawNodeProfiles,
|
||||
): TopologyData {
|
||||
const nodes: Record<string, NodeInfo> = {};
|
||||
const edges: TopologyEdge[] = [];
|
||||
|
||||
for (const node of raw.nodes || []) {
|
||||
// Get split state fields (may be undefined if events haven't arrived yet)
|
||||
const identity = identities?.[node.nodeId];
|
||||
const memory = memories?.[node.nodeId];
|
||||
const system = systems?.[node.nodeId];
|
||||
const network = networks?.[node.nodeId];
|
||||
|
||||
const ramTotal = memory?.ramTotal?.inBytes ?? 0;
|
||||
const ramAvailable = memory?.ramAvailable?.inBytes ?? 0;
|
||||
const mergedProfile = profiles?.[node.nodeId];
|
||||
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
|
||||
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
|
||||
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
|
||||
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
|
||||
|
||||
const networkInterfaces = (network ?? []).map((iface) => ({
|
||||
name: iface.name,
|
||||
addresses: [iface.ipAddress],
|
||||
}));
|
||||
const networkInterfaces = (profile?.networkInterfaces || []).map(
|
||||
(iface) => {
|
||||
const addresses: string[] = [];
|
||||
if (iface.ipAddress && typeof iface.ipAddress === "string") {
|
||||
addresses.push(iface.ipAddress);
|
||||
}
|
||||
if (Array.isArray(iface.addresses)) {
|
||||
for (const addr of iface.addresses) {
|
||||
if (typeof addr === "string") addresses.push(addr);
|
||||
else if (addr && typeof addr === "object" && addr.address)
|
||||
addresses.push(addr.address);
|
||||
}
|
||||
}
|
||||
if (Array.isArray(iface.ipAddresses)) {
|
||||
addresses.push(
|
||||
...iface.ipAddresses.filter(
|
||||
(a): a is string => typeof a === "string",
|
||||
),
|
||||
);
|
||||
}
|
||||
if (Array.isArray(iface.ips)) {
|
||||
addresses.push(
|
||||
...iface.ips.filter((a): a is string => typeof a === "string"),
|
||||
);
|
||||
}
|
||||
if (iface.ipv4 && typeof iface.ipv4 === "string")
|
||||
addresses.push(iface.ipv4);
|
||||
if (iface.ipv6 && typeof iface.ipv6 === "string")
|
||||
addresses.push(iface.ipv6);
|
||||
|
||||
return {
|
||||
name: iface.name,
|
||||
addresses: Array.from(new Set(addresses)),
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const ipToInterface: Record<string, string> = {};
|
||||
for (const iface of networkInterfaces) {
|
||||
for (const addr of iface.addresses) {
|
||||
ipToInterface[addr] = iface.name;
|
||||
for (const addr of iface.addresses || []) {
|
||||
ipToInterface[addr] = iface.name ?? "";
|
||||
}
|
||||
}
|
||||
|
||||
nodes[node.nodeId] = {
|
||||
system_info: {
|
||||
model_id: identity?.modelId ?? "Unknown",
|
||||
chip: identity?.chipId,
|
||||
model_id: profile?.modelId ?? "Unknown",
|
||||
chip: profile?.chipId,
|
||||
memory: ramTotal,
|
||||
},
|
||||
network_interfaces: networkInterfaces,
|
||||
@@ -256,15 +278,17 @@ function transformTopology(
|
||||
ram_total: ramTotal,
|
||||
},
|
||||
temp:
|
||||
system?.temp !== undefined
|
||||
? { gpu_temp_avg: system.temp }
|
||||
profile?.system?.temp !== undefined
|
||||
? { gpu_temp_avg: profile.system.temp }
|
||||
: undefined,
|
||||
gpu_usage:
|
||||
system?.gpuUsage !== undefined ? [0, system.gpuUsage] : undefined,
|
||||
sys_power: system?.sysPower,
|
||||
profile?.system?.gpuUsage !== undefined
|
||||
? [0, profile.system.gpuUsage]
|
||||
: undefined,
|
||||
sys_power: profile?.system?.sysPower,
|
||||
},
|
||||
last_macmon_update: Date.now() / 1000,
|
||||
friendly_name: identity?.friendlyName,
|
||||
friendly_name: profile?.friendlyName,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -844,13 +868,7 @@ class AppStore {
|
||||
const data: RawStateResponse = await response.json();
|
||||
|
||||
if (data.topology) {
|
||||
this.topologyData = transformTopology(
|
||||
data.topology,
|
||||
data.nodeIdentities,
|
||||
data.nodeMemories,
|
||||
data.nodeSystems,
|
||||
data.nodeNetworks,
|
||||
);
|
||||
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
|
||||
}
|
||||
if (data.instances) {
|
||||
this.instances = data.instances;
|
||||
|
||||
@@ -599,8 +599,9 @@ class API:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
for memory in self.state.node_memories.values():
|
||||
total_available += memory.ram_available
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available
|
||||
|
||||
return total_available
|
||||
|
||||
|
||||
@@ -113,7 +113,6 @@ def place_instance(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
and node.node_profile.memory is not None
|
||||
),
|
||||
start=Memory(),
|
||||
),
|
||||
|
||||
@@ -25,10 +25,7 @@ class NodeWithProfile(BaseModel):
|
||||
|
||||
|
||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(
|
||||
node.node_profile is not None and node.node_profile.memory is not None
|
||||
for node in nodes
|
||||
)
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
|
||||
|
||||
def filter_cycles_by_memory(
|
||||
@@ -39,14 +36,8 @@ def filter_cycles_by_memory(
|
||||
if not narrow_all_nodes(cycle):
|
||||
continue
|
||||
|
||||
# narrow_all_nodes guarantees memory is not None
|
||||
total_mem = sum(
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile.memory is not None
|
||||
),
|
||||
start=Memory(),
|
||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
||||
)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
||||
@@ -62,13 +53,8 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
):
|
||||
# NodeWithProfile guarantees memory is not None
|
||||
cycle_memory = sum(
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in selected_cycle
|
||||
if node.node_profile.memory is not None
|
||||
),
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
total_layers = model_meta.n_layers
|
||||
@@ -81,8 +67,6 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
if i == len(selected_cycle) - 1:
|
||||
node_layers = total_layers - layers_assigned
|
||||
else:
|
||||
# NodeWithProfile guarantees memory is not None
|
||||
assert node.node_profile.memory is not None
|
||||
node_layers = round(
|
||||
total_layers
|
||||
* (
|
||||
|
||||
@@ -19,13 +19,16 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodeIdentityMeasured,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
@@ -72,39 +75,29 @@ async def test_master():
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
# inject NodeIdentityMeasured and NodeMemoryMeasured events
|
||||
logger.info("inject NodeIdentityMeasured event")
|
||||
# inject a NodePerformanceProfile event
|
||||
logger.info("inject a NodePerformanceProfile event")
|
||||
await local_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeIdentityMeasured(
|
||||
NodePerformanceMeasured(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info("inject NodeMemoryMeasured event")
|
||||
await local_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=1,
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeMemoryMeasured(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -115,7 +108,7 @@ async def test_master():
|
||||
logger.info("wait for initial topology event")
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
while len(master.state.node_identities) == 0:
|
||||
while len(master.state.node_profiles) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
|
||||
logger.info("inject a CreateInstance Command")
|
||||
@@ -162,19 +155,17 @@ async def test_master():
|
||||
),
|
||||
)
|
||||
)
|
||||
while len(_get_events()) < 4:
|
||||
while len(_get_events()) < 3:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
events = _get_events()
|
||||
assert len(events) == 4
|
||||
assert len(events) == 3
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert events[3].idx == 3
|
||||
assert isinstance(events[0].event, NodeIdentityMeasured)
|
||||
assert isinstance(events[1].event, NodeMemoryMeasured)
|
||||
assert isinstance(events[2].event, InstanceCreated)
|
||||
created_instance = events[2].event.instance
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
|
||||
# Validate the shard assignments
|
||||
@@ -206,10 +197,10 @@ async def test_master():
|
||||
assert len(created_instance.hosts_by_node[node_id]) == 1
|
||||
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[3].event, TaskCreated)
|
||||
assert events[3].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[3].event.task, ChatCompletionTask)
|
||||
assert events[3].event.task.task_params == ChatCompletionTaskParams(
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
assert events[2].event.task.task_params == ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hello, how are you?")
|
||||
|
||||
@@ -13,10 +13,8 @@ from exo.shared.types.events import (
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeIdentityMeasured,
|
||||
NodeMemoryMeasured,
|
||||
NodeNetworkMeasured,
|
||||
NodeSystemMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeTimedOut,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
@@ -29,13 +27,7 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
NodeIdentity,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
@@ -59,12 +51,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeIdentityMeasured():
|
||||
return apply_node_identity_measured(event, state)
|
||||
case NodeSystemMeasured():
|
||||
return apply_node_system_measured(event, state)
|
||||
case NodeNetworkMeasured():
|
||||
return apply_node_network_measured(event, state)
|
||||
case NodePerformanceMeasured():
|
||||
return apply_node_performance_measured(event, state)
|
||||
case NodeDownloadProgress():
|
||||
return apply_node_download_progress(event, state)
|
||||
case NodeMemoryMeasured():
|
||||
@@ -202,19 +190,8 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
state.topology.remove_node(event.node_id)
|
||||
node_identities = {
|
||||
key: value
|
||||
for key, value in state.node_identities.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_memories = {
|
||||
key: value for key, value in state.node_memories.items() if key != event.node_id
|
||||
}
|
||||
node_systems = {
|
||||
key: value for key, value in state.node_systems.items() if key != event.node_id
|
||||
}
|
||||
node_networks = {
|
||||
key: value for key, value in state.node_networks.items() if key != event.node_id
|
||||
node_profiles = {
|
||||
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
||||
}
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
@@ -222,120 +199,32 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
return state.model_copy(
|
||||
update={
|
||||
"topology": topology,
|
||||
"node_identities": node_identities,
|
||||
"node_memories": node_memories,
|
||||
"node_systems": node_systems,
|
||||
"node_networks": node_networks,
|
||||
"node_profiles": node_profiles,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _reconstruct_profile(
|
||||
node_id: NodeId,
|
||||
state: State,
|
||||
*,
|
||||
identity: NodeIdentity | None = None,
|
||||
memory: MemoryPerformanceProfile | None = None,
|
||||
system: SystemPerformanceProfile | None = None,
|
||||
network_interfaces: list[NetworkInterfaceInfo] | None = None,
|
||||
) -> NodePerformanceProfile:
|
||||
"""Reconstruct a NodePerformanceProfile from split state storage.
|
||||
|
||||
Uses provided overrides, falling back to state values.
|
||||
"""
|
||||
ident = identity or state.node_identities.get(node_id)
|
||||
mem = memory or state.node_memories.get(node_id)
|
||||
sys = system or state.node_systems.get(node_id)
|
||||
nets = (
|
||||
network_interfaces
|
||||
if network_interfaces is not None
|
||||
else state.node_networks.get(node_id, [])
|
||||
)
|
||||
|
||||
return NodePerformanceProfile(
|
||||
model_id=ident.model_id if ident else None,
|
||||
chip_id=ident.chip_id if ident else None,
|
||||
friendly_name=ident.friendly_name if ident else None,
|
||||
memory=mem,
|
||||
network_interfaces=nets,
|
||||
system=sys,
|
||||
)
|
||||
|
||||
|
||||
def apply_node_identity_measured(event: NodeIdentityMeasured, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
identity = NodeIdentity(
|
||||
model_id=event.model_id,
|
||||
chip_id=event.chip_id,
|
||||
friendly_name=event.friendly_name,
|
||||
)
|
||||
new_identities: Mapping[NodeId, NodeIdentity] = {
|
||||
**state.node_identities,
|
||||
event.node_id: identity,
|
||||
def apply_node_performance_measured(
|
||||
event: NodePerformanceMeasured, state: State
|
||||
) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: event.node_profile,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
reconstructed = _reconstruct_profile(event.node_id, state, identity=identity)
|
||||
topology.update_node_profile(event.node_id, reconstructed)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_identities": new_identities,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_system_measured(event: NodeSystemMeasured, state: State) -> State:
|
||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
new_systems: Mapping[NodeId, SystemPerformanceProfile] = {
|
||||
**state.node_systems,
|
||||
event.node_id: event.system,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
reconstructed = _reconstruct_profile(event.node_id, state, system=event.system)
|
||||
topology.update_node_profile(event.node_id, reconstructed)
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_systems": new_systems,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_network_measured(event: NodeNetworkMeasured, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
new_networks: Mapping[NodeId, list[NetworkInterfaceInfo]] = {
|
||||
**state.node_networks,
|
||||
event.node_id: event.network_interfaces,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
reconstructed = _reconstruct_profile(
|
||||
event.node_id, state, network_interfaces=event.network_interfaces
|
||||
)
|
||||
topology.update_node_profile(event.node_id, reconstructed)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_networks": new_networks,
|
||||
"node_profiles": new_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
@@ -343,26 +232,57 @@ def apply_node_network_measured(event: NodeNetworkMeasured, state: State) -> Sta
|
||||
|
||||
|
||||
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
|
||||
existing = state.node_profiles.get(event.node_id)
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
new_memories: Mapping[NodeId, MemoryPerformanceProfile] = {
|
||||
**state.node_memories,
|
||||
event.node_id: event.memory,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
if existing is None:
|
||||
created = NodePerformanceProfile(
|
||||
model_id="unknown",
|
||||
chip_id="unknown",
|
||||
friendly_name="Unknown",
|
||||
memory=event.memory,
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(
|
||||
# TODO: flops_fp16=0.0,
|
||||
gpu_usage=0.0,
|
||||
temp=0.0,
|
||||
sys_power=0.0,
|
||||
pcpu_usage=0.0,
|
||||
ecpu_usage=0.0,
|
||||
ane_power=0.0,
|
||||
),
|
||||
)
|
||||
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: created,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
# TODO: NodeCreated
|
||||
topology.update_node_profile(event.node_id, created)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": created_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
updated = existing.model_copy(update={"memory": event.memory})
|
||||
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: updated,
|
||||
}
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
reconstructed = _reconstruct_profile(event.node_id, state, memory=event.memory)
|
||||
topology.update_node_profile(event.node_id, reconstructed)
|
||||
topology.update_node_profile(event.node_id, updated)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_memories": new_memories,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
update={"node_profiles": updated_profiles, "topology": topology}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,14 +2,10 @@ from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -89,35 +85,13 @@ class NodeTimedOut(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
class NodeIdentityMeasured(BaseEvent):
|
||||
"""Static identity info - emitted once at startup."""
|
||||
|
||||
class NodePerformanceMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
|
||||
|
||||
class NodeSystemMeasured(BaseEvent):
|
||||
"""Dynamic system metrics (GPU, temp, power) - emitted at 1s intervals."""
|
||||
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
class NodeNetworkMeasured(BaseEvent):
|
||||
"""Semi-static network interface info - emitted at 30s intervals."""
|
||||
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
network_interfaces: list[NetworkInterfaceInfo]
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class NodeMemoryMeasured(BaseEvent):
|
||||
"""Dynamic memory metrics - emitted at 0.5s intervals."""
|
||||
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
memory: MemoryPerformanceProfile
|
||||
@@ -153,9 +127,7 @@ Event = (
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
| NodeTimedOut
|
||||
| NodeIdentityMeasured
|
||||
| NodeSystemMeasured
|
||||
| NodeNetworkMeasured
|
||||
| NodePerformanceMeasured
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
|
||||
@@ -52,21 +52,13 @@ class NetworkInterfaceInfo(CamelCaseModel):
|
||||
ip_address: str
|
||||
|
||||
|
||||
class NodeIdentity(CamelCaseModel):
|
||||
"""Static identity info for a node."""
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
model_id: str | None = None
|
||||
chip_id: str | None = None
|
||||
friendly_name: str | None = None
|
||||
memory: MemoryPerformanceProfile | None = None
|
||||
memory: MemoryPerformanceProfile
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile | None = None
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
class ConnectionProfile(CamelCaseModel):
|
||||
|
||||
@@ -7,12 +7,7 @@ from pydantic.alias_generators import to_camel
|
||||
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
NodeIdentity,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -40,10 +35,7 @@ class State(CamelCaseModel):
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
node_identities: Mapping[NodeId, NodeIdentity] = {}
|
||||
node_memories: Mapping[NodeId, MemoryPerformanceProfile] = {}
|
||||
node_systems: Mapping[NodeId, SystemPerformanceProfile] = {}
|
||||
node_networks: Mapping[NodeId, list[NetworkInterfaceInfo]] = {}
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
|
||||
last_seen: Mapping[NodeId, datetime] = {}
|
||||
topology: Topology = Field(default_factory=Topology)
|
||||
last_event_applied_idx: int = Field(default=-1, ge=-1)
|
||||
|
||||
@@ -245,12 +245,15 @@ def create_http_session(
|
||||
sock_read_timeout = 1800
|
||||
sock_connect_timeout = 60
|
||||
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
ssl_context = ssl.create_default_context(
|
||||
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
|
||||
)
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
return aiohttp.ClientSession(
|
||||
auto_decompress=auto_decompress,
|
||||
connector=connector,
|
||||
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=total_timeout,
|
||||
connect=connect_timeout,
|
||||
|
||||
@@ -16,10 +16,8 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeIdentityMeasured,
|
||||
NodeMemoryMeasured,
|
||||
NodeNetworkMeasured,
|
||||
NodeSystemMeasured,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
TopologyEdgeCreated,
|
||||
@@ -27,11 +25,7 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
@@ -57,13 +51,7 @@ from exo.worker.download.download_utils import (
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from exo.worker.utils import (
|
||||
IdentityMetrics,
|
||||
start_polling_identity_metrics,
|
||||
start_polling_memory_metrics,
|
||||
start_polling_network_metrics,
|
||||
start_polling_system_metrics,
|
||||
)
|
||||
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
|
||||
from exo.worker.utils.net_profile import check_reachable
|
||||
|
||||
|
||||
@@ -110,51 +98,37 @@ class Worker:
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
async def identity_callback(identity: IdentityMetrics) -> None:
|
||||
# TODO: CLEANUP HEADER
|
||||
async def resource_monitor_callback(
|
||||
node_performance_profile: NodePerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeIdentityMeasured(
|
||||
NodePerformanceMeasured(
|
||||
node_id=self.node_id,
|
||||
model_id=identity.model_id,
|
||||
chip_id=identity.chip_id,
|
||||
friendly_name=identity.friendly_name,
|
||||
node_profile=node_performance_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async def system_callback(system: SystemPerformanceProfile) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeSystemMeasured(
|
||||
node_id=self.node_id,
|
||||
system=system,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async def network_callback(interfaces: list[NetworkInterfaceInfo]) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeNetworkMeasured(
|
||||
node_id=self.node_id,
|
||||
network_interfaces=interfaces,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async def memory_callback(memory: MemoryPerformanceProfile) -> None:
|
||||
async def memory_monitor_callback(
|
||||
memory_profile: MemoryPerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeMemoryMeasured(
|
||||
node_id=self.node_id,
|
||||
memory=memory,
|
||||
memory=memory_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
)
|
||||
)
|
||||
|
||||
# END CLEANUP
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(start_polling_identity_metrics, identity_callback)
|
||||
tg.start_soon(start_polling_system_metrics, system_callback)
|
||||
tg.start_soon(start_polling_network_metrics, network_callback)
|
||||
tg.start_soon(start_polling_memory_metrics, memory_callback)
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
from .profile import (
|
||||
IdentityMetrics,
|
||||
start_polling_identity_metrics,
|
||||
start_polling_memory_metrics,
|
||||
start_polling_network_metrics,
|
||||
start_polling_system_metrics,
|
||||
)
|
||||
from .profile import start_polling_memory_metrics, start_polling_node_metrics
|
||||
|
||||
__all__ = [
|
||||
"IdentityMetrics",
|
||||
"start_polling_identity_metrics",
|
||||
"start_polling_node_metrics",
|
||||
"start_polling_memory_metrics",
|
||||
"start_polling_network_metrics",
|
||||
"start_polling_system_metrics",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import anyio
|
||||
@@ -10,7 +9,7 @@ from loguru import logger
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
@@ -28,13 +27,6 @@ from .system_info import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IdentityMetrics:
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics | None:
|
||||
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
|
||||
|
||||
@@ -75,73 +67,48 @@ async def start_polling_memory_metrics(
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_identity_metrics(
|
||||
callback: Callable[[IdentityMetrics], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 30.0,
|
||||
) -> None:
|
||||
"""Continuously poll and emit identity metrics at 30s intervals."""
|
||||
while True:
|
||||
try:
|
||||
model_id, chip_id = await get_model_and_chip()
|
||||
friendly_name = await get_friendly_name()
|
||||
await callback(
|
||||
IdentityMetrics(
|
||||
model_id=model_id,
|
||||
chip_id=chip_id,
|
||||
friendly_name=friendly_name,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("Failed to emit identity metrics")
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_system_metrics(
|
||||
callback: Callable[[SystemPerformanceProfile], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 1.0,
|
||||
) -> None:
|
||||
"""Continuously poll and emit system metrics (GPU, temp, power) at 1s intervals."""
|
||||
async def start_polling_node_metrics(
|
||||
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
|
||||
):
|
||||
poll_interval_s = 1.0
|
||||
while True:
|
||||
try:
|
||||
metrics = await get_metrics_async()
|
||||
if metrics is None:
|
||||
return
|
||||
|
||||
network_interfaces = get_network_interfaces()
|
||||
# these awaits could be joined but realistically they should be cached
|
||||
model_id, chip_id = await get_model_and_chip()
|
||||
friendly_name = await get_friendly_name()
|
||||
|
||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
||||
memory_profile = get_memory_profile()
|
||||
|
||||
await callback(
|
||||
SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
sys_power=metrics.sys_power,
|
||||
pcpu_usage=metrics.pcpu_usage[1],
|
||||
ecpu_usage=metrics.ecpu_usage[1],
|
||||
ane_power=metrics.ane_power,
|
||||
NodePerformanceProfile(
|
||||
model_id=model_id,
|
||||
chip_id=chip_id,
|
||||
friendly_name=friendly_name,
|
||||
network_interfaces=network_interfaces,
|
||||
memory=memory_profile,
|
||||
system=SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
sys_power=metrics.sys_power,
|
||||
pcpu_usage=metrics.pcpu_usage[1],
|
||||
ecpu_usage=metrics.ecpu_usage[1],
|
||||
ane_power=metrics.ane_power,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"[system_monitor] Operation timed out after 30s, skipping this cycle."
|
||||
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
|
||||
)
|
||||
except MacMonError as e:
|
||||
logger.opt(exception=e).error("System Monitor encountered error")
|
||||
logger.opt(exception=e).error("Resource Monitor encountered error")
|
||||
return
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_network_metrics(
|
||||
callback: Callable[[list[NetworkInterfaceInfo]], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 30.0,
|
||||
) -> None:
|
||||
"""Continuously poll and emit network interface info at 30s intervals."""
|
||||
while True:
|
||||
try:
|
||||
network_interfaces = get_network_interfaces()
|
||||
await callback(network_interfaces)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("Network Monitor encountered error")
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
Reference in New Issue
Block a user