mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-19 11:28:51 -05:00
Compare commits
3 Commits
simplify-m
...
evan/mlnix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0094eaea86 | ||
|
|
346b13e2c9 | ||
|
|
ea0588429b |
26
.github/workflows/pipeline.yml
vendored
26
.github/workflows/pipeline.yml
vendored
@@ -113,6 +113,29 @@ jobs:
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Select Xcode
|
||||
if: startsWith(matrix.runner, 'macos-')
|
||||
run: |
|
||||
XCODE_BASEDIR="$(printf '%s\n' /Applications/Xcode_*.app | sort -V | tail -n 1)"
|
||||
[[ -z "$XCODE_BASEDIR" ]] && exit 1
|
||||
|
||||
ls -ld "/Applications/Xcode.app"
|
||||
sudo /usr/bin/xcode-select -s "$XCODE_BASEDIR"
|
||||
/usr/bin/xcode-select -p || true
|
||||
/usr/bin/xcrun --toolchain default --find xcodebuild || true
|
||||
|
||||
- name: Install Metal toolchain component
|
||||
if: startsWith(matrix.runner, 'macos-')
|
||||
run: |
|
||||
set -e
|
||||
if ! xcrun --find metal >/dev/null 2>&1; then
|
||||
sudo xcodebuild -downloadComponent MetalToolchain
|
||||
fi
|
||||
xcrun --find metal
|
||||
xcrun --find metallib
|
||||
echo "GH_OVERRIDE_METAL=$(xcrun --find metal)" >> $GITHUB_ENV
|
||||
echo "GH_OVERRIDE_METALLIB=$(xcrun --find metallib)" >> $GITHUB_ENV
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
@@ -124,6 +147,9 @@ jobs:
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build all Nix outputs
|
||||
env:
|
||||
GH_OVERRIDE_METAL: ${{ env.GH_OVERRIDE_METAL }}
|
||||
GH_OVERRIDE_METALLIB: ${{ env.GH_OVERRIDE_METALLIB }}
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
[
|
||||
|
||||
@@ -53,62 +53,285 @@
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
* Unescape HTML entities that marked may have escaped
|
||||
*/
|
||||
function unescapeHtmlEntities(text: string): string {
|
||||
return text
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/&/g, '&')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
// Storage for math expressions extracted before markdown processing
|
||||
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
|
||||
let mathCounter = 0;
|
||||
|
||||
// Storage for HTML snippets that need protection from markdown
|
||||
const htmlSnippets: Map<string, string> = new Map();
|
||||
let htmlCounter = 0;
|
||||
|
||||
// Use alphanumeric placeholders that won't be interpreted as HTML tags
|
||||
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
|
||||
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
|
||||
const HTML_PLACEHOLDER_PREFIX = 'HTMLPLACEHOLDER';
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
// Reset storage
|
||||
mathExpressions.clear();
|
||||
mathCounter = 0;
|
||||
htmlSnippets.clear();
|
||||
htmlCounter = 0;
|
||||
|
||||
// Protect code blocks first
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
// Remove LaTeX document commands
|
||||
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\begin\{document\}/g, '');
|
||||
processed = processed.replace(/\\end\{document\}/g, '');
|
||||
processed = processed.replace(/\\maketitle/g, '');
|
||||
processed = processed.replace(/\\title\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\author\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\date\{[^}]*\}/g, '');
|
||||
|
||||
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
|
||||
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
|
||||
processed = processed.replace(/\\require\{[^}]*\}/g, '');
|
||||
|
||||
// Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)
|
||||
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, () => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">📐</span><span class="latex-diagram-text">Diagram</span></div>');
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\begin\{figure\}[\s\S]*?\\end\{figure\}/g, () => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">🖼️</span><span class="latex-diagram-text">Figure</span></div>');
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
// Strip center environment (layout only, no content change)
|
||||
processed = processed.replace(/\\begin\{center\}/g, '');
|
||||
processed = processed.replace(/\\end\{center\}/g, '');
|
||||
// Strip other layout environments
|
||||
processed = processed.replace(/\\begin\{flushleft\}/g, '');
|
||||
processed = processed.replace(/\\end\{flushleft\}/g, '');
|
||||
processed = processed.replace(/\\begin\{flushright\}/g, '');
|
||||
processed = processed.replace(/\\end\{flushright\}/g, '');
|
||||
processed = processed.replace(/\\label\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\caption\{[^}]*\}/g, '');
|
||||
|
||||
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
|
||||
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
|
||||
|
||||
// Convert LaTeX math environments to display math (both bare and wrapped in $...$)
|
||||
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*', 'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', 'cases'];
|
||||
for (const env of mathEnvs) {
|
||||
// Handle $\begin{env}...\end{env}$ (with dollar signs, possibly multiline)
|
||||
const wrappedRegex = new RegExp(`\\$\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}\\$`, 'g');
|
||||
processed = processed.replace(wrappedRegex, (_, args, content) => {
|
||||
const cleanEnv = env.replace('\\*', '*');
|
||||
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Handle bare \begin{env}...\end{env} (without dollar signs)
|
||||
const bareRegex = new RegExp(`\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
|
||||
processed = processed.replace(bareRegex, (_, args, content) => {
|
||||
const cleanEnv = env.replace('\\*', '*');
|
||||
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert LaTeX proof environments to styled blocks (use placeholders for HTML)
|
||||
processed = processed.replace(
|
||||
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
|
||||
(_, content) => {
|
||||
const html = `<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">${content}</div></div>`;
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, html);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
}
|
||||
);
|
||||
|
||||
// Convert LaTeX theorem-like environments
|
||||
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
|
||||
for (const env of theoremEnvs) {
|
||||
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
|
||||
const envName = env.charAt(0).toUpperCase() + env.slice(1);
|
||||
processed = processed.replace(envRegex, (_, content) => {
|
||||
const html = `<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">${content}</div></div>`;
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, html);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert LaTeX text formatting commands (use placeholders to protect from markdown)
|
||||
processed = processed.replace(/\\emph\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<em>${content}</em>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\textit\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<em>${content}</em>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\textbf\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<strong>${content}</strong>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\texttt\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<code class="inline-code">${content}</code>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\underline\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<u>${content}</u>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Handle LaTeX line breaks and spacing
|
||||
processed = processed.replace(/\\\\(?:\s*\n)?/g, '\n'); // \\ -> newline
|
||||
processed = processed.replace(/\\newline/g, '\n');
|
||||
processed = processed.replace(/\\par\b/g, '\n\n');
|
||||
processed = processed.replace(/\\quad/g, ' ');
|
||||
processed = processed.replace(/\\qquad/g, ' ');
|
||||
processed = processed.replace(/~~/g, ' '); // non-breaking space
|
||||
|
||||
// Remove other common LaTeX commands that don't render
|
||||
processed = processed.replace(/\\centering/g, '');
|
||||
processed = processed.replace(/\\noindent/g, '');
|
||||
processed = processed.replace(/\\hfill/g, '');
|
||||
processed = processed.replace(/\\vspace\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\hspace\{[^}]*\}/g, ' ');
|
||||
|
||||
// Convert \(...\) to placeholder (display: false)
|
||||
processed = processed.replace(/\\\(([\s\S]+?)\\\)/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content, displayMode: false });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Convert \[...\] to placeholder (display: true)
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Extract display math ($$...$$) BEFORE markdown processing
|
||||
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Extract inline math ($...$) BEFORE markdown processing
|
||||
// Allow single-line only, skip currency patterns like $5 or $50
|
||||
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
|
||||
if (/^\d/.test(content.trim())) {
|
||||
return match; // Keep as-is for currency
|
||||
}
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Restore escaped dollar signs
|
||||
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
// Clean up any remaining stray backslashes from unrecognized commands
|
||||
processed = processed.replace(/\\(?=[a-zA-Z])/g, ''); // Remove \ before letters (unrecognized commands)
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
* Render math expressions with KaTeX and restore HTML placeholders
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
// Replace all math placeholders with rendered KaTeX
|
||||
for (const [placeholder, { content, displayMode }] of mathExpressions) {
|
||||
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const regex = new RegExp(escapedPlaceholder, 'g');
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
html = html.replace(regex, () => {
|
||||
try {
|
||||
const rendered = katex.renderToString(content, {
|
||||
displayMode,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
|
||||
if (displayMode) {
|
||||
return `
|
||||
<div class="math-display-wrapper">
|
||||
<div class="math-display-header">
|
||||
<span class="math-label">LaTeX</span>
|
||||
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="math-display-content">
|
||||
${rendered}
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
} else {
|
||||
return `<span class="math-inline">${rendered}</span>`;
|
||||
}
|
||||
} catch {
|
||||
const display = displayMode ? `$$${content}$$` : `$${content}$`;
|
||||
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Restore HTML placeholders (for \textbf, \emph, etc.)
|
||||
for (const [placeholder, htmlContent] of htmlSnippets) {
|
||||
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const regex = new RegExp(escapedPlaceholder, 'g');
|
||||
html = html.replace(regex, htmlContent);
|
||||
}
|
||||
|
||||
return html;
|
||||
}
|
||||
@@ -154,16 +377,50 @@
|
||||
}
|
||||
}
|
||||
|
||||
async function handleMathCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedSource = target.getAttribute('data-math-source');
|
||||
if (!encodedSource) return;
|
||||
|
||||
const source = decodeURIComponent(encodedSource);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(source);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy math:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of codeButtons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
|
||||
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
|
||||
for (const button of mathButtons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleMathCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
@@ -424,28 +681,290 @@
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
/* KaTeX math styling - Base */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
/* Display math container wrapper */
|
||||
.markdown-content :global(.math-display-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.15);
|
||||
background: rgba(0, 0, 0, 0.3);
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover) {
|
||||
border-color: rgba(255, 215, 0, 0.25);
|
||||
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
|
||||
}
|
||||
|
||||
/* Display math header - hidden by default, slides in on hover */
|
||||
.markdown-content :global(.math-display-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.375rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.03);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
|
||||
opacity: 0;
|
||||
max-height: 0;
|
||||
padding-top: 0;
|
||||
padding-bottom: 0;
|
||||
overflow: hidden;
|
||||
transition:
|
||||
opacity 0.2s ease,
|
||||
max-height 0.2s ease,
|
||||
padding 0.2s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
|
||||
opacity: 1;
|
||||
max-height: 2.5rem;
|
||||
padding: 0.375rem 0.75rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-label) {
|
||||
color: rgba(255, 215, 0, 0.7);
|
||||
font-size: 0.65rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
opacity: 0;
|
||||
transition:
|
||||
color 0.2s,
|
||||
opacity 0.15s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
/* Display math content area */
|
||||
.markdown-content :global(.math-display-content) {
|
||||
padding: 1rem 1.25rem;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
/* Custom scrollbar for math overflow */
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
|
||||
background: rgba(255, 215, 0, 0.2);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
|
||||
background: rgba(255, 215, 0, 0.35);
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content .katex-display) {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content .katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/* Inline math wrapper */
|
||||
.markdown-content :global(.math-inline) {
|
||||
display: inline;
|
||||
padding: 0 0.125rem;
|
||||
border-radius: 0.25rem;
|
||||
transition: background-color 0.15s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-inline:hover) {
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
}
|
||||
|
||||
/* Dark theme KaTeX overrides */
|
||||
.markdown-content :global(.katex .mord),
|
||||
.markdown-content :global(.katex .minner),
|
||||
.markdown-content :global(.katex .mop),
|
||||
.markdown-content :global(.katex .mbin),
|
||||
.markdown-content :global(.katex .mrel),
|
||||
.markdown-content :global(.katex .mpunct) {
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
/* Fraction lines and rules */
|
||||
.markdown-content :global(.katex .frac-line),
|
||||
.markdown-content :global(.katex .overline-line),
|
||||
.markdown-content :global(.katex .underline-line),
|
||||
.markdown-content :global(.katex .hline),
|
||||
.markdown-content :global(.katex .rule) {
|
||||
border-color: oklch(0.85 0 0) !important;
|
||||
background: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
/* Square roots and SVG elements */
|
||||
.markdown-content :global(.katex .sqrt-line) {
|
||||
border-color: oklch(0.85 0 0) !important;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex svg) {
|
||||
fill: oklch(0.85 0 0);
|
||||
stroke: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex svg path) {
|
||||
stroke: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
/* Delimiters (parentheses, brackets, braces) */
|
||||
.markdown-content :global(.katex .delimsizing),
|
||||
.markdown-content :global(.katex .delim-size1),
|
||||
.markdown-content :global(.katex .delim-size2),
|
||||
.markdown-content :global(.katex .delim-size3),
|
||||
.markdown-content :global(.katex .delim-size4),
|
||||
.markdown-content :global(.katex .mopen),
|
||||
.markdown-content :global(.katex .mclose) {
|
||||
color: oklch(0.75 0 0);
|
||||
}
|
||||
|
||||
/* Math error styling */
|
||||
.markdown-content :global(.math-error) {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.375rem;
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
padding: 0.25rem 0.5rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid rgba(248, 113, 113, 0.2);
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error-icon) {
|
||||
font-size: 0.875em;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
/* LaTeX proof environment */
|
||||
.markdown-content :global(.latex-proof) {
|
||||
margin: 1rem 0;
|
||||
padding: 1rem 1.25rem;
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border-left: 3px solid rgba(255, 215, 0, 0.4);
|
||||
border-radius: 0 0.375rem 0.375rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-header) {
|
||||
font-weight: 600;
|
||||
font-style: italic;
|
||||
color: oklch(0.85 0 0);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-header::after) {
|
||||
content: '.';
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-content) {
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-content p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* QED symbol at end of proof */
|
||||
.markdown-content :global(.latex-proof-content::after) {
|
||||
content: '∎';
|
||||
display: block;
|
||||
text-align: right;
|
||||
color: oklch(0.7 0 0);
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
/* LaTeX theorem-like environments */
|
||||
.markdown-content :global(.latex-theorem) {
|
||||
margin: 1rem 0;
|
||||
padding: 1rem 1.25rem;
|
||||
background: rgba(255, 215, 0, 0.03);
|
||||
border: 1px solid rgba(255, 215, 0, 0.15);
|
||||
border-radius: 0.375rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-header) {
|
||||
font-weight: 700;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-header::after) {
|
||||
content: '.';
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-content) {
|
||||
color: oklch(0.9 0 0);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-content p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* LaTeX diagram/figure placeholder */
|
||||
.markdown-content :global(.latex-diagram-placeholder) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
margin: 1rem 0;
|
||||
padding: 1.5rem 2rem;
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border: 1px dashed rgba(255, 215, 0, 0.25);
|
||||
border-radius: 0.5rem;
|
||||
color: rgba(255, 215, 0, 0.6);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-diagram-icon) {
|
||||
font-size: 1.25rem;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-diagram-text) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.75rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -87,6 +87,11 @@
|
||||
touch $out
|
||||
'';
|
||||
|
||||
packages =
|
||||
if pkgs.stdenv.isDarwin then {
|
||||
metal = pkgs.callPackage ./nix/metalWrapper.nix { metalVersion = "310"; };
|
||||
} else { };
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
@@ -124,6 +129,7 @@
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
|
||||
shellHook = ''
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
|
||||
${lib.optionalString stdenv.isLinux ''
|
||||
|
||||
79
nix/darwin-build-fixes.patch
Normal file
79
nix/darwin-build-fixes.patch
Normal file
@@ -0,0 +1,79 @@
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 0ed30932..d8528132 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
- # Throw an error if xcrun not found
|
||||
- execute_process(
|
||||
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
- OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(MACOS_SDK_VERSION @sdkVersion@)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
- execute_process(
|
||||
- COMMAND
|
||||
- zsh "-c"
|
||||
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(
|
||||
+ MLX_METAL_VERSION @metalVersion@)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
|
||||
index 13db804a..5b385132 100644
|
||||
--- a/cmake/extension.cmake
|
||||
+++ b/cmake/extension.cmake
|
||||
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
- xcrun -sdk macosx metal
|
||||
+ metal
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
index 262b0495..5c7446ad 100644
|
||||
--- a/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
add_custom_command(
|
||||
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
+ COMMAND metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
@@ -170,7 +170,7 @@ endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
+ COMMAND metallib ${KERNEL_AIR} -o
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
index bb55ed3a..94ea7dd7 100644
|
||||
--- a/mlx/backend/metal/make_compiled_preamble.sh
|
||||
+++ b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Use the metal compiler to get a list of headers (with depth)
|
||||
-CCC="xcrun -sdk macosx metal -x metal"
|
||||
+CCC="metal -x metal"
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
# Remove any included system frameworks (for MetalPerformancePrimitive headers)
|
||||
25
nix/metalWrapper.nix
Normal file
25
nix/metalWrapper.nix
Normal file
@@ -0,0 +1,25 @@
|
||||
{ stdenvNoCC
|
||||
, metalVersion
|
||||
}:
|
||||
assert stdenvNoCC.isDarwin;
|
||||
stdenvNoCC.mkDerivation {
|
||||
pname = "metal-wrapper-impure";
|
||||
version = metalVersion;
|
||||
|
||||
__noChroot = true;
|
||||
buildCommand = ''
|
||||
mkdir -p $out/bin && cd $out/bin
|
||||
|
||||
METALLIB_PATH=''${GH_OVERRIDE_METALLIB:-$(/usr/bin/xcrun --sdk macosx -f metallib)}
|
||||
METAL_PATH=''${GH_OVERRIDE_METAL:-"$(dirname "$METALLIB_PATH")/metal"}
|
||||
echo "$METAL_PATH"
|
||||
echo "$METALLIB_PATH"
|
||||
|
||||
ln -sf "$METAL_PATH" metal
|
||||
ln -sf "$METALLIB_PATH" metallib
|
||||
|
||||
[[ -e $out/bin/metal ]] && [[ -e $out/bin/metallib ]] || { echo ":(" && exit 1; }
|
||||
METAL_VERSION=$(echo __METAL_VERSION__ | "$METAL_PATH" -E -x metal -P - | tail -1 | tr -d '\n')
|
||||
[[ "$METAL_VERSION" == "${metalVersion}" ]] || { echo "Metal version $METAL_VERSION is not ${metalVersion}" && exit 1; }
|
||||
'';
|
||||
}
|
||||
154
nix/mlx.nix
Normal file
154
nix/mlx.nix
Normal file
@@ -0,0 +1,154 @@
|
||||
{ stdenv
|
||||
, lib
|
||||
, buildPythonPackage
|
||||
, fetchFromGitHub
|
||||
, replaceVars
|
||||
, fetchzip
|
||||
, setuptools
|
||||
, cmake
|
||||
, nanobind
|
||||
, pybind11
|
||||
, nlohmann_json
|
||||
, apple-sdk_26
|
||||
, metal
|
||||
, numpy
|
||||
, pytestCheckHook
|
||||
, python
|
||||
, runCommand
|
||||
, fmt
|
||||
}:
|
||||
assert stdenv.isDarwin;
|
||||
let
|
||||
# static dependencies included directly during compilation
|
||||
gguf-tools = fetchFromGitHub {
|
||||
owner = "antirez";
|
||||
repo = "gguf-tools";
|
||||
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
|
||||
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
|
||||
};
|
||||
|
||||
metal_cpp = fetchzip {
|
||||
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
|
||||
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
|
||||
};
|
||||
|
||||
mlx = buildPythonPackage rec {
|
||||
pname = "mlx";
|
||||
version = "0.30.1";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-Vt0RH+70VBwUjXSfPTsNdRS3g0ookJHhzf2kvgEtgH8=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
(replaceVars ./darwin-build-fixes.patch {
|
||||
sdkVersion = apple-sdk_26.version;
|
||||
metalVersion = metal.version;
|
||||
})
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace pyproject.toml \
|
||||
--replace-fail "nanobind==2.10.2" "nanobind"
|
||||
|
||||
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
|
||||
--replace-fail "g++" "$CXX"
|
||||
'';
|
||||
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
enableParallelBuilding = true;
|
||||
|
||||
# Allows multiple cores to be used in Python builds.
|
||||
postUnpack = ''
|
||||
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
|
||||
'';
|
||||
|
||||
# updates the wrong fetcher rev attribute
|
||||
passthru.skipBulkUpdate = true;
|
||||
|
||||
env = {
|
||||
DEV_RELEASE = 1;
|
||||
# NOTE The `metal` command-line utility used to build the Metal kernels is not open-source.
|
||||
# this is what the xcode wrapper is for - it patches in the system metal cli
|
||||
CMAKE_ARGS = toString [
|
||||
(lib.cmakeBool "USE_SYSTEM_FMT" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeOptionType "filepath" "METAL_LIB"
|
||||
"${metal}/Metal.framework")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
|
||||
];
|
||||
SDKROOT = apple-sdk_26.passthru.sdkroot;
|
||||
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
|
||||
};
|
||||
|
||||
build-system = [
|
||||
setuptools
|
||||
];
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
metal
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
fmt
|
||||
gguf-tools
|
||||
nanobind
|
||||
pybind11
|
||||
apple-sdk_26
|
||||
];
|
||||
|
||||
pythonImportsCheck = [ "mlx" ];
|
||||
|
||||
# Run the mlx Python test suite.
|
||||
nativeCheckInputs = [
|
||||
numpy
|
||||
pytestCheckHook
|
||||
];
|
||||
|
||||
enabledTestPaths = [
|
||||
"python/tests/"
|
||||
];
|
||||
|
||||
# Additional testing by executing the example Python scripts supplied with mlx
|
||||
# using the version of the library we've built.
|
||||
passthru.tests = {
|
||||
mlxTest =
|
||||
runCommand "run-mlx-examples"
|
||||
{
|
||||
buildInputs = [ mlx ];
|
||||
nativeBuildInputs = [ python ];
|
||||
}
|
||||
''
|
||||
cp ${src}/examples/python/logistic_regression.py .
|
||||
${python.interpreter} logistic_regression.py
|
||||
rm logistic_regression.py
|
||||
|
||||
cp ${src}/examples/python/linear_regression.py .
|
||||
${python.interpreter} linear_regression.py
|
||||
rm linear_regression.py
|
||||
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
meta = {
|
||||
homepage = "https://github.com/ml-explore/mlx";
|
||||
description = "Array framework for Apple silicon";
|
||||
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
platforms = [ "x86_64-linux" "aarch64-linux" "aarch64-darwin" ];
|
||||
};
|
||||
};
|
||||
in
|
||||
mlx
|
||||
@@ -46,9 +46,11 @@ class CustomMlxLayer(nn.Module):
|
||||
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
# Set twice to avoid __setattr__ recursion
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
self.original_layer: _LayerCallable = original_layer
|
||||
|
||||
@property
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -58,7 +60,7 @@ class CustomMlxLayer(nn.Module):
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
original_layer = object.__getattribute__(self, "_original_layer")
|
||||
return object.__getattribute__(original_layer, name)
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
|
||||
202
src/exo/worker/tests/unittests/test_mlx/conftest.py
Normal file
202
src/exo/worker/tests/unittests/test_mlx/conftest.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# type: ignore
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
|
||||
|
||||
class MockLayer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_attr = "test_value"
|
||||
self.use_sliding = True
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
return x * 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTestConfig:
|
||||
model_path: Path
|
||||
total_layers: int
|
||||
base_port: int
|
||||
max_tokens: int
|
||||
|
||||
|
||||
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
hostfile_path = f.name
|
||||
|
||||
return hostfile_path, hosts
|
||||
|
||||
|
||||
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
|
||||
|
||||
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
|
||||
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
|
||||
total_layers=24,
|
||||
base_port=29600,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
|
||||
def run_gpt_oss_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
model_path: Path,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
max_tokens: int = 200,
|
||||
) -> None:
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
model, tokenizer = load(str(model_path))
|
||||
|
||||
# Generate a prompt of exact token length
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
base_len = len(base_tokens)
|
||||
|
||||
# Build prompt with approximate target length
|
||||
repeats = (prompt_tokens // base_len) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)
|
||||
# Truncate to exact target length
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
formatted_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt_text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
n_layers=24,
|
||||
)
|
||||
|
||||
model = pipeline_auto_parallel(model, group, shard_meta)
|
||||
|
||||
# Barrier before generation
|
||||
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
|
||||
mlx_core.eval(barrier)
|
||||
|
||||
generated_text = ""
|
||||
for response in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=formatted_prompt,
|
||||
max_tokens=max_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
):
|
||||
generated_text += response.text
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def run_gpt_oss_tensor_parallel_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
model_path: Path,
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
max_tokens: int = 10,
|
||||
) -> None:
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
model, tokenizer = load(str(model_path))
|
||||
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
base_len = len(base_tokens)
|
||||
|
||||
repeats = (prompt_tokens // base_len) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
formatted_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt_text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
model = tensor_auto_parallel(model, group)
|
||||
|
||||
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
|
||||
mlx_core.eval(barrier)
|
||||
|
||||
generated_text = ""
|
||||
for response in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=formatted_prompt,
|
||||
max_tokens=max_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
):
|
||||
generated_text += response.text
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
|
||||
137
src/exo/worker/tests/unittests/test_mlx/test_auto_parallel.py
Normal file
137
src/exo/worker/tests/unittests/test_mlx/test_auto_parallel.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import multiprocessing as mp
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
CustomMlxLayer,
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
|
||||
|
||||
|
||||
def run_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
) -> None:
|
||||
import os
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
import mlx.nn as mlx_nn
|
||||
|
||||
class MockLayerInner(mlx_nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_attr = "test_value"
|
||||
|
||||
def __call__(
|
||||
self, x: mlx_core.array, *args: object, **kwargs: object
|
||||
) -> mlx_core.array:
|
||||
return x * 2
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
mock = MockLayerInner()
|
||||
first = PipelineFirstLayer(mock, r=rank, group=group)
|
||||
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
|
||||
|
||||
x = mlx_core.ones((1, 4))
|
||||
result = composed(x)
|
||||
mlx_core.eval(result)
|
||||
|
||||
success = result.shape == x.shape
|
||||
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def test_single_wrapper_delegates_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
wrapped = CustomMlxLayer(mock)
|
||||
|
||||
assert wrapped.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert wrapped.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_composed_wrappers_delegate_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
group = mx.distributed.init()
|
||||
|
||||
first = PipelineFirstLayer(mock, r=0, group=group)
|
||||
composed = PipelineLastLayer(first, r=0, s=1, group=group)
|
||||
|
||||
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert composed.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_missing_attribute_raises() -> None:
|
||||
mock = MockLayer()
|
||||
wrapped = CustomMlxLayer(mock)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = wrapped.nonexistent_attr # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_composed_call_works() -> None:
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
world_size = 2
|
||||
base_port = 29500
|
||||
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
hostfile_path = f.name
|
||||
|
||||
try:
|
||||
result_queue: Any = ctx.Queue()
|
||||
|
||||
processes: list[Any] = []
|
||||
for rank in range(world_size):
|
||||
p = ctx.Process(
|
||||
target=run_pipeline_device,
|
||||
args=(rank, world_size, hostfile_path, result_queue),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
p.join(timeout=10) # pyright: ignore[reportAny]
|
||||
|
||||
results: dict[int, Any] = {}
|
||||
errors: dict[int, str] = {}
|
||||
while not result_queue.empty(): # pyright: ignore[reportAny]
|
||||
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
|
||||
if success:
|
||||
results[rank] = value
|
||||
else:
|
||||
errors[rank] = value
|
||||
|
||||
assert len(results) == world_size, (
|
||||
f"Expected {world_size} results, got {len(results)}. Errors: {errors}"
|
||||
)
|
||||
|
||||
for rank in range(world_size):
|
||||
assert rank in results, (
|
||||
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
|
||||
)
|
||||
result_array = results[rank]
|
||||
# Both devices see the final result (4.0) after all_gather
|
||||
assert (result_array == 4.0).all(), (
|
||||
f"Device {rank}: expected 4.0, got {result_array}"
|
||||
)
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
Reference in New Issue
Block a user