mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-19 11:28:51 -05:00
Compare commits
15 Commits
v1.0.63
...
evan/nonse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a041dc2f8 | ||
|
|
8de4d862b2 | ||
|
|
f11492a4af | ||
|
|
346b13e2c9 | ||
|
|
ea0588429b | ||
|
|
73b3f87e07 | ||
|
|
746589ba6b | ||
|
|
f82f862fd7 | ||
|
|
7ff937d8a1 | ||
|
|
d19bf02404 | ||
|
|
618cee5223 | ||
|
|
9c29eb7d48 | ||
|
|
c5158bee53 | ||
|
|
5c8a237940 | ||
|
|
745343c705 |
12
.github/actions/typecheck/action.yml
vendored
12
.github/actions/typecheck/action.yml
vendored
@@ -1,12 +0,0 @@
|
||||
name: Type Check
|
||||
|
||||
description: "Run type checker"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run type checker
|
||||
run: |
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
|
||||
shell: bash
|
||||
22
.github/workflows/build-app.yml
vendored
22
.github/workflows/build-app.yml
vendored
@@ -161,17 +161,6 @@ jobs:
|
||||
- name: Install Homebrew packages
|
||||
run: brew install just awscli macmon
|
||||
|
||||
- name: Install UV
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: uv.lock
|
||||
|
||||
- name: Setup Python
|
||||
run: |
|
||||
uv python install
|
||||
uv sync --locked
|
||||
|
||||
- name: Install Nix
|
||||
uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
@@ -183,12 +172,6 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build dashboard
|
||||
run: |
|
||||
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
|
||||
mkdir -p dashboard/build
|
||||
cp -r "$DASHBOARD_OUT"/* dashboard/build/
|
||||
|
||||
- name: Install Sparkle CLI
|
||||
run: |
|
||||
CLI_URL="${SPARKLE_CLI_URL:-https://github.com/sparkle-project/Sparkle/releases/download/${SPARKLE_VERSION}/Sparkle-${SPARKLE_VERSION}.tar.xz}"
|
||||
@@ -244,7 +227,10 @@ jobs:
|
||||
# ============================================================
|
||||
|
||||
- name: Build PyInstaller bundle
|
||||
run: uv run pyinstaller packaging/pyinstaller/exo.spec
|
||||
run: |
|
||||
PYINSTALLER_OUT=$(nix build .#exo-pyinstaller --print-build-logs --no-link --print-out-paths)
|
||||
mkdir -p dist
|
||||
cp -r "$PYINSTALLER_OUT" dist/exo
|
||||
|
||||
- name: Build Swift app
|
||||
env:
|
||||
|
||||
97
.github/workflows/pipeline.yml
vendored
97
.github/workflows/pipeline.yml
vendored
@@ -26,73 +26,14 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Pull LFS files
|
||||
run: |
|
||||
echo "Pulling Git LFS files..."
|
||||
git lfs pull
|
||||
shell: bash
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix binary exists directly
|
||||
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
nix --version
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Found nix profile script, sourcing..."
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
nix --version
|
||||
elif command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
nix --version
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Configure basedpyright include for local MLX
|
||||
run: |
|
||||
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
|
||||
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
|
||||
awk '
|
||||
BEGIN { in=0 }
|
||||
/^\[tool\.basedpyright\]/ { in=1; print; next }
|
||||
in && /^\[/ { in=0 } # next section
|
||||
in && /^[ \t]*include[ \t]*=/ {
|
||||
print "include = [\"/Users/Shared/mlx\"]"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
|
||||
|
||||
echo "New [tool.basedpyright] section:"
|
||||
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
|
||||
else
|
||||
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
|
||||
fi
|
||||
else
|
||||
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
@@ -113,6 +54,29 @@ jobs:
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Select Xcode
|
||||
if: startsWith(matrix.runner, 'macos-')
|
||||
run: |
|
||||
XCODE_BASEDIR="$(printf '%s\n' /Applications/Xcode_*.app | sort -V | tail -n 1)"
|
||||
[[ -z "$XCODE_BASEDIR" ]] && exit 1
|
||||
|
||||
ls -ld "/Applications/Xcode.app"
|
||||
sudo /usr/bin/xcode-select -s "$XCODE_BASEDIR"
|
||||
/usr/bin/xcode-select -p || true
|
||||
/usr/bin/xcrun --toolchain default --find xcodebuild || true
|
||||
|
||||
- name: Install Metal toolchain component
|
||||
if: startsWith(matrix.runner, 'macos-')
|
||||
run: |
|
||||
set -e
|
||||
if ! xcrun --find metal >/dev/null 2>&1; then
|
||||
sudo xcodebuild -downloadComponent MetalToolchain
|
||||
fi
|
||||
xcrun --find metal
|
||||
xcrun --find metallib
|
||||
echo "GH_OVERRIDE_METAL=$(xcrun --find metal)" >> $GITHUB_ENV
|
||||
echo "GH_OVERRIDE_METALLIB=$(xcrun --find metallib)" >> $GITHUB_ENV
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
@@ -124,6 +88,9 @@ jobs:
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build all Nix outputs
|
||||
env:
|
||||
GH_OVERRIDE_METAL: ${{ env.GH_OVERRIDE_METAL }}
|
||||
GH_OVERRIDE_METALLIB: ${{ env.GH_OVERRIDE_METALLIB }}
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
[
|
||||
|
||||
25
AGENTS.md
25
AGENTS.md
@@ -40,6 +40,31 @@ uv run ruff check
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Pre-Commit Checks (REQUIRED)
|
||||
|
||||
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
|
||||
|
||||
```bash
|
||||
# 1. Type checking - MUST pass with 0 errors
|
||||
uv run basedpyright
|
||||
|
||||
# 2. Linting - MUST pass
|
||||
uv run ruff check
|
||||
|
||||
# 3. Formatting - MUST be applied
|
||||
nix fmt
|
||||
|
||||
# 4. Tests - MUST pass
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Run all checks in sequence:
|
||||
```bash
|
||||
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
|
||||
```
|
||||
|
||||
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
|
||||
@@ -27,6 +27,15 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
|
||||
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
|
||||
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
|
||||
|
||||
## Dashboard
|
||||
|
||||
exo includes a built-in dashboard for managing your cluster and chatting with models.
|
||||
|
||||
<p align="center">
|
||||
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
|
||||
</p>
|
||||
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
|
||||
|
||||
## Benchmarks
|
||||
|
||||
<details>
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
@@ -26,7 +27,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -104,22 +105,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
@@ -299,6 +324,12 @@ def main() -> int:
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
@@ -320,7 +351,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -399,7 +430,7 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if 0 < n <= args.max_nodes:
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
@@ -441,7 +472,13 @@ def main() -> int:
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize placement: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
continue
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@@ -53,62 +53,285 @@
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
* Unescape HTML entities that marked may have escaped
|
||||
*/
|
||||
function unescapeHtmlEntities(text: string): string {
|
||||
return text
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/&/g, '&')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
// Storage for math expressions extracted before markdown processing
|
||||
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
|
||||
let mathCounter = 0;
|
||||
|
||||
// Storage for HTML snippets that need protection from markdown
|
||||
const htmlSnippets: Map<string, string> = new Map();
|
||||
let htmlCounter = 0;
|
||||
|
||||
// Use alphanumeric placeholders that won't be interpreted as HTML tags
|
||||
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
|
||||
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
|
||||
const HTML_PLACEHOLDER_PREFIX = 'HTMLPLACEHOLDER';
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
// Reset storage
|
||||
mathExpressions.clear();
|
||||
mathCounter = 0;
|
||||
htmlSnippets.clear();
|
||||
htmlCounter = 0;
|
||||
|
||||
// Protect code blocks first
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
// Remove LaTeX document commands
|
||||
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\begin\{document\}/g, '');
|
||||
processed = processed.replace(/\\end\{document\}/g, '');
|
||||
processed = processed.replace(/\\maketitle/g, '');
|
||||
processed = processed.replace(/\\title\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\author\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\date\{[^}]*\}/g, '');
|
||||
|
||||
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
|
||||
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
|
||||
processed = processed.replace(/\\require\{[^}]*\}/g, '');
|
||||
|
||||
// Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)
|
||||
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, () => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">📐</span><span class="latex-diagram-text">Diagram</span></div>');
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\begin\{figure\}[\s\S]*?\\end\{figure\}/g, () => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">🖼️</span><span class="latex-diagram-text">Figure</span></div>');
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
// Strip center environment (layout only, no content change)
|
||||
processed = processed.replace(/\\begin\{center\}/g, '');
|
||||
processed = processed.replace(/\\end\{center\}/g, '');
|
||||
// Strip other layout environments
|
||||
processed = processed.replace(/\\begin\{flushleft\}/g, '');
|
||||
processed = processed.replace(/\\end\{flushleft\}/g, '');
|
||||
processed = processed.replace(/\\begin\{flushright\}/g, '');
|
||||
processed = processed.replace(/\\end\{flushright\}/g, '');
|
||||
processed = processed.replace(/\\label\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\caption\{[^}]*\}/g, '');
|
||||
|
||||
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
|
||||
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
|
||||
|
||||
// Convert LaTeX math environments to display math (both bare and wrapped in $...$)
|
||||
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*', 'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', 'cases'];
|
||||
for (const env of mathEnvs) {
|
||||
// Handle $\begin{env}...\end{env}$ (with dollar signs, possibly multiline)
|
||||
const wrappedRegex = new RegExp(`\\$\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}\\$`, 'g');
|
||||
processed = processed.replace(wrappedRegex, (_, args, content) => {
|
||||
const cleanEnv = env.replace('\\*', '*');
|
||||
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Handle bare \begin{env}...\end{env} (without dollar signs)
|
||||
const bareRegex = new RegExp(`\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
|
||||
processed = processed.replace(bareRegex, (_, args, content) => {
|
||||
const cleanEnv = env.replace('\\*', '*');
|
||||
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert LaTeX proof environments to styled blocks (use placeholders for HTML)
|
||||
processed = processed.replace(
|
||||
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
|
||||
(_, content) => {
|
||||
const html = `<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">${content}</div></div>`;
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, html);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
}
|
||||
);
|
||||
|
||||
// Convert LaTeX theorem-like environments
|
||||
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
|
||||
for (const env of theoremEnvs) {
|
||||
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
|
||||
const envName = env.charAt(0).toUpperCase() + env.slice(1);
|
||||
processed = processed.replace(envRegex, (_, content) => {
|
||||
const html = `<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">${content}</div></div>`;
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, html);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert LaTeX text formatting commands (use placeholders to protect from markdown)
|
||||
processed = processed.replace(/\\emph\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<em>${content}</em>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\textit\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<em>${content}</em>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\textbf\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<strong>${content}</strong>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\texttt\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<code class="inline-code">${content}</code>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\underline\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<u>${content}</u>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Handle LaTeX line breaks and spacing
|
||||
processed = processed.replace(/\\\\(?:\s*\n)?/g, '\n'); // \\ -> newline
|
||||
processed = processed.replace(/\\newline/g, '\n');
|
||||
processed = processed.replace(/\\par\b/g, '\n\n');
|
||||
processed = processed.replace(/\\quad/g, ' ');
|
||||
processed = processed.replace(/\\qquad/g, ' ');
|
||||
processed = processed.replace(/~~/g, ' '); // non-breaking space
|
||||
|
||||
// Remove other common LaTeX commands that don't render
|
||||
processed = processed.replace(/\\centering/g, '');
|
||||
processed = processed.replace(/\\noindent/g, '');
|
||||
processed = processed.replace(/\\hfill/g, '');
|
||||
processed = processed.replace(/\\vspace\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\hspace\{[^}]*\}/g, ' ');
|
||||
|
||||
// Convert \(...\) to placeholder (display: false)
|
||||
processed = processed.replace(/\\\(([\s\S]+?)\\\)/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content, displayMode: false });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Convert \[...\] to placeholder (display: true)
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Extract display math ($$...$$) BEFORE markdown processing
|
||||
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Extract inline math ($...$) BEFORE markdown processing
|
||||
// Allow single-line only, skip currency patterns like $5 or $50
|
||||
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
|
||||
if (/^\d/.test(content.trim())) {
|
||||
return match; // Keep as-is for currency
|
||||
}
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Restore escaped dollar signs
|
||||
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
// Clean up any remaining stray backslashes from unrecognized commands
|
||||
processed = processed.replace(/\\(?=[a-zA-Z])/g, ''); // Remove \ before letters (unrecognized commands)
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
* Render math expressions with KaTeX and restore HTML placeholders
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
// Replace all math placeholders with rendered KaTeX
|
||||
for (const [placeholder, { content, displayMode }] of mathExpressions) {
|
||||
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const regex = new RegExp(escapedPlaceholder, 'g');
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
html = html.replace(regex, () => {
|
||||
try {
|
||||
const rendered = katex.renderToString(content, {
|
||||
displayMode,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
|
||||
if (displayMode) {
|
||||
return `
|
||||
<div class="math-display-wrapper">
|
||||
<div class="math-display-header">
|
||||
<span class="math-label">LaTeX</span>
|
||||
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="math-display-content">
|
||||
${rendered}
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
} else {
|
||||
return `<span class="math-inline">${rendered}</span>`;
|
||||
}
|
||||
} catch {
|
||||
const display = displayMode ? `$$${content}$$` : `$${content}$`;
|
||||
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Restore HTML placeholders (for \textbf, \emph, etc.)
|
||||
for (const [placeholder, htmlContent] of htmlSnippets) {
|
||||
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const regex = new RegExp(escapedPlaceholder, 'g');
|
||||
html = html.replace(regex, htmlContent);
|
||||
}
|
||||
|
||||
return html;
|
||||
}
|
||||
@@ -154,16 +377,50 @@
|
||||
}
|
||||
}
|
||||
|
||||
async function handleMathCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedSource = target.getAttribute('data-math-source');
|
||||
if (!encodedSource) return;
|
||||
|
||||
const source = decodeURIComponent(encodedSource);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(source);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy math:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of codeButtons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
|
||||
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
|
||||
for (const button of mathButtons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleMathCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
@@ -424,28 +681,290 @@
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
/* KaTeX math styling - Base */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
/* Display math container wrapper */
|
||||
.markdown-content :global(.math-display-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.15);
|
||||
background: rgba(0, 0, 0, 0.3);
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover) {
|
||||
border-color: rgba(255, 215, 0, 0.25);
|
||||
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
|
||||
}
|
||||
|
||||
/* Display math header - hidden by default, slides in on hover */
|
||||
.markdown-content :global(.math-display-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.375rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.03);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
|
||||
opacity: 0;
|
||||
max-height: 0;
|
||||
padding-top: 0;
|
||||
padding-bottom: 0;
|
||||
overflow: hidden;
|
||||
transition:
|
||||
opacity 0.2s ease,
|
||||
max-height 0.2s ease,
|
||||
padding 0.2s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
|
||||
opacity: 1;
|
||||
max-height: 2.5rem;
|
||||
padding: 0.375rem 0.75rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-label) {
|
||||
color: rgba(255, 215, 0, 0.7);
|
||||
font-size: 0.65rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
opacity: 0;
|
||||
transition:
|
||||
color 0.2s,
|
||||
opacity 0.15s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
/* Display math content area */
|
||||
.markdown-content :global(.math-display-content) {
|
||||
padding: 1rem 1.25rem;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
/* Custom scrollbar for math overflow */
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
|
||||
background: rgba(255, 215, 0, 0.2);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
|
||||
background: rgba(255, 215, 0, 0.35);
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content .katex-display) {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content .katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/* Inline math wrapper */
|
||||
.markdown-content :global(.math-inline) {
|
||||
display: inline;
|
||||
padding: 0 0.125rem;
|
||||
border-radius: 0.25rem;
|
||||
transition: background-color 0.15s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-inline:hover) {
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
}
|
||||
|
||||
/* Dark theme KaTeX overrides */
|
||||
.markdown-content :global(.katex .mord),
|
||||
.markdown-content :global(.katex .minner),
|
||||
.markdown-content :global(.katex .mop),
|
||||
.markdown-content :global(.katex .mbin),
|
||||
.markdown-content :global(.katex .mrel),
|
||||
.markdown-content :global(.katex .mpunct) {
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
/* Fraction lines and rules */
|
||||
.markdown-content :global(.katex .frac-line),
|
||||
.markdown-content :global(.katex .overline-line),
|
||||
.markdown-content :global(.katex .underline-line),
|
||||
.markdown-content :global(.katex .hline),
|
||||
.markdown-content :global(.katex .rule) {
|
||||
border-color: oklch(0.85 0 0) !important;
|
||||
background: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
/* Square roots and SVG elements */
|
||||
.markdown-content :global(.katex .sqrt-line) {
|
||||
border-color: oklch(0.85 0 0) !important;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex svg) {
|
||||
fill: oklch(0.85 0 0);
|
||||
stroke: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex svg path) {
|
||||
stroke: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
/* Delimiters (parentheses, brackets, braces) */
|
||||
.markdown-content :global(.katex .delimsizing),
|
||||
.markdown-content :global(.katex .delim-size1),
|
||||
.markdown-content :global(.katex .delim-size2),
|
||||
.markdown-content :global(.katex .delim-size3),
|
||||
.markdown-content :global(.katex .delim-size4),
|
||||
.markdown-content :global(.katex .mopen),
|
||||
.markdown-content :global(.katex .mclose) {
|
||||
color: oklch(0.75 0 0);
|
||||
}
|
||||
|
||||
/* Math error styling */
|
||||
.markdown-content :global(.math-error) {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.375rem;
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
padding: 0.25rem 0.5rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid rgba(248, 113, 113, 0.2);
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error-icon) {
|
||||
font-size: 0.875em;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
/* LaTeX proof environment */
|
||||
.markdown-content :global(.latex-proof) {
|
||||
margin: 1rem 0;
|
||||
padding: 1rem 1.25rem;
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border-left: 3px solid rgba(255, 215, 0, 0.4);
|
||||
border-radius: 0 0.375rem 0.375rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-header) {
|
||||
font-weight: 600;
|
||||
font-style: italic;
|
||||
color: oklch(0.85 0 0);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-header::after) {
|
||||
content: '.';
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-content) {
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-content p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* QED symbol at end of proof */
|
||||
.markdown-content :global(.latex-proof-content::after) {
|
||||
content: '∎';
|
||||
display: block;
|
||||
text-align: right;
|
||||
color: oklch(0.7 0 0);
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
/* LaTeX theorem-like environments */
|
||||
.markdown-content :global(.latex-theorem) {
|
||||
margin: 1rem 0;
|
||||
padding: 1rem 1.25rem;
|
||||
background: rgba(255, 215, 0, 0.03);
|
||||
border: 1px solid rgba(255, 215, 0, 0.15);
|
||||
border-radius: 0.375rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-header) {
|
||||
font-weight: 700;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-header::after) {
|
||||
content: '.';
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-content) {
|
||||
color: oklch(0.9 0 0);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-content p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* LaTeX diagram/figure placeholder */
|
||||
.markdown-content :global(.latex-diagram-placeholder) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
margin: 1rem 0;
|
||||
padding: 1.5rem 2rem;
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border: 1px dashed rgba(255, 215, 0, 0.25);
|
||||
border-radius: 0.5rem;
|
||||
color: rgba(255, 215, 0, 0.6);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-diagram-icon) {
|
||||
font-size: 1.25rem;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-diagram-text) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.75rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
</style>
|
||||
|
||||
BIN
docs/imgs/dashboard-cluster-view.png
Normal file
BIN
docs/imgs/dashboard-cluster-view.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 187 KiB |
65
flake.lock
generated
65
flake.lock
generated
@@ -21,7 +21,9 @@
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
@@ -149,19 +151,44 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
],
|
||||
"uv2nix": [
|
||||
"uv2nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763662255,
|
||||
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"lastModified": 1764134915,
|
||||
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -178,7 +205,10 @@
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix",
|
||||
"treefmt-nix": "treefmt-nix",
|
||||
"uv2nix": "uv2nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
@@ -239,6 +269,29 @@
|
||||
"repo": "treefmt-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"uv2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1767701098,
|
||||
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
||||
31
flake.nix
31
flake.nix
@@ -24,6 +24,26 @@
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
};
|
||||
|
||||
# Python packaging with uv2nix
|
||||
pyproject-nix = {
|
||||
url = "github:pyproject-nix/pyproject.nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
uv2nix = {
|
||||
url = "github:pyproject-nix/uv2nix";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
pyproject-build-systems = {
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.uv2nix.follows = "uv2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
@@ -48,6 +68,7 @@
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
./python/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
@@ -81,11 +102,10 @@
|
||||
};
|
||||
};
|
||||
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
packages =
|
||||
if pkgs.stdenv.isDarwin then {
|
||||
metal = pkgs.callPackage ./nix/metalWrapper.nix { metalVersion = "310"; };
|
||||
} else { };
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
@@ -124,6 +144,7 @@
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
|
||||
shellHook = ''
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
|
||||
${lib.optionalString stdenv.isLinux ''
|
||||
|
||||
79
nix/darwin-build-fixes.patch
Normal file
79
nix/darwin-build-fixes.patch
Normal file
@@ -0,0 +1,79 @@
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 0ed30932..d8528132 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
- # Throw an error if xcrun not found
|
||||
- execute_process(
|
||||
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
- OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(MACOS_SDK_VERSION @sdkVersion@)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
- execute_process(
|
||||
- COMMAND
|
||||
- zsh "-c"
|
||||
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(
|
||||
+ MLX_METAL_VERSION @metalVersion@)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
|
||||
index 13db804a..5b385132 100644
|
||||
--- a/cmake/extension.cmake
|
||||
+++ b/cmake/extension.cmake
|
||||
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
- xcrun -sdk macosx metal
|
||||
+ metal
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
index 262b0495..5c7446ad 100644
|
||||
--- a/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
add_custom_command(
|
||||
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
+ COMMAND metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
@@ -170,7 +170,7 @@ endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
+ COMMAND metallib ${KERNEL_AIR} -o
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
index bb55ed3a..94ea7dd7 100644
|
||||
--- a/mlx/backend/metal/make_compiled_preamble.sh
|
||||
+++ b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Use the metal compiler to get a list of headers (with depth)
|
||||
-CCC="xcrun -sdk macosx metal -x metal"
|
||||
+CCC="metal -x metal"
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
# Remove any included system frameworks (for MetalPerformancePrimitive headers)
|
||||
25
nix/metalWrapper.nix
Normal file
25
nix/metalWrapper.nix
Normal file
@@ -0,0 +1,25 @@
|
||||
{ stdenvNoCC
|
||||
, metalVersion
|
||||
}:
|
||||
assert stdenvNoCC.isDarwin;
|
||||
stdenvNoCC.mkDerivation {
|
||||
pname = "metal-wrapper-impure";
|
||||
version = metalVersion;
|
||||
|
||||
__noChroot = true;
|
||||
buildCommand = ''
|
||||
mkdir -p $out/bin && cd $out/bin
|
||||
|
||||
METALLIB_PATH=''${GH_OVERRIDE_METALLIB:-$(/usr/bin/xcrun --sdk macosx -f metallib)}
|
||||
METAL_PATH=''${GH_OVERRIDE_METAL:-"$(dirname "$METALLIB_PATH")/metal"}
|
||||
echo "$METAL_PATH"
|
||||
echo "$METALLIB_PATH"
|
||||
|
||||
ln -sf "$METAL_PATH" metal
|
||||
ln -sf "$METALLIB_PATH" metallib
|
||||
|
||||
[[ -e $out/bin/metal ]] && [[ -e $out/bin/metallib ]] || { echo ":(" && exit 1; }
|
||||
METAL_VERSION=$(echo __METAL_VERSION__ | "$METAL_PATH" -E -x metal -P - | tail -1 | tr -d '\n')
|
||||
[[ "$METAL_VERSION" == "${metalVersion}" ]] || { echo "Metal version $METAL_VERSION is not ${metalVersion}" && exit 1; }
|
||||
'';
|
||||
}
|
||||
154
nix/mlx.nix
Normal file
154
nix/mlx.nix
Normal file
@@ -0,0 +1,154 @@
|
||||
{ stdenv
|
||||
, lib
|
||||
, buildPythonPackage
|
||||
, fetchFromGitHub
|
||||
, replaceVars
|
||||
, fetchzip
|
||||
, setuptools
|
||||
, cmake
|
||||
, nanobind
|
||||
, pybind11
|
||||
, nlohmann_json
|
||||
, apple-sdk_26
|
||||
, metal
|
||||
, numpy
|
||||
, pytestCheckHook
|
||||
, python
|
||||
, runCommand
|
||||
, fmt
|
||||
}:
|
||||
assert stdenv.isDarwin;
|
||||
let
|
||||
# static dependencies included directly during compilation
|
||||
gguf-tools = fetchFromGitHub {
|
||||
owner = "antirez";
|
||||
repo = "gguf-tools";
|
||||
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
|
||||
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
|
||||
};
|
||||
|
||||
metal_cpp = fetchzip {
|
||||
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
|
||||
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
|
||||
};
|
||||
|
||||
mlx = buildPythonPackage rec {
|
||||
pname = "mlx";
|
||||
version = "0.30.1";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-Vt0RH+70VBwUjXSfPTsNdRS3g0ookJHhzf2kvgEtgH8=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
(replaceVars ./darwin-build-fixes.patch {
|
||||
sdkVersion = apple-sdk_26.version;
|
||||
metalVersion = metal.version;
|
||||
})
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace pyproject.toml \
|
||||
--replace-fail "nanobind==2.10.2" "nanobind"
|
||||
|
||||
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
|
||||
--replace-fail "g++" "$CXX"
|
||||
'';
|
||||
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
enableParallelBuilding = true;
|
||||
|
||||
# Allows multiple cores to be used in Python builds.
|
||||
postUnpack = ''
|
||||
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
|
||||
'';
|
||||
|
||||
# updates the wrong fetcher rev attribute
|
||||
passthru.skipBulkUpdate = true;
|
||||
|
||||
env = {
|
||||
DEV_RELEASE = 1;
|
||||
# NOTE The `metal` command-line utility used to build the Metal kernels is not open-source.
|
||||
# this is what the xcode wrapper is for - it patches in the system metal cli
|
||||
CMAKE_ARGS = toString [
|
||||
(lib.cmakeBool "USE_SYSTEM_FMT" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeOptionType "filepath" "METAL_LIB"
|
||||
"${metal}/Metal.framework")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
|
||||
];
|
||||
SDKROOT = apple-sdk_26.passthru.sdkroot;
|
||||
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
|
||||
};
|
||||
|
||||
build-system = [
|
||||
setuptools
|
||||
];
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
metal
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
fmt
|
||||
gguf-tools
|
||||
nanobind
|
||||
pybind11
|
||||
apple-sdk_26
|
||||
];
|
||||
|
||||
pythonImportsCheck = [ "mlx" ];
|
||||
|
||||
# Run the mlx Python test suite.
|
||||
nativeCheckInputs = [
|
||||
numpy
|
||||
pytestCheckHook
|
||||
];
|
||||
|
||||
enabledTestPaths = [
|
||||
"python/tests/"
|
||||
];
|
||||
|
||||
# Additional testing by executing the example Python scripts supplied with mlx
|
||||
# using the version of the library we've built.
|
||||
passthru.tests = {
|
||||
mlxTest =
|
||||
runCommand "run-mlx-examples"
|
||||
{
|
||||
buildInputs = [ mlx ];
|
||||
nativeBuildInputs = [ python ];
|
||||
}
|
||||
''
|
||||
cp ${src}/examples/python/logistic_regression.py .
|
||||
${python.interpreter} logistic_regression.py
|
||||
rm logistic_regression.py
|
||||
|
||||
cp ${src}/examples/python/linear_regression.py .
|
||||
${python.interpreter} linear_regression.py
|
||||
rm linear_regression.py
|
||||
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
meta = {
|
||||
homepage = "https://github.com/ml-explore/mlx";
|
||||
description = "Array framework for Apple silicon";
|
||||
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
platforms = [ "x86_64-linux" "aarch64-linux" "aarch64-darwin" ];
|
||||
};
|
||||
};
|
||||
in
|
||||
mlx
|
||||
@@ -126,3 +126,6 @@ env = [
|
||||
"EXO_TESTS=1"
|
||||
]
|
||||
addopts = "-m 'not slow'"
|
||||
filterwarnings = [
|
||||
"ignore:builtin type Swig:DeprecationWarning",
|
||||
]
|
||||
|
||||
143
python/parts.nix
Normal file
143
python/parts.nix
Normal file
@@ -0,0 +1,143 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ self', pkgs, lib, system, ... }:
|
||||
let
|
||||
# Load workspace from uv.lock
|
||||
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
|
||||
workspaceRoot = inputs.self;
|
||||
};
|
||||
|
||||
# Create overlay from workspace
|
||||
overlay = workspace.mkPyprojectOverlay { };
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: _: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
src = self'.packages.exo_pyo3_bindings;
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
};
|
||||
};
|
||||
|
||||
python = pkgs.python313;
|
||||
|
||||
# Overlay to provide build systems for source builds
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
|
||||
# Build MLX from source with proper dependencies
|
||||
mlx = pkgs.callPythonPackage ./nix/mlx.nix;
|
||||
};
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
lib.composeManyExtensions [
|
||||
inputs.pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
);
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
}
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set DASHBOARD_DIR ${self'.packages.dashboard}
|
||||
done
|
||||
'';
|
||||
|
||||
pyinstallerPackage =
|
||||
let
|
||||
venv = pythonSet.mkVirtualEnv "exo-pyinstaller-env" (
|
||||
workspace.deps.default
|
||||
// {
|
||||
# Include pyinstaller in the environment
|
||||
exo = [ "dev" ];
|
||||
}
|
||||
);
|
||||
in
|
||||
pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyinstaller";
|
||||
version = "0.3.0";
|
||||
|
||||
src = inputs.self;
|
||||
|
||||
nativeBuildInputs = [ venv pkgs.makeWrapper pkgs.macmon pkgs.darwin.system_cmds ];
|
||||
|
||||
buildPhase = ''
|
||||
# macmon must be in PATH for PyInstaller to bundle it
|
||||
export PATH="${pkgs.macmon}/bin:$PATH"
|
||||
# HOME must be writable for PyInstaller's cache
|
||||
export HOME="$TMPDIR"
|
||||
|
||||
# Copy dashboard to expected location
|
||||
mkdir -p dashboard/build
|
||||
cp -r ${self'.packages.dashboard}/* dashboard/build/
|
||||
|
||||
# Run PyInstaller
|
||||
${venv}/bin/python -m PyInstaller packaging/pyinstaller/exo.spec
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
cp -r dist/exo $out
|
||||
'';
|
||||
};
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS for now due to the dependency on
|
||||
# mlx/mlx-cpu being tricky to build on Linux. We can either remove this
|
||||
# dependency in the PyProject or build it with Nix.
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
exo-pyinstaller = pyinstallerPackage;
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
}
|
||||
# Pytest only on macOS (requires MLX)
|
||||
// lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
pytest = pkgs.runCommand "pytest"
|
||||
{
|
||||
nativeBuildInputs = [ testVenv ];
|
||||
} ''
|
||||
export HOME="$TMPDIR"
|
||||
export EXO_TESTS=1
|
||||
cd ${inputs.self}
|
||||
${testVenv}/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -205,6 +205,14 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
logger.info("FAST_SYNCH forced ON")
|
||||
elif args.fast_synch is False:
|
||||
os.environ["EXO_FAST_SYNCH"] = "off"
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
logger.info("EXO Shutdown complete")
|
||||
@@ -218,6 +226,7 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -259,6 +268,20 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
action="store_true",
|
||||
dest="fast_synch",
|
||||
default=None,
|
||||
help="Force MLX FAST_SYNCH on (for JACCL backend)",
|
||||
)
|
||||
fast_synch_group.add_argument(
|
||||
"--no-fast-synch",
|
||||
action="store_false",
|
||||
dest="fast_synch",
|
||||
help="Force MLX FAST_SYNCH off",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from hypercorn.config import Config
|
||||
@@ -29,6 +30,8 @@ from exo.shared.types.api import (
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ModelList,
|
||||
@@ -49,7 +52,12 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
@@ -115,6 +123,7 @@ class API:
|
||||
self.paused_ev: anyio.Event = anyio.Event()
|
||||
|
||||
self.app = FastAPI()
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
@@ -145,6 +154,21 @@ class API:
|
||||
self.paused_ev.set()
|
||||
self.paused_ev = anyio.Event()
|
||||
|
||||
def _setup_exception_handlers(self) -> None:
|
||||
self.app.exception_handler(HTTPException)(self.http_exception_handler)
|
||||
|
||||
async def http_exception_handler(
|
||||
self, _: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -406,6 +430,18 @@ class API:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
@@ -426,6 +462,12 @@ class API:
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -463,6 +505,12 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -607,14 +655,14 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -49,33 +49,83 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def allocate_layers_proportionally(
|
||||
total_layers: int,
|
||||
memory_fractions: list[float],
|
||||
) -> list[int]:
|
||||
n = len(memory_fractions)
|
||||
if n == 0:
|
||||
raise ValueError("Cannot allocate layers to an empty node list")
|
||||
if total_layers < n:
|
||||
raise ValueError(
|
||||
f"Cannot distribute {total_layers} layers across {n} nodes "
|
||||
"(need at least 1 layer per node)"
|
||||
)
|
||||
|
||||
# Largest remainder: floor each, then distribute remainder by fractional part
|
||||
raw = [f * total_layers for f in memory_fractions]
|
||||
result = [int(r) for r in raw]
|
||||
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
|
||||
for i in range(total_layers - sum(result)):
|
||||
result[by_remainder[i]] += 1
|
||||
|
||||
# Ensure minimum 1 per node by taking from the largest
|
||||
for i in range(n):
|
||||
if result[i] == 0:
|
||||
max_idx = max(range(n), key=lambda j: result[j])
|
||||
assert result[max_idx] > 1
|
||||
result[max_idx] -= 1
|
||||
result[i] = 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
):
|
||||
if not selected_cycle:
|
||||
raise ValueError("Cannot create shard assignments for empty node cycle")
|
||||
|
||||
cycle_memory = sum(
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
|
||||
if cycle_memory.in_bytes == 0:
|
||||
raise ValueError("Cannot create shard assignments: total available memory is 0")
|
||||
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
layers_assigned = 0
|
||||
for i, node in enumerate(selected_cycle):
|
||||
if i == len(selected_cycle) - 1:
|
||||
node_layers = total_layers - layers_assigned
|
||||
else:
|
||||
node_layers = round(
|
||||
total_layers
|
||||
* (
|
||||
node.node_profile.memory.ram_available.in_bytes
|
||||
/ cycle_memory.in_bytes
|
||||
)
|
||||
)
|
||||
node_layers = max(1, node_layers)
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=total_layers,
|
||||
memory_fractions=[
|
||||
node.node_profile.memory.ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node in selected_cycle
|
||||
],
|
||||
)
|
||||
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
|
||||
for i, (node, node_layers) in enumerate(
|
||||
zip(selected_cycle, layer_allocations, strict=True)
|
||||
):
|
||||
required_memory = node_layers * memory_per_layer
|
||||
available_memory = node.node_profile.memory.ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
f"Node {i} ({node.node_id}) has insufficient memory: "
|
||||
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
layers_assigned = 0
|
||||
for i, (node, node_layers) in enumerate(
|
||||
zip(selected_cycle, layer_allocations, strict=True)
|
||||
):
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
|
||||
107
src/exo/master/tests/test_api_error_handling.py
Normal file
107
src/exo/master/tests/test_api_error_handling.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
from exo.master.api import API
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Setup exception handler
|
||||
api = object.__new__(API)
|
||||
api.app = app
|
||||
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Add test routes that raise HTTPException
|
||||
@app.get("/test-error")
|
||||
async def _test_error() -> None:
|
||||
raise HTTPException(status_code=500, detail="Test error message")
|
||||
|
||||
@app.get("/test-not-found")
|
||||
async def _test_not_found() -> None:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Test 500 error
|
||||
response = client.get("/test-error")
|
||||
assert response.status_code == 500
|
||||
data: dict[str, Any] = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Test error message"
|
||||
assert data["error"]["type"] == "Internal Server Error"
|
||||
assert data["error"]["code"] == 500
|
||||
|
||||
# Test 404 error
|
||||
response = client.get("/test-not-found")
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
@@ -70,7 +70,7 @@ def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
|
||||
[
|
||||
((500, 500, 1000), 12, (3, 3, 6)),
|
||||
((500, 500, 500), 12, (4, 4, 4)),
|
||||
((312, 518, 1024), 12, (2, 3, 7)),
|
||||
((312, 468, 1092), 12, (2, 3, 7)),
|
||||
],
|
||||
)
|
||||
def test_get_instance_placements_create_instance(
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Callable
|
||||
import pytest
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
@@ -165,6 +166,9 @@ def test_get_smallest_cycles(
|
||||
((500, 500, 1000), 12, (3, 3, 6)),
|
||||
((500, 500, 500), 12, (4, 4, 4)),
|
||||
((312, 518, 1024), 12, (2, 3, 7)),
|
||||
# Edge case: one node has ~90% of memory - should not over-allocate.
|
||||
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
|
||||
((900, 50, 50), 20, (18, 1, 1)),
|
||||
],
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
@@ -397,3 +401,96 @@ def test_get_mlx_jaccl_coordinators(
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
|
||||
class TestAllocateLayersProportionally:
|
||||
def test_empty_node_list_raises(self):
|
||||
with pytest.raises(ValueError, match="empty node list"):
|
||||
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
|
||||
|
||||
def test_zero_layers_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
|
||||
|
||||
def test_negative_layers_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
|
||||
|
||||
def test_fewer_layers_than_nodes_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(
|
||||
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
|
||||
)
|
||||
|
||||
def test_equal_distribution(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
|
||||
)
|
||||
assert result == [3, 3, 3, 3]
|
||||
assert sum(result) == 12
|
||||
|
||||
def test_proportional_distribution(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
|
||||
)
|
||||
assert result == [3, 3, 6]
|
||||
assert sum(result) == 12
|
||||
|
||||
def test_extreme_imbalance_ensures_minimum(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
|
||||
)
|
||||
assert all(layers >= 1 for layers in result)
|
||||
assert sum(result) == 20
|
||||
# Small nodes get minimum 1 layer
|
||||
assert result == [18, 1, 1]
|
||||
|
||||
def test_single_node_gets_all_layers(self):
|
||||
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
|
||||
assert result == [10]
|
||||
|
||||
def test_minimum_viable_allocation(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
|
||||
)
|
||||
assert result == [1, 1, 1]
|
||||
assert sum(result) == 3
|
||||
|
||||
|
||||
def test_get_shard_assignments_insufficient_memory_raises(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
|
||||
node_a = create_node(900 * 1024, node_a_id)
|
||||
node_b = create_node(50 * 1024, node_b_id)
|
||||
node_c = create_node(10 * 1024, node_c_id) # Insufficient memory
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=20,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
with pytest.raises(ValueError, match="insufficient memory"):
|
||||
get_shard_assignments(model_meta, selected_cycle, Sharding.Pipeline)
|
||||
|
||||
@@ -11,10 +11,21 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
]
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
message: str
|
||||
type: str
|
||||
param: str | None = None
|
||||
code: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
|
||||
@@ -22,6 +22,7 @@ class TokenChunk(BaseChunk):
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -245,12 +245,15 @@ def create_http_session(
|
||||
sock_read_timeout = 1800
|
||||
sock_connect_timeout = 60
|
||||
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
ssl_context = ssl.create_default_context(
|
||||
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
|
||||
)
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
return aiohttp.ClientSession(
|
||||
auto_decompress=auto_decompress,
|
||||
connector=connector,
|
||||
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=total_timeout,
|
||||
connect=connect_timeout,
|
||||
|
||||
@@ -46,9 +46,11 @@ class CustomMlxLayer(nn.Module):
|
||||
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
# Set twice to avoid __setattr__ recursion
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
self.original_layer: _LayerCallable = original_layer
|
||||
|
||||
@property
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -58,7 +60,7 @@ class CustomMlxLayer(nn.Module):
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
original_layer = object.__getattribute__(self, "_original_layer")
|
||||
return object.__getattribute__(original_layer, name)
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
@@ -168,11 +170,21 @@ def pipeline_auto_parallel(
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
# We can assume the model has at least one layer thanks to placement.
|
||||
# If a layer type doesn't exist, we can set it to 0.
|
||||
inner_model_instance.swa_idx = (
|
||||
0
|
||||
if "sliding_attention" not in inner_model_instance.layer_types # type: ignore
|
||||
else inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
inner_model_instance.ga_idx = (
|
||||
0
|
||||
if "full_attention" not in inner_model_instance.layer_types # type: ignore
|
||||
else inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
)
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
@@ -2,7 +2,9 @@ import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -82,6 +84,45 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -188,7 +229,9 @@ def initialize_mlx(
|
||||
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
@@ -202,7 +245,9 @@ def load_mlx_items(
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
@@ -216,6 +261,7 @@ def load_mlx_items(
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
@@ -252,7 +298,15 @@ def shard_and_load(
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
mx.eval(model.parameters())
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
@@ -17,15 +17,23 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
fast_synch_override != "off"
|
||||
and (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
)
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
@@ -67,6 +67,7 @@ def main(
|
||||
bound_instance.bound_runner_id,
|
||||
bound_instance.bound_shard,
|
||||
)
|
||||
device_rank = shard_metadata.device_rank
|
||||
logger.info("hello from the runner")
|
||||
if getattr(shard_metadata, "immediate_exception", False):
|
||||
raise Exception("Fake exception - runner failed to spin up.")
|
||||
@@ -118,7 +119,20 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
@@ -148,8 +162,6 @@ def main(
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -158,41 +170,61 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
)
|
||||
try:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
|
||||
202
src/exo/worker/tests/unittests/test_mlx/conftest.py
Normal file
202
src/exo/worker/tests/unittests/test_mlx/conftest.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# type: ignore
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
|
||||
|
||||
class MockLayer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_attr = "test_value"
|
||||
self.use_sliding = True
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
return x * 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTestConfig:
|
||||
model_path: Path
|
||||
total_layers: int
|
||||
base_port: int
|
||||
max_tokens: int
|
||||
|
||||
|
||||
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
hostfile_path = f.name
|
||||
|
||||
return hostfile_path, hosts
|
||||
|
||||
|
||||
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
|
||||
|
||||
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
|
||||
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
|
||||
total_layers=24,
|
||||
base_port=29600,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
|
||||
def run_gpt_oss_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
model_path: Path,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
max_tokens: int = 200,
|
||||
) -> None:
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
model, tokenizer = load(str(model_path))
|
||||
|
||||
# Generate a prompt of exact token length
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
base_len = len(base_tokens)
|
||||
|
||||
# Build prompt with approximate target length
|
||||
repeats = (prompt_tokens // base_len) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)
|
||||
# Truncate to exact target length
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
formatted_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt_text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
n_layers=24,
|
||||
)
|
||||
|
||||
model = pipeline_auto_parallel(model, group, shard_meta)
|
||||
|
||||
# Barrier before generation
|
||||
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
|
||||
mlx_core.eval(barrier)
|
||||
|
||||
generated_text = ""
|
||||
for response in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=formatted_prompt,
|
||||
max_tokens=max_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
):
|
||||
generated_text += response.text
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def run_gpt_oss_tensor_parallel_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
model_path: Path,
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
max_tokens: int = 10,
|
||||
) -> None:
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
model, tokenizer = load(str(model_path))
|
||||
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
base_len = len(base_tokens)
|
||||
|
||||
repeats = (prompt_tokens // base_len) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
formatted_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt_text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
model = tensor_auto_parallel(model, group)
|
||||
|
||||
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
|
||||
mlx_core.eval(barrier)
|
||||
|
||||
generated_text = ""
|
||||
for response in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=formatted_prompt,
|
||||
max_tokens=max_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
):
|
||||
generated_text += response.text
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
|
||||
137
src/exo/worker/tests/unittests/test_mlx/test_auto_parallel.py
Normal file
137
src/exo/worker/tests/unittests/test_mlx/test_auto_parallel.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import multiprocessing as mp
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
CustomMlxLayer,
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
|
||||
|
||||
|
||||
def run_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
) -> None:
|
||||
import os
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
import mlx.nn as mlx_nn
|
||||
|
||||
class MockLayerInner(mlx_nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_attr = "test_value"
|
||||
|
||||
def __call__(
|
||||
self, x: mlx_core.array, *args: object, **kwargs: object
|
||||
) -> mlx_core.array:
|
||||
return x * 2
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
mock = MockLayerInner()
|
||||
first = PipelineFirstLayer(mock, r=rank, group=group)
|
||||
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
|
||||
|
||||
x = mlx_core.ones((1, 4))
|
||||
result = composed(x)
|
||||
mlx_core.eval(result)
|
||||
|
||||
success = result.shape == x.shape
|
||||
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def test_single_wrapper_delegates_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
wrapped = CustomMlxLayer(mock)
|
||||
|
||||
assert wrapped.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert wrapped.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_composed_wrappers_delegate_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
group = mx.distributed.init()
|
||||
|
||||
first = PipelineFirstLayer(mock, r=0, group=group)
|
||||
composed = PipelineLastLayer(first, r=0, s=1, group=group)
|
||||
|
||||
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert composed.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_missing_attribute_raises() -> None:
|
||||
mock = MockLayer()
|
||||
wrapped = CustomMlxLayer(mock)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = wrapped.nonexistent_attr # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_composed_call_works() -> None:
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
world_size = 2
|
||||
base_port = 29500
|
||||
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
hostfile_path = f.name
|
||||
|
||||
try:
|
||||
result_queue: Any = ctx.Queue()
|
||||
|
||||
processes: list[Any] = []
|
||||
for rank in range(world_size):
|
||||
p = ctx.Process(
|
||||
target=run_pipeline_device,
|
||||
args=(rank, world_size, hostfile_path, result_queue),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
p.join(timeout=10) # pyright: ignore[reportAny]
|
||||
|
||||
results: dict[int, Any] = {}
|
||||
errors: dict[int, str] = {}
|
||||
while not result_queue.empty(): # pyright: ignore[reportAny]
|
||||
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
|
||||
if success:
|
||||
results[rank] = value
|
||||
else:
|
||||
errors[rank] = value
|
||||
|
||||
assert len(results) == world_size, (
|
||||
f"Expected {world_size} results, got {len(results)}. Errors: {errors}"
|
||||
)
|
||||
|
||||
for rank in range(world_size):
|
||||
assert rank in results, (
|
||||
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
|
||||
)
|
||||
result_array = results[rank]
|
||||
# Both devices see the final result (4.0) after all_gather
|
||||
assert (result_array == 4.0).all(), (
|
||||
f"Device {rank}: expected 4.0, got {result_array}"
|
||||
)
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
@@ -121,6 +121,21 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
|
||||
# Use a fake event_sender to remove test flakiness.
|
||||
class EventCollector:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[Event] = []
|
||||
|
||||
def send(self, event: Event) -> None:
|
||||
self.events.append(event)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def join(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -130,22 +145,20 @@ def _run(tasks: Iterable[Task]):
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
event_sender = EventCollector()
|
||||
|
||||
with task_sender, event_receiver:
|
||||
with task_sender:
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
|
||||
# worst monkeypatch known to man
|
||||
# this is some c++ nonsense
|
||||
event_sender.close = nothin
|
||||
event_sender.join = nothin
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
|
||||
|
||||
return event_receiver.collect()
|
||||
return event_sender.events
|
||||
|
||||
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
|
||||
Reference in New Issue
Block a user