Compare commits

..

10 Commits

Author SHA1 Message Date
Jake Hillion
8a041dc2f8 nix: add PyInstaller package for macOS DMG
Added exo-pyinstaller derivation that builds the PyInstaller bundle using
Nix. Updated build-app.yml to use the Nix-built bundle instead of invoking
uv/pyinstaller directly.

Test plan:
- Build on macOS: `nix build .#exo-pyinstaller`
- Push to test-app branch to trigger build-app.yml workflow
2026-01-19 16:25:32 +00:00
Jake Hillion
8de4d862b2 nix: add Python packaging with uv2nix
Add uv2nix to build Python packages from uv.lock. This creates a fully
Nix-managed Python environment with the Rust bindings injected via overlay.

Changes:
- Add pyproject-nix, uv2nix, and pyproject-build-systems flake inputs
- Create python/parts.nix with overlays to inject Nix-built Rust wheel
- Export packages.exo on macOS (wraps exo/exo-master/exo-worker with dashboard)
- Add checks.lint (ruff, all platforms) and checks.pytest (macOS only)
- Simplify CI typecheck job using nicknovitski/nix-develop action
- Delete .github/actions/typecheck composite action (no longer needed)
- Add no-build-package for MLX packages in pyproject.toml (use wheels)

The Python build is currently macOS-only since MLX requires Metal. Linux
support will be added once the pyproject dependencies are simplified.

Test plan:
- Run `nix flake check` on macOS to verify pytest and lint pass
- Build exo package on macOS: `nix build .#exo`
- Verify CI pipeline passes with simplified typecheck job
2026-01-19 16:25:32 +00:00
Evan
f11492a4af nix mlx compilation for better portability 2026-01-19 16:19:52 +00:00
Alex Cheema
346b13e2c9 Enhance LaTeX rendering in dashboard markdown (#1197)
## Motivation

When models output LaTeX-formatted math proofs, the dashboard was not
rendering them correctly. Issues included:
- `\documentclass`, `\begin{document}`, `\usepackage` showing as raw
text
- `$...$` inline math with complex expressions (like `\frac`, `\ldots`)
not rendering due to markdown escaping backslashes
- `\begin{align*}...\end{align*}` and other math environments showing as
raw text
- `\emph{...}`, `\textbf{...}` LaTeX formatting commands not being
converted
- `$\require{...}$` (MathJax-specific) causing KaTeX errors
- `\begin{proof}...\end{proof}` showing as raw text

## Changes

Enhanced `MarkdownContent.svelte` with comprehensive LaTeX support:

**Math extraction before markdown processing:**
- Extract `$...$`, `$$...$$`, `\(...\)`, `\[...\]` into placeholders
before markdown processes the text
- Use alphanumeric placeholders (`MATHPLACEHOLDERINLINE0END`) that won't
be interpreted as HTML tags
- Restore and render with KaTeX after markdown processing

**LaTeX document command removal:**
- Strip `\documentclass{...}`, `\usepackage{...}`, `\begin{document}`,
`\end{document}`
- Strip `\maketitle`, `\title{...}`, `\author{...}`, `\date{...}`
- Strip `\require{...}` (MathJax-specific, not KaTeX)
- Replace `tikzpicture` environments with `[diagram]` placeholder
- Strip `\label{...}` cross-reference commands

**LaTeX math environments:**
- Convert `\begin{align*}`, `\begin{equation}`, `\begin{gather}`, etc.
to display math blocks

**LaTeX text formatting:**
- `\emph{...}` and `\textit{...}` → `<em>...</em>`
- `\textbf{...}` → `<strong>...</strong>`
- `\texttt{...}` → `<code>...</code>`
- `\underline{...}` → `<u>...</u>`

**LaTeX environments styling:**
- `\begin{proof}...\end{proof}` → styled proof block with QED symbol
- `\begin{theorem}`, `\begin{lemma}`, etc. → styled theorem blocks

**Display math enhancements:**
- Wrapped in styled container with subtle gold border
- "LaTeX" label and copy button appear on hover
- Dark theme KaTeX color overrides for better readability
- Custom scrollbar for overflow

## Why It Works

The key insight is that markdown processing was escaping backslashes in
LaTeX before KaTeX could see them. By extracting all math expressions
into alphanumeric placeholders *before* markdown runs, then restoring
them *after*, the LaTeX content passes through to KaTeX unmodified.

Using purely alphanumeric placeholders like `MATHPLACEHOLDERINLINE0END`
instead of `<<MATH_INLINE_0>>` prevents markdown from interpreting them
as HTML tags and stripping them.

## Test Plan

### Manual Testing
- Hardware: Any machine with the dashboard
- What you did:
  - Ask model to "write a proof in latex"
  - Verify inline math like `$x \in S$` renders correctly
- Verify display math like `\begin{align*}...\end{align*}` renders as
block
  - Verify `\documentclass`, `\begin{document}` are stripped (not shown)
  - Verify `\emph{...}` converts to italics
  - Verify copy button works on display math blocks
- Test edge cases: `$5` (currency) stays as text, `\$50` (escaped)
becomes `$50`

Before:
<img width="799" height="637" alt="Screenshot 2026-01-19 at 11 51 22 AM"
src="https://github.com/user-attachments/assets/62a705b8-b3c2-47b8-afd0-5d0c1b240e44"
/>

After:
<img width="809" height="642" alt="Screenshot 2026-01-19 at 11 46 58 AM"
src="https://github.com/user-attachments/assets/4f35fa1d-333c-4285-bc68-58a50f8f148e"
/>


### Automated Testing
- Dashboard builds successfully with `npm run build`
- Existing functionality preserved

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:50:41 +00:00
rltakashige
ea0588429b Custom mlx layer composition (#1201)
## Motivation

With a single pipeline layer, PipelineFirstLayer gets composed with
PipelineLastLayer.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing


### Automated Testing
Made failing tests. Fixed them!
2026-01-19 12:36:25 +00:00
rltakashige
73b3f87e07 Set swa_idx and ga_idx for single layer (#1202)
## Motivation

Layer types does not contain either "sliding_attention" or
"full_attention" for pipeline parallel (single layer).

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Manually tested single layer of GPT OSS. Doesn't crash

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-19 12:31:11 +00:00
Evan Quiney
746589ba6b tidy: remove context manager from api (#1199) 2026-01-19 11:58:13 +00:00
rltakashige
f82f862fd7 Fix several issues with placement (#1200)
## Motivation

Uneven placements were causing issues for some users with lopsided
setups. While fixing, I ran into another issue with impossible
allocation of memory.

## Changes

- Allocate at least 1 layer per device.
- Catch overallocation of memory with an error.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Tested that GPT OSS is placed correctly.

### Automated Testing
Added breaking tests in the first commit. Resolved with new placement
algorithm in the second one.
2026-01-19 11:52:35 +00:00
Alex Cheema
7ff937d8a1 Add dashboard screenshots to README (#1185)
## Motivation

The README showcases exo's features and benchmarks but doesn't show what
the dashboard actually looks like. Adding a screenshot helps users
understand what they'll get when they run exo.

## Changes

- Added dashboard screenshot to `docs/imgs/dashboard-cluster-view.png`:
Shows the cluster topology view with 4 × 512GB M3 Ultra Mac Studio
running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)
- Added a new "Dashboard" section to README.md below Features,
displaying the screenshot with caption

## Why It Works

Visual documentation helps users understand what exo offers before they
install it. The screenshot demonstrates the cluster management
capabilities.

## Test Plan

### Manual Testing
- Verified image renders correctly in GitHub markdown preview

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 10:43:27 +00:00
Evan Quiney
d19bf02404 re-raise exceptions in the runner (#1198)
## Motivation

Runners that crash can swallow errors - we should re-raise. Also the
exception handler annoyed me.

## Changes

The try: except in the runner's chat now re-raises.
2026-01-19 10:35:23 +00:00
35 changed files with 1800 additions and 965 deletions

View File

@@ -1,12 +0,0 @@
name: Type Check
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

View File

@@ -161,17 +161,6 @@ jobs:
- name: Install Homebrew packages
run: brew install just awscli macmon
- name: Install UV
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: uv.lock
- name: Setup Python
run: |
uv python install
uv sync --locked
- name: Install Nix
uses: cachix/install-nix-action@v31
with:
@@ -183,12 +172,6 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build dashboard
run: |
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
mkdir -p dashboard/build
cp -r "$DASHBOARD_OUT"/* dashboard/build/
- name: Install Sparkle CLI
run: |
CLI_URL="${SPARKLE_CLI_URL:-https://github.com/sparkle-project/Sparkle/releases/download/${SPARKLE_VERSION}/Sparkle-${SPARKLE_VERSION}.tar.xz}"
@@ -244,7 +227,10 @@ jobs:
# ============================================================
- name: Build PyInstaller bundle
run: uv run pyinstaller packaging/pyinstaller/exo.spec
run: |
PYINSTALLER_OUT=$(nix build .#exo-pyinstaller --print-build-logs --no-link --print-out-paths)
mkdir -p dist
cp -r "$PYINSTALLER_OUT" dist/exo
- name: Build Swift app
env:

View File

@@ -26,73 +26,14 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Sync dependencies
run: uv sync --all-packages
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
nix:
name: Build and check (${{ matrix.system }})
@@ -113,6 +54,29 @@ jobs:
with:
lfs: false
- name: Select Xcode
if: startsWith(matrix.runner, 'macos-')
run: |
XCODE_BASEDIR="$(printf '%s\n' /Applications/Xcode_*.app | sort -V | tail -n 1)"
[[ -z "$XCODE_BASEDIR" ]] && exit 1
ls -ld "/Applications/Xcode.app"
sudo /usr/bin/xcode-select -s "$XCODE_BASEDIR"
/usr/bin/xcode-select -p || true
/usr/bin/xcrun --toolchain default --find xcodebuild || true
- name: Install Metal toolchain component
if: startsWith(matrix.runner, 'macos-')
run: |
set -e
if ! xcrun --find metal >/dev/null 2>&1; then
sudo xcodebuild -downloadComponent MetalToolchain
fi
xcrun --find metal
xcrun --find metallib
echo "GH_OVERRIDE_METAL=$(xcrun --find metal)" >> $GITHUB_ENV
echo "GH_OVERRIDE_METALLIB=$(xcrun --find metallib)" >> $GITHUB_ENV
- uses: cachix/install-nix-action@v31
with:
nix_path: nixpkgs=channel:nixos-unstable
@@ -124,6 +88,9 @@ jobs:
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build all Nix outputs
env:
GH_OVERRIDE_METAL: ${{ env.GH_OVERRIDE_METAL }}
GH_OVERRIDE_METALLIB: ${{ env.GH_OVERRIDE_METALLIB }}
run: |
nix flake show --json | jq -r '
[

View File

@@ -27,6 +27,15 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Dashboard
exo includes a built-in dashboard for managing your cluster and chatting with models.
<p align="center">
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
</p>
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
## Benchmarks
<details>

View File

@@ -53,62 +53,285 @@
marked.use({ renderer });
/**
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
* Unescape HTML entities that marked may have escaped
*/
function unescapeHtmlEntities(text: string): string {
return text
.replace(/&lt;/g, '<')
.replace(/&gt;/g, '>')
.replace(/&amp;/g, '&')
.replace(/&quot;/g, '"')
.replace(/&#39;/g, "'");
}
// Storage for math expressions extracted before markdown processing
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
let mathCounter = 0;
// Storage for HTML snippets that need protection from markdown
const htmlSnippets: Map<string, string> = new Map();
let htmlCounter = 0;
// Use alphanumeric placeholders that won't be interpreted as HTML tags
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
const HTML_PLACEHOLDER_PREFIX = 'HTMLPLACEHOLDER';
/**
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
*/
function preprocessLaTeX(text: string): string {
// Protect code blocks
// Reset storage
mathExpressions.clear();
mathCounter = 0;
htmlSnippets.clear();
htmlCounter = 0;
// Protect code blocks first
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `<<CODE_${codeBlocks.length - 1}>>`;
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
});
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Remove LaTeX document commands
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\begin\{document\}/g, '');
processed = processed.replace(/\\end\{document\}/g, '');
processed = processed.replace(/\\maketitle/g, '');
processed = processed.replace(/\\title\{[^}]*\}/g, '');
processed = processed.replace(/\\author\{[^}]*\}/g, '');
processed = processed.replace(/\\date\{[^}]*\}/g, '');
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
processed = processed.replace(/\\require\{[^}]*\}/g, '');
// Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">📐</span><span class="latex-diagram-text">Diagram</span></div>');
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\begin\{figure\}[\s\S]*?\\end\{figure\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">🖼️</span><span class="latex-diagram-text">Figure</span></div>');
htmlCounter++;
return placeholder;
});
// Strip center environment (layout only, no content change)
processed = processed.replace(/\\begin\{center\}/g, '');
processed = processed.replace(/\\end\{center\}/g, '');
// Strip other layout environments
processed = processed.replace(/\\begin\{flushleft\}/g, '');
processed = processed.replace(/\\end\{flushleft\}/g, '');
processed = processed.replace(/\\begin\{flushright\}/g, '');
processed = processed.replace(/\\end\{flushright\}/g, '');
processed = processed.replace(/\\label\{[^}]*\}/g, '');
processed = processed.replace(/\\caption\{[^}]*\}/g, '');
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
// Convert LaTeX math environments to display math (both bare and wrapped in $...$)
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*', 'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', 'cases'];
for (const env of mathEnvs) {
// Handle $\begin{env}...\end{env}$ (with dollar signs, possibly multiline)
const wrappedRegex = new RegExp(`\\$\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}\\$`, 'g');
processed = processed.replace(wrappedRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
// Handle bare \begin{env}...\end{env} (without dollar signs)
const bareRegex = new RegExp(`\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
processed = processed.replace(bareRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
}
// Convert LaTeX proof environments to styled blocks (use placeholders for HTML)
processed = processed.replace(
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
(_, content) => {
const html = `<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
}
);
// Convert LaTeX theorem-like environments
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
for (const env of theoremEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
const envName = env.charAt(0).toUpperCase() + env.slice(1);
processed = processed.replace(envRegex, (_, content) => {
const html = `<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
});
}
// Convert LaTeX text formatting commands (use placeholders to protect from markdown)
processed = processed.replace(/\\emph\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textit\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textbf\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<strong>${content}</strong>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\texttt\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<code class="inline-code">${content}</code>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\underline\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<u>${content}</u>`);
htmlCounter++;
return placeholder;
});
// Handle LaTeX line breaks and spacing
processed = processed.replace(/\\\\(?:\s*\n)?/g, '\n'); // \\ -> newline
processed = processed.replace(/\\newline/g, '\n');
processed = processed.replace(/\\par\b/g, '\n\n');
processed = processed.replace(/\\quad/g, ' ');
processed = processed.replace(/\\qquad/g, ' ');
processed = processed.replace(/~~/g, ' '); // non-breaking space
// Remove other common LaTeX commands that don't render
processed = processed.replace(/\\centering/g, '');
processed = processed.replace(/\\noindent/g, '');
processed = processed.replace(/\\hfill/g, '');
processed = processed.replace(/\\vspace\{[^}]*\}/g, '');
processed = processed.replace(/\\hspace\{[^}]*\}/g, ' ');
// Convert \(...\) to placeholder (display: false)
processed = processed.replace(/\\\(([\s\S]+?)\\\)/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: false });
mathCounter++;
return placeholder;
});
// Convert \[...\] to placeholder (display: true)
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: true });
mathCounter++;
return placeholder;
});
// Extract display math ($$...$$) BEFORE markdown processing
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
mathCounter++;
return placeholder;
});
// Extract inline math ($...$) BEFORE markdown processing
// Allow single-line only, skip currency patterns like $5 or $50
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
if (/^\d/.test(content.trim())) {
return match; // Keep as-is for currency
}
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
mathCounter++;
return placeholder;
});
// Restore escaped dollar signs
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
// Restore code blocks
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
// Clean up any remaining stray backslashes from unrecognized commands
processed = processed.replace(/\\(?=[a-zA-Z])/g, ''); // Remove \ before letters (unrecognized commands)
return processed;
}
/**
* Render math expressions with KaTeX after HTML is generated
* Render math expressions with KaTeX and restore HTML placeholders
*/
function renderMath(html: string): string {
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
// Replace all math placeholders with rendered KaTeX
for (const [placeholder, { content, displayMode }] of mathExpressions) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
html = html.replace(regex, () => {
try {
const rendered = katex.renderToString(content, {
displayMode,
throwOnError: false,
output: 'html'
});
if (displayMode) {
return `
<div class="math-display-wrapper">
<div class="math-display-header">
<span class="math-label">LaTeX</span>
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<div class="math-display-content">
${rendered}
</div>
</div>
`;
} else {
return `<span class="math-inline">${rendered}</span>`;
}
} catch {
const display = displayMode ? `$$${content}$$` : `$${content}$`;
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
}
});
}
// Restore HTML placeholders (for \textbf, \emph, etc.)
for (const [placeholder, htmlContent] of htmlSnippets) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
html = html.replace(regex, htmlContent);
}
return html;
}
@@ -154,16 +377,50 @@
}
}
async function handleMathCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedSource = target.getAttribute('data-math-source');
if (!encodedSource) return;
const source = decodeURIComponent(encodedSource);
try {
await navigator.clipboard.writeText(source);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy math:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of codeButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
for (const button of mathButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleMathCopyClick);
}
}
}
$effect(() => {
@@ -424,28 +681,290 @@
color: #60a5fa;
}
/* KaTeX math styling */
/* KaTeX math styling - Base */
.markdown-content :global(.katex) {
font-size: 1.1em;
color: oklch(0.9 0 0);
}
.markdown-content :global(.katex-display) {
/* Display math container wrapper */
.markdown-content :global(.math-display-wrapper) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.15);
background: rgba(0, 0, 0, 0.3);
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover) {
border-color: rgba(255, 215, 0, 0.25);
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
}
/* Display math header - hidden by default, slides in on hover */
.markdown-content :global(.math-display-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.375rem 0.75rem;
background: rgba(255, 215, 0, 0.03);
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
opacity: 0;
max-height: 0;
padding-top: 0;
padding-bottom: 0;
overflow: hidden;
transition:
opacity 0.2s ease,
max-height 0.2s ease,
padding 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
opacity: 1;
max-height: 2.5rem;
padding: 0.375rem 0.75rem;
}
.markdown-content :global(.math-label) {
color: rgba(255, 215, 0, 0.7);
font-size: 0.65rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-math-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
opacity: 0;
transition:
color 0.2s,
opacity 0.15s ease;
}
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
opacity: 1;
}
.markdown-content :global(.copy-math-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-math-btn.copied) {
color: #22c55e;
}
/* Display math content area */
.markdown-content :global(.math-display-content) {
padding: 1rem 1.25rem;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
.markdown-content :global(.katex-display > .katex) {
/* Custom scrollbar for math overflow */
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
height: 6px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
background: rgba(255, 255, 255, 0.05);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
background: rgba(255, 215, 0, 0.2);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
background: rgba(255, 215, 0, 0.35);
}
.markdown-content :global(.math-display-content .katex-display) {
margin: 0;
padding: 0;
}
.markdown-content :global(.math-display-content .katex-display > .katex) {
text-align: center;
}
/* Inline math wrapper */
.markdown-content :global(.math-inline) {
display: inline;
padding: 0 0.125rem;
border-radius: 0.25rem;
transition: background-color 0.15s ease;
}
.markdown-content :global(.math-inline:hover) {
background: rgba(255, 215, 0, 0.05);
}
/* Dark theme KaTeX overrides */
.markdown-content :global(.katex .mord),
.markdown-content :global(.katex .minner),
.markdown-content :global(.katex .mop),
.markdown-content :global(.katex .mbin),
.markdown-content :global(.katex .mrel),
.markdown-content :global(.katex .mpunct) {
color: oklch(0.9 0 0);
}
/* Fraction lines and rules */
.markdown-content :global(.katex .frac-line),
.markdown-content :global(.katex .overline-line),
.markdown-content :global(.katex .underline-line),
.markdown-content :global(.katex .hline),
.markdown-content :global(.katex .rule) {
border-color: oklch(0.85 0 0) !important;
background: oklch(0.85 0 0);
}
/* Square roots and SVG elements */
.markdown-content :global(.katex .sqrt-line) {
border-color: oklch(0.85 0 0) !important;
}
.markdown-content :global(.katex svg) {
fill: oklch(0.85 0 0);
stroke: oklch(0.85 0 0);
}
.markdown-content :global(.katex svg path) {
stroke: oklch(0.85 0 0);
}
/* Delimiters (parentheses, brackets, braces) */
.markdown-content :global(.katex .delimsizing),
.markdown-content :global(.katex .delim-size1),
.markdown-content :global(.katex .delim-size2),
.markdown-content :global(.katex .delim-size3),
.markdown-content :global(.katex .delim-size4),
.markdown-content :global(.katex .mopen),
.markdown-content :global(.katex .mclose) {
color: oklch(0.75 0 0);
}
/* Math error styling */
.markdown-content :global(.math-error) {
display: inline-flex;
align-items: center;
gap: 0.375rem;
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.125rem 0.25rem;
padding: 0.25rem 0.5rem;
border-radius: 0.25rem;
border: 1px solid rgba(248, 113, 113, 0.2);
}
.markdown-content :global(.math-error-icon) {
font-size: 0.875em;
opacity: 0.9;
}
/* LaTeX proof environment */
.markdown-content :global(.latex-proof) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 255, 255, 0.02);
border-left: 3px solid rgba(255, 215, 0, 0.4);
border-radius: 0 0.375rem 0.375rem 0;
}
.markdown-content :global(.latex-proof-header) {
font-weight: 600;
font-style: italic;
color: oklch(0.85 0 0);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-proof-header::after) {
content: '.';
}
.markdown-content :global(.latex-proof-content) {
color: oklch(0.9 0 0);
}
.markdown-content :global(.latex-proof-content p:last-child) {
margin-bottom: 0;
}
/* QED symbol at end of proof */
.markdown-content :global(.latex-proof-content::after) {
content: '∎';
display: block;
text-align: right;
color: oklch(0.7 0 0);
margin-top: 0.5rem;
}
/* LaTeX theorem-like environments */
.markdown-content :global(.latex-theorem) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 215, 0, 0.03);
border: 1px solid rgba(255, 215, 0, 0.15);
border-radius: 0.375rem;
}
.markdown-content :global(.latex-theorem-header) {
font-weight: 700;
color: var(--exo-yellow, #ffd700);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-theorem-header::after) {
content: '.';
}
.markdown-content :global(.latex-theorem-content) {
color: oklch(0.9 0 0);
font-style: italic;
}
.markdown-content :global(.latex-theorem-content p:last-child) {
margin-bottom: 0;
}
/* LaTeX diagram/figure placeholder */
.markdown-content :global(.latex-diagram-placeholder) {
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
margin: 1rem 0;
padding: 1.5rem 2rem;
background: rgba(255, 255, 255, 0.02);
border: 1px dashed rgba(255, 215, 0, 0.25);
border-radius: 0.5rem;
color: rgba(255, 215, 0, 0.6);
font-size: 0.875rem;
}
.markdown-content :global(.latex-diagram-icon) {
font-size: 1.25rem;
opacity: 0.8;
}
.markdown-content :global(.latex-diagram-text) {
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.75rem;
text-transform: uppercase;
letter-spacing: 0.05em;
}
</style>

View File

@@ -69,8 +69,6 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
interface RawNodeProfile {

View File

@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -59,7 +59,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
instanceType: InstanceMeta;
minNodes: number;
}
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
@@ -88,16 +88,16 @@ const sidebarVisible = $derived(chatSidebarVisible());
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
@@ -109,19 +109,11 @@ const sidebarVisible = $derived(chatSidebarVisible());
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Draft model edit modal state
let editingDraftInstanceId = $state<string | null>(null);
let editDraftModel = $state<string | null>(null);
let editNumDraftTokens = $state<number>(4);
let isDraftEditDropdownOpen = $state(false);
let draftEditDropdownSearch = $state('');
let isSavingDraftModel = $state(false);
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state('');
// Slider dragging state
let isDraggingSlider = $state(false);
let sliderTrackElement: HTMLDivElement | null = $state(null);
@@ -370,36 +362,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
if (!modelId || launchingModelId) return;
launchingModelId = modelId;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
const preview = specificPreview ?? filteredPreview();
let response: Response;
// Use /place_instance endpoint - it handles placement and creation in one step
const placePayload = {
model_id: modelId,
sharding: preview?.sharding ?? selectedSharding,
instance_meta: preview?.instance_meta ?? selectedInstanceType,
min_nodes: selectedMinNodes,
};
response = await fetch('/place_instance', {
let instanceData: unknown;
if (preview?.instance) {
// Use the instance from the preview
instanceData = preview.instance;
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
);
if (!placementResponse.ok) {
const errorText = await placementResponse.text();
console.error('Failed to get placement:', errorText);
return;
}
instanceData = await placementResponse.json();
}
// POST the instance to create it
const response = await fetch('/instance', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(placePayload)
body: JSON.stringify({ instance: instanceData })
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
const scrollToBottom = () => {
@@ -794,52 +797,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
}
// Open draft model edit modal for an instance
function openDraftModelEdit(instanceId: string, currentDraftModel: string | null, currentNumTokens: number | null) {
editingDraftInstanceId = instanceId;
editDraftModel = currentDraftModel;
editNumDraftTokens = currentNumTokens ?? 4;
isDraftEditDropdownOpen = false;
draftEditDropdownSearch = '';
}
// Close draft model edit modal
function closeDraftModelEdit() {
editingDraftInstanceId = null;
editDraftModel = null;
editNumDraftTokens = 4;
isDraftEditDropdownOpen = false;
draftEditDropdownSearch = '';
}
// Save draft model settings for an instance
async function saveDraftModel() {
if (!editingDraftInstanceId || isSavingDraftModel) return;
isSavingDraftModel = true;
try {
const response = await fetch(`/instance/${editingDraftInstanceId}/draft_model`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
draft_model: editDraftModel,
num_draft_tokens: editNumDraftTokens,
})
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to set draft model:', errorText);
} else {
closeDraftModelEdit();
}
} catch (error) {
console.error('Error setting draft model:', error);
} finally {
isSavingDraftModel = false;
}
}
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
function getTagged(obj: unknown): [string | null, unknown] {
if (!obj || typeof obj !== 'object') return [null, null];
@@ -859,34 +816,30 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
nodeNames: string[];
nodeIds: string[];
nodeCount: number;
draftModel: string | null;
numDraftTokens: number | null;
} {
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
}
// Instance type from tag
let instanceType = 'Unknown';
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
};
draftModel?: string;
numDraftTokens?: number;
}
};
// Sharding strategy from first shard
let sharding = 'Unknown';
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
@@ -897,7 +850,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
}
// Node names from topology
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
const nodeIds = Object.keys(nodeToRunner);
@@ -905,12 +858,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
});
// Draft model for speculative decoding
const draftModel = inst.draftModel ?? null;
const numDraftTokens = inst.numDraftTokens ?? null;
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
}
function formatLastUpdate(): string {
@@ -1386,31 +1335,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
</div>
<div class="flex items-center gap-2">
<!-- Draft Model Button -->
<button
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
</svg>
</button>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -1745,7 +1679,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
</div>
</div>
</div>
<!-- Selected Model Preview -->
<div class="space-y-3">
{#if models.length === 0}
@@ -1904,31 +1838,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
</div>
<div class="flex items-center gap-2">
<!-- Draft Model Button -->
<button
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
</svg>
</button>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -2059,120 +1978,4 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{/if}
</main>
<!-- Draft Model Edit Modal -->
{#if editingDraftInstanceId}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="fixed inset-0 z-50 flex items-center justify-center bg-black/70 backdrop-blur-sm"
onclick={closeDraftModelEdit}
onkeydown={(e) => e.key === 'Escape' && closeDraftModelEdit()}
>
<!-- svelte-ignore a11y_click_events_have_key_events -->
<div
class="bg-exo-dark-gray border border-exo-medium-gray/50 rounded-lg shadow-2xl p-6 w-full max-w-md mx-4"
onclick={(e) => e.stopPropagation()}
>
<div class="flex items-center justify-between mb-4">
<h3 class="text-lg font-mono text-exo-yellow tracking-wide">Speculative Decoding</h3>
<button
onclick={closeDraftModelEdit}
class="text-white/60 hover:text-white transition-colors cursor-pointer"
aria-label="Close"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
<p class="text-white/60 text-sm font-mono mb-4">
Configure a draft model for faster generation. The draft model proposes tokens that the main model verifies.
</p>
<!-- Draft Model Dropdown -->
<div class="mb-4">
<div class="text-xs text-white/70 font-mono mb-2">Draft Model:</div>
<div class="relative">
<button
onclick={() => { isDraftEditDropdownOpen = !isDraftEditDropdownOpen; draftEditDropdownSearch = ''; }}
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {editDraftModel ? 'bg-transparent text-cyan-400 border-cyan-500/50' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-cyan-500/50'}"
>
<span class="truncate">{editDraftModel ? editDraftModel.split('/').pop() : 'None'}</span>
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftEditDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</button>
{#if isDraftEditDropdownOpen}
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
<div class="p-2 border-b border-exo-medium-gray/30">
<input
type="text"
bind:value={draftEditDropdownSearch}
placeholder="Search models..."
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-cyan-500/50"
/>
</div>
<div class="overflow-y-auto max-h-36">
<!-- None option -->
<button
onclick={() => { editDraftModel = null; isDraftEditDropdownOpen = false; }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {editDraftModel === null ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
>
<span>None (Disable)</span>
</button>
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftEditDropdownSearch.toLowerCase())) as model}
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
{@const modelHfId = model.hugging_face_id ?? model.id}
<button
onclick={() => { editDraftModel = modelHfId; isDraftEditDropdownOpen = false; }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {editDraftModel === modelHfId ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs text-white/50">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
{/each}
</div>
</div>
{/if}
</div>
</div>
<!-- Draft Tokens -->
{#if editDraftModel}
<div class="mb-6">
<div class="text-xs text-white/70 font-mono mb-2">Draft Tokens per Iteration:</div>
<div class="flex items-center gap-2">
{#each [2, 3, 4, 5, 6] as n}
<button
onclick={() => editNumDraftTokens = n}
class="w-8 h-8 text-sm font-mono rounded transition-all {editNumDraftTokens === n ? 'bg-cyan-500/20 text-cyan-400 border border-cyan-500/50' : 'text-white/50 hover:text-white/80 border border-exo-medium-gray/50 hover:border-white/30'} cursor-pointer"
>{n}</button>
{/each}
</div>
</div>
{/if}
<!-- Action Buttons -->
<div class="flex items-center justify-end gap-3">
<button
onclick={closeDraftModelEdit}
class="px-4 py-2 text-sm font-mono text-white/70 hover:text-white transition-colors cursor-pointer"
>
Cancel
</button>
<button
onclick={saveDraftModel}
disabled={isSavingDraftModel}
class="px-4 py-2 text-sm font-mono border border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500 transition-all disabled:opacity-50 disabled:cursor-not-allowed cursor-pointer"
>
{isSavingDraftModel ? 'Saving...' : 'Save'}
</button>
</div>
</div>
</div>
{/if}
</div>

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 187 KiB

65
flake.lock generated
View File

@@ -21,7 +21,9 @@
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1765953015,
@@ -149,19 +151,44 @@
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1763662255,
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"lastModified": 1764134915,
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
"type": "github"
},
"original": {
@@ -178,7 +205,10 @@
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"treefmt-nix": "treefmt-nix",
"uv2nix": "uv2nix"
}
},
"rust-analyzer-src": {
@@ -239,6 +269,29 @@
"repo": "treefmt-nix",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1767701098,
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",

View File

@@ -24,6 +24,26 @@
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
};
# Python packaging with uv2nix
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
@@ -48,6 +68,7 @@
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
];
perSystem =
@@ -81,11 +102,10 @@
};
};
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
packages =
if pkgs.stdenv.isDarwin then {
metal = pkgs.callPackage ./nix/metalWrapper.nix { metalVersion = "310"; };
} else { };
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];
@@ -124,6 +144,7 @@
OPENSSL_NO_VENDOR = "1";
shellHook = ''
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
${lib.optionalString stdenv.isLinux ''

View File

@@ -0,0 +1,79 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0ed30932..d8528132 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
- # Throw an error if xcrun not found
- execute_process(
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
- OUTPUT_VARIABLE MACOS_SDK_VERSION
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
+ set(MACOS_SDK_VERSION @sdkVersion@)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
- execute_process(
- COMMAND
- zsh "-c"
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
+ set(
+ MLX_METAL_VERSION @metalVersion@)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
index 13db804a..5b385132 100644
--- a/cmake/extension.cmake
+++ b/cmake/extension.cmake
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
- xcrun -sdk macosx metal
+ metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
index 262b0495..5c7446ad 100644
--- a/mlx/backend/metal/kernels/CMakeLists.txt
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
add_custom_command(
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
+ COMMAND metal ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
@@ -170,7 +170,7 @@ endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
+ COMMAND metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
index bb55ed3a..94ea7dd7 100644
--- a/mlx/backend/metal/make_compiled_preamble.sh
+++ b/mlx/backend/metal/make_compiled_preamble.sh
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# Use the metal compiler to get a list of headers (with depth)
-CCC="xcrun -sdk macosx metal -x metal"
+CCC="metal -x metal"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
# Remove any included system frameworks (for MetalPerformancePrimitive headers)

25
nix/metalWrapper.nix Normal file
View File

@@ -0,0 +1,25 @@
{ stdenvNoCC
, metalVersion
}:
assert stdenvNoCC.isDarwin;
stdenvNoCC.mkDerivation {
pname = "metal-wrapper-impure";
version = metalVersion;
__noChroot = true;
buildCommand = ''
mkdir -p $out/bin && cd $out/bin
METALLIB_PATH=''${GH_OVERRIDE_METALLIB:-$(/usr/bin/xcrun --sdk macosx -f metallib)}
METAL_PATH=''${GH_OVERRIDE_METAL:-"$(dirname "$METALLIB_PATH")/metal"}
echo "$METAL_PATH"
echo "$METALLIB_PATH"
ln -sf "$METAL_PATH" metal
ln -sf "$METALLIB_PATH" metallib
[[ -e $out/bin/metal ]] && [[ -e $out/bin/metallib ]] || { echo ":(" && exit 1; }
METAL_VERSION=$(echo __METAL_VERSION__ | "$METAL_PATH" -E -x metal -P - | tail -1 | tr -d '\n')
[[ "$METAL_VERSION" == "${metalVersion}" ]] || { echo "Metal version $METAL_VERSION is not ${metalVersion}" && exit 1; }
'';
}

154
nix/mlx.nix Normal file
View File

@@ -0,0 +1,154 @@
{ stdenv
, lib
, buildPythonPackage
, fetchFromGitHub
, replaceVars
, fetchzip
, setuptools
, cmake
, nanobind
, pybind11
, nlohmann_json
, apple-sdk_26
, metal
, numpy
, pytestCheckHook
, python
, runCommand
, fmt
}:
assert stdenv.isDarwin;
let
# static dependencies included directly during compilation
gguf-tools = fetchFromGitHub {
owner = "antirez";
repo = "gguf-tools";
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
};
metal_cpp = fetchzip {
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
};
mlx = buildPythonPackage rec {
pname = "mlx";
version = "0.30.1";
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-Vt0RH+70VBwUjXSfPTsNdRS3g0ookJHhzf2kvgEtgH8=";
};
patches = [
(replaceVars ./darwin-build-fixes.patch {
sdkVersion = apple-sdk_26.version;
metalVersion = metal.version;
})
];
postPatch = ''
substituteInPlace pyproject.toml \
--replace-fail "nanobind==2.10.2" "nanobind"
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
--replace-fail "g++" "$CXX"
'';
dontUseCmakeConfigure = true;
enableParallelBuilding = true;
# Allows multiple cores to be used in Python builds.
postUnpack = ''
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
'';
# updates the wrong fetcher rev attribute
passthru.skipBulkUpdate = true;
env = {
DEV_RELEASE = 1;
# NOTE The `metal` command-line utility used to build the Metal kernels is not open-source.
# this is what the xcode wrapper is for - it patches in the system metal cli
CMAKE_ARGS = toString [
(lib.cmakeBool "USE_SYSTEM_FMT" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeOptionType "filepath" "METAL_LIB"
"${metal}/Metal.framework")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
];
SDKROOT = apple-sdk_26.passthru.sdkroot;
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
};
build-system = [
setuptools
];
nativeBuildInputs = [
cmake
metal
];
buildInputs = [
fmt
gguf-tools
nanobind
pybind11
apple-sdk_26
];
pythonImportsCheck = [ "mlx" ];
# Run the mlx Python test suite.
nativeCheckInputs = [
numpy
pytestCheckHook
];
enabledTestPaths = [
"python/tests/"
];
# Additional testing by executing the example Python scripts supplied with mlx
# using the version of the library we've built.
passthru.tests = {
mlxTest =
runCommand "run-mlx-examples"
{
buildInputs = [ mlx ];
nativeBuildInputs = [ python ];
}
''
cp ${src}/examples/python/logistic_regression.py .
${python.interpreter} logistic_regression.py
rm logistic_regression.py
cp ${src}/examples/python/linear_regression.py .
${python.interpreter} linear_regression.py
rm linear_regression.py
touch $out
'';
};
meta = {
homepage = "https://github.com/ml-explore/mlx";
description = "Array framework for Apple silicon";
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
license = lib.licenses.mit;
platforms = [ "x86_64-linux" "aarch64-linux" "aarch64-darwin" ];
};
};
in
mlx

143
python/parts.nix Normal file
View File

@@ -0,0 +1,143 @@
{ inputs, ... }:
{
perSystem =
{ self', pkgs, lib, system, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = inputs.self;
};
# Create overlay from workspace
overlay = workspace.mkPyprojectOverlay { };
# Override overlay to inject Nix-built components
exoOverlay = final: _: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
src = self'.packages.exo_pyo3_bindings;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
};
};
python = pkgs.python313;
# Overlay to provide build systems for source builds
buildSystemsOverlay = final: prev: {
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
# Build MLX from source with proper dependencies
mlx = pkgs.callPythonPackage ./nix/mlx.nix;
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
exoOverlay
buildSystemsOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
}
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard}
done
'';
pyinstallerPackage =
let
venv = pythonSet.mkVirtualEnv "exo-pyinstaller-env" (
workspace.deps.default
// {
# Include pyinstaller in the environment
exo = [ "dev" ];
}
);
in
pkgs.stdenv.mkDerivation {
pname = "exo-pyinstaller";
version = "0.3.0";
src = inputs.self;
nativeBuildInputs = [ venv pkgs.makeWrapper pkgs.macmon pkgs.darwin.system_cmds ];
buildPhase = ''
# macmon must be in PATH for PyInstaller to bundle it
export PATH="${pkgs.macmon}/bin:$PATH"
# HOME must be writable for PyInstaller's cache
export HOME="$TMPDIR"
# Copy dashboard to expected location
mkdir -p dashboard/build
cp -r ${self'.packages.dashboard}/* dashboard/build/
# Run PyInstaller
${venv}/bin/python -m PyInstaller packaging/pyinstaller/exo.spec
'';
installPhase = ''
cp -r dist/exo $out
'';
};
in
{
# Python package only available on macOS for now due to the dependency on
# mlx/mlx-cpu being tricky to build on Linux. We can either remove this
# dependency in the PyProject or build it with Nix.
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
exo-pyinstaller = pyinstallerPackage;
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
}
# Pytest only on macOS (requires MLX)
// lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
pytest = pkgs.runCommand "pytest"
{
nativeBuildInputs = [ testVenv ];
} ''
export HOME="$TMPDIR"
export EXO_TESTS=1
cd ${inputs.self}
${testVenv}/bin/python -m pytest src -m "not slow" --import-mode=importlib
touch $out
'';
};
};
}

View File

@@ -39,8 +39,6 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
SetDraftModelParams,
SetDraftModelResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
@@ -51,7 +49,6 @@ from exo.shared.types.commands import (
DeleteInstance,
ForwarderCommand,
PlaceInstance,
SetInstanceDraftModel,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -158,18 +155,19 @@ class API:
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
self.app.exception_handler(HTTPException)(self.http_exception_handler)
async def http_exception_handler(
self, _: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
@@ -188,7 +186,6 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.put("/instance/{instance_id}/draft_model")(self.set_draft_model)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/v1/chat/completions", response_model=None)(
@@ -204,8 +201,6 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
@@ -402,24 +397,6 @@ class API:
instance_id=instance_id,
)
async def set_draft_model(
self, instance_id: InstanceId, payload: SetDraftModelParams
) -> SetDraftModelResponse:
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
command = SetInstanceDraftModel(
instance_id=instance_id,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
return SetDraftModelResponse(
message="Command received.",
command_id=command.command_id,
instance_id=instance_id,
)
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:

View File

@@ -18,7 +18,6 @@ from exo.shared.types.commands import (
ForwarderCommand,
PlaceInstance,
RequestEventLog,
SetInstanceDraftModel,
TaskFinished,
TestCommand,
)
@@ -28,7 +27,6 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeTimedOut,
TaskCreated,
TaskDeleted,
@@ -175,14 +173,6 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SetInstanceDraftModel():
generated_events.append(
InstanceDraftModelUpdated(
instance_id=command.instance_id,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -3,6 +3,8 @@ from collections.abc import Mapping
from copy import deepcopy
from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
@@ -53,6 +55,7 @@ def place_instance(
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
@@ -125,6 +128,10 @@ def place_instance(
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
# TODO: Single node instances
@@ -144,8 +151,6 @@ def place_instance(
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -159,8 +164,6 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -49,33 +49,83 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
if not selected_cycle:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node.node_profile.memory.ram_available.in_bytes / cycle_memory.in_bytes
for node in selected_cycle
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node.node_profile.memory.ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node.node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(

View File

@@ -70,7 +70,7 @@ def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
[
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
((312, 468, 1092), 12, (2, 3, 7)),
],
)
def test_get_instance_placements_create_instance(

View File

@@ -3,6 +3,7 @@ from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
@@ -165,6 +166,9 @@ def test_get_smallest_cycles(
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
@@ -397,3 +401,96 @@ def test_get_mlx_jaccl_coordinators(
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a = create_node(900 * 1024, node_a_id)
node_b = create_node(50 * 1024, node_b_id)
node_c = create_node(10 * 1024, node_c_id) # Insufficient memory
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(model_meta, selected_cycle, Sharding.Pipeline)

View File

@@ -11,7 +11,6 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeCreated,
NodeDownloadProgress,
NodeMemoryMeasured,
@@ -48,8 +47,6 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceDraftModelUpdated():
return apply_instance_draft_model_updated(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
@@ -172,25 +169,6 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
def apply_instance_draft_model_updated(
event: InstanceDraftModelUpdated, state: State
) -> State:
if event.instance_id not in state.instances:
return state
instance = state.instances[event.instance_id]
updated_instance = instance.model_copy(
update={
"draft_model": event.draft_model,
"num_draft_tokens": event.num_draft_tokens,
}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
event.instance_id: updated_instance,
}
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,

View File

@@ -161,8 +161,6 @@ class ChatCompletionTaskParams(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
num_draft_tokens: int = 3
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
@@ -174,8 +172,6 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod
@@ -217,14 +213,3 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class SetDraftModelParams(BaseModel):
draft_model: ModelId | None = None # None to disable speculative decoding
num_draft_tokens: int = 4
class SetDraftModelResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId

View File

@@ -2,7 +2,7 @@ from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -25,8 +25,6 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
class CreateInstance(BaseCommand):
@@ -37,14 +35,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class SetInstanceDraftModel(BaseCommand):
"""Set or update the draft model for an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None # None to disable speculative decoding
num_draft_tokens: int = 4
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -60,7 +50,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| SetInstanceDraftModel
| TaskFinished
)

View File

@@ -5,7 +5,6 @@ from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
@@ -68,14 +67,6 @@ class InstanceDeleted(BaseEvent):
instance_id: InstanceId
class InstanceDraftModelUpdated(BaseEvent):
"""Draft model updated on an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None
num_draft_tokens: int
class RunnerStatusUpdated(BaseEvent):
runner_id: RunnerId
runner_status: RunnerStatus
@@ -132,7 +123,6 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceDraftModelUpdated
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated

View File

@@ -36,12 +36,6 @@ class DownloadModel(BaseTask): # emitted by Worker
shard_metadata: ShardMetadata
class DownloadDraftModel(BaseTask): # emitted by Worker
"""Download a draft model for speculative decoding (rank 0 only)."""
model_id: str # HuggingFace model ID
class LoadModel(BaseTask): # emitted by Worker
pass
@@ -66,21 +60,12 @@ class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
class SetDraftModel(BaseTask): # emitted by Worker
"""Load or clear a draft model on an already-running instance."""
model_id: str | None # HuggingFace model ID, or None to clear
num_draft_tokens: int = 4
Task = (
CreateRunner
| DownloadModel
| DownloadDraftModel
| ConnectToGroup
| LoadModel
| StartWarmup
| ChatCompletion
| Shutdown
| SetDraftModel
)

View File

@@ -3,7 +3,6 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -20,8 +19,6 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -46,9 +46,11 @@ class CustomMlxLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
# Set twice to avoid __setattr__ recursion
object.__setattr__(self, "_original_layer", original_layer)
self.original_layer: _LayerCallable = original_layer
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -58,7 +60,7 @@ class CustomMlxLayer(nn.Module):
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
return object.__getattribute__(original_layer, name)
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
@@ -168,11 +170,21 @@ def pipeline_auto_parallel(
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
# We can assume the model has at least one layer thanks to placement.
# If a layer type doesn't exist, we can set it to 0.
inner_model_instance.swa_idx = (
0
if "sliding_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
inner_model_instance.ga_idx = (
0
if "full_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
)
_set_layers(model, layers)

View File

@@ -48,8 +48,6 @@ def maybe_quantize_kv_cache(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -68,30 +66,25 @@ def warmup_inference(
tokens_generated = 0
cache = make_kv_cache(
model=model,
)
# Use a default sampler for warmup
sampler = make_sampler(temp=0.7)
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": warmup_prompt,
"max_tokens": 50,
"sampler": sampler,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Warm up with draft model if provided (speculative decoding path)
if draft_model is not None:
logger.info("Warming up with speculative decoding (draft model)")
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
logger.info("Generating warmup tokens")
for _r in stream_generate(**generate_kwargs): # type: ignore[arg-type]
for _r in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=warmup_prompt,
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info("Generated warmup token: " + str(_r.text))
tokens_generated += 1
@@ -126,8 +119,6 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -144,6 +135,8 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -156,31 +149,19 @@ def mlx_generate(
)
max_tokens = task.max_tokens or MAX_TOKENS
# Build kwargs for stream_generate, conditionally adding draft model params
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"sampler": sampler,
"logits_processors": logits_processors,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Add speculative decoding parameters if draft model is provided
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
# as speculative decoding requires cache trimming capabilities
if draft_model is not None:
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
# Only use custom cache for non-speculative generation
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info(out.text)
stats: GenerationStats | None = None

View File

@@ -258,27 +258,6 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: str) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,

View File

@@ -29,9 +29,7 @@ from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformance
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadDraftModel,
DownloadModel,
SetDraftModel,
Shutdown,
Task,
TaskStatus,
@@ -50,7 +48,6 @@ from exo.utils.event_buffer import OrderedBuffer
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.impl_shard_downloader import build_full_shard
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -205,10 +202,42 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
await self._handle_download(shard, task)
case DownloadDraftModel(model_id=model_id):
shard = await build_full_shard(model_id)
await self._handle_download(shard, task)
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -219,25 +248,6 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case SetDraftModel(
model_id=draft_model_id, num_draft_tokens=num_tokens
):
runner = self.runners[self._task_to_runner_id(task)]
await runner.start_task(task)
# Update bound_instance to reflect new/cleared draft model
updated_instance = runner.bound_instance.instance.model_copy(
update={
"draft_model": (
ModelId(draft_model_id)
if draft_model_id is not None
else None
),
"num_draft_tokens": num_tokens,
}
)
runner.bound_instance = runner.bound_instance.model_copy(
update={"instance": updated_instance}
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
@@ -330,46 +340,6 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
async def _handle_download(self, shard: ShardMetadata, task: Task) -> None:
"""Handle model download - shared logic for main and draft models."""
model_id = shard.model_meta.model_id
if model_id not in self.download_status:
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
else:
await self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
download_task = DownloadModel(
instance_id=task.instance_id,
shard_metadata=shard,
task_id=task.task_id,
task_status=task.task_status,
)
self._handle_shard_download_process(download_task, initial_progress)
def _handle_shard_download_process(
self,
task: DownloadModel,

View File

@@ -8,10 +8,8 @@ from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadDraftModel,
DownloadModel,
LoadModel,
SetDraftModel,
Shutdown,
StartWarmup,
Task,
@@ -40,16 +38,6 @@ from exo.shared.types.worker.runners import (
from exo.worker.runner.runner_supervisor import RunnerSupervisor
def _is_download_in_progress_or_complete(
model_id: ModelId,
download_status: Mapping[ModelId, DownloadProgress],
) -> bool:
"""Check if model download is in progress or complete."""
return model_id in download_status and isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
@@ -67,11 +55,9 @@ def plan(
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _draft_model_needs_download(runners, download_status, instances)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status, download_status)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _set_draft_model(runners, instances, download_status)
or _pending_tasks(runners, tasks, all_runners)
)
@@ -129,9 +115,12 @@ def _model_needs_download(
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(
runner.status, RunnerIdle
) and not _is_download_in_progress_or_complete(model_id, download_status):
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
instance_id=runner.bound_instance.instance.instance_id,
@@ -139,43 +128,6 @@ def _model_needs_download(
)
def _draft_model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
instances: Mapping[InstanceId, Instance],
) -> DownloadDraftModel | None:
"""Check if draft model needs download for rank 0 runner.
Triggers download when:
- RunnerIdle with draft model (initial setup)
- RunnerReady with new draft model (updated via API)
"""
rank_0_runner = next(
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
None,
)
if rank_0_runner is None:
return None
if not isinstance(rank_0_runner.status, (RunnerIdle, RunnerReady)):
return None
# Use current instance state (may have been updated via API)
instance_id = rank_0_runner.bound_instance.instance.instance_id
current_instance = instances.get(instance_id)
if current_instance is None:
return None
draft_model_id = current_instance.draft_model
if draft_model_id is None:
return None
if _is_download_in_progress_or_complete(draft_model_id, download_status):
return None
return DownloadDraftModel(
instance_id=instance_id,
model_id=str(draft_model_id),
)
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -230,12 +182,10 @@ def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
download_status: Mapping[ModelId, DownloadProgress],
) -> LoadModel | None:
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
shard = runner.bound_instance.bound_shard
all_local_downloads_complete = all(
nid in global_download_status
@@ -249,14 +199,6 @@ def _load_model(
if not all_local_downloads_complete:
continue
# Rank 0 with draft model must wait for draft download before loading
if shard.device_rank == 0:
draft_model_id = instance.draft_model
if draft_model_id is not None and not isinstance(
download_status.get(draft_model_id), DownloadCompleted
):
continue
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
@@ -316,53 +258,6 @@ def _ready_to_warmup(
return None
def _set_draft_model(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
download_status: Mapping[ModelId, DownloadProgress],
) -> SetDraftModel | None:
"""Check if rank 0 runner needs to load or clear a draft model."""
rank_0_runner = next(
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
None,
)
if rank_0_runner is None:
return None
if not isinstance(rank_0_runner.status, RunnerReady):
return None
instance_id = rank_0_runner.bound_instance.instance.instance_id
current_instance = instances.get(instance_id)
if current_instance is None:
return None
# Compare runner's bound draft model vs current instance draft model
runner_draft_model = rank_0_runner.bound_instance.instance.draft_model
current_draft_model = current_instance.draft_model
if runner_draft_model == current_draft_model:
return None
# Draft model changed - need to update
if current_draft_model is None:
# Clear draft model
return SetDraftModel(
instance_id=instance_id,
model_id=None,
num_draft_tokens=4,
)
# Wait for draft model to be downloaded
if not isinstance(download_status.get(current_draft_model), DownloadCompleted):
return None
return SetDraftModel(
instance_id=instance_id,
model_id=str(current_draft_model),
num_draft_tokens=current_instance.num_draft_tokens,
)
def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],

View File

@@ -1,8 +1,6 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -15,7 +13,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -23,12 +20,10 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
LoadModel,
SetDraftModel,
Shutdown,
StartWarmup,
Task,
@@ -53,44 +48,15 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_draft_model,
load_mlx_items,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -101,6 +67,7 @@ def main(
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
@@ -112,7 +79,6 @@ def main(
model = None
tokenizer = None
group = None
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -168,16 +134,6 @@ def main(
bound_instance, group, on_timeout=on_model_load_timeout
)
# Load draft model for speculative decoding (rank 0 only)
if (
instance.draft_model is not None
and shard_metadata.device_rank == 0
):
logger.info(f"Loading draft model: {instance.draft_model}")
draft_model = cast(
Model, load_draft_model(str(instance.draft_model))
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -193,10 +149,9 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=cast(Model, model),
model=model,
tokenizer=tokenizer,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -215,24 +170,18 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
):
assert model
assert tokenizer
assert task_params.messages[0].content is not None
assert model
assert tokenizer
assert task_params.messages[0].content is not None
try:
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses (draft_model loaded at warmup if configured)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=cast(Model, model),
model=model,
tokenizer=tokenizer,
task=task_params,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
# GPT-OSS specific parsing to match other model formats.
@@ -244,7 +193,7 @@ def main(
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
@@ -259,52 +208,26 @@ def main(
)
)
# can we make this more explicit?
except Exception as e:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_meta.model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case SetDraftModel(
model_id=draft_model_id, num_draft_tokens=num_tokens
) if isinstance(current_status, RunnerReady):
current_status = RunnerWarmingUp()
logger.info("runner warming up (setting draft model)")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert model is not None
assert tokenizer is not None
if draft_model_id is None:
# Clear draft model
logger.info("Clearing draft model")
draft_model = None
instance = instance.model_copy(
update={
"draft_model": None,
"num_draft_tokens": 4,
}
)
else:
# Load new draft model
logger.info(f"Loading draft model: {draft_model_id}")
draft_model = cast(Model, load_draft_model(draft_model_id))
instance = instance.model_copy(
update={
"draft_model": ModelId(draft_model_id),
"num_draft_tokens": num_tokens,
}
)
# Warm up with speculative decoding
logger.info("Warming up with new draft model")
warmup_inference(
model=cast(Model, model),
tokenizer=tokenizer,
draft_model=draft_model,
num_draft_tokens=num_tokens,
)
logger.info("Draft model loaded and warmed up")
current_status = RunnerReady()
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
@@ -325,7 +248,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group, draft_model
del model, tokenizer, group
mx.clear_cache()
import gc

View File

@@ -0,0 +1,202 @@
# type: ignore
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
class MockLayer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
self.use_sliding = True
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x * 2
@dataclass(frozen=True)
class PipelineTestConfig:
model_path: Path
total_layers: int
base_port: int
max_tokens: int
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
import json
import tempfile
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
return hostfile_path, hosts
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
total_layers=24,
base_port=29600,
max_tokens=200,
)
def run_gpt_oss_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 200,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
from mlx_lm import load, stream_generate
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
model, tokenizer = load(str(model_path))
# Generate a prompt of exact token length
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
# Build prompt with approximate target length
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
# Truncate to exact target length
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=False,
),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=24,
)
model = pipeline_auto_parallel(model, group, shard_meta)
# Barrier before generation
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
generated_text = ""
for response in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=max_tokens,
prefill_step_size=prefill_step_size,
):
generated_text += response.text
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
def run_gpt_oss_tensor_parallel_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
prompt_tokens: int,
prefill_step_size: int,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 10,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
from mlx_lm import load, stream_generate
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
model, tokenizer = load(str(model_path))
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
model = tensor_auto_parallel(model, group)
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
generated_text = ""
for response in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=max_tokens,
prefill_step_size=prefill_step_size,
):
generated_text += response.text
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]

View File

@@ -0,0 +1,137 @@
import multiprocessing as mp
from typing import Any
import mlx.core as mx
import pytest
from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer,
PipelineFirstLayer,
PipelineLastLayer,
)
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
def run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
result_queue: Any, # pyright: ignore[reportAny]
) -> None:
import os
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
import mlx.nn as mlx_nn
class MockLayerInner(mlx_nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
def __call__(
self, x: mlx_core.array, *args: object, **kwargs: object
) -> mlx_core.array:
return x * 2
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
x = mlx_core.ones((1, 4))
result = composed(x)
mlx_core.eval(result)
success = result.shape == x.shape
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
def test_single_wrapper_delegates_attributes() -> None:
mock = MockLayer()
wrapped = CustomMlxLayer(mock)
assert wrapped.custom_attr == "test_value" # type: ignore[attr-defined]
assert wrapped.use_sliding is True # type: ignore[attr-defined]
def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer()
group = mx.distributed.init()
first = PipelineFirstLayer(mock, r=0, group=group)
composed = PipelineLastLayer(first, r=0, s=1, group=group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined]
def test_missing_attribute_raises() -> None:
mock = MockLayer()
wrapped = CustomMlxLayer(mock)
with pytest.raises(AttributeError):
_ = wrapped.nonexistent_attr # type: ignore[attr-defined]
def test_composed_call_works() -> None:
import json
import os
import tempfile
ctx = mp.get_context("spawn")
world_size = 2
base_port = 29500
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
try:
result_queue: Any = ctx.Queue()
processes: list[Any] = []
for rank in range(world_size):
p = ctx.Process(
target=run_pipeline_device,
args=(rank, world_size, hostfile_path, result_queue),
)
p.start()
processes.append(p)
for p in processes: # pyright: ignore[reportAny]
p.join(timeout=10) # pyright: ignore[reportAny]
results: dict[int, Any] = {}
errors: dict[int, str] = {}
while not result_queue.empty(): # pyright: ignore[reportAny]
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
if success:
results[rank] = value
else:
errors[rank] = value
assert len(results) == world_size, (
f"Expected {world_size} results, got {len(results)}. Errors: {errors}"
)
for rank in range(world_size):
assert rank in results, (
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
)
result_array = results[rank]
# Both devices see the final result (4.0) after all_gather
assert (result_array == 4.0).all(), (
f"Device {rank}: expected 4.0, got {result_array}"
)
finally:
os.unlink(hostfile_path)

View File

@@ -1,50 +0,0 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()