Compare commits

...

12 Commits

Author SHA1 Message Date
Alex Cheema
d35b73c3c1 chore: gitignore hosts_*.json files
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-30 08:25:08 -08:00
Evan Quiney
cd946742f7 fix skipping logic in worker plan (#1342)
the worker plan function had some skipping logic missing, leading to
double-submitting tasks.
2026-01-30 14:31:40 +00:00
rltakashige
a5bc38ad1f Check all nodes to evict (#1341)
## Motivation

If nodes have uneven memory, one node may evict cache that remains on
another node. This will break prefill on some setups.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-30 13:42:09 +00:00
Evan Quiney
2a4e0d4629 make node-ids unique per-session (#1338)
we currently have no strict reuqirements that node ids persist across
sessions, so we can generate fresh nodeids each time

this avoids issues like #1332, but prevents further features such as
caching downloads or node-id dialling

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-30 13:33:31 +00:00
Evan Quiney
46a14153dd switch to ModelCard.load outside of download log (#1339)
some attempts to load model cards (i.e. build_base_shard) always went
through networking rather than using downloaded model cards. we should
always default to ModelCard.load in these scenarios
2026-01-30 11:20:20 +00:00
Evan
9ba61f3733 improve log message in shard downloader
closes #1336
2026-01-30 10:35:01 +00:00
rltakashige
d9eca75895 Add usage stats (#1333)
## Motivation

(Probably) the final missing piece of the Chat Completions API 

## Changes

Add UsageStats 

## Why It Works

OpenCode reviewed my PR and gave me stats:

<img width="1150" height="802" alt="image"
src="https://github.com/user-attachments/assets/ebc06bae-797f-4087-87d5-2f26cf60fc48"
/>


## Test Plan

### Automated Testing
No tests were broken.
2026-01-30 10:23:08 +00:00
rltakashige
9dabde7e57 Fix bench after recent updates (#1331)
## Motivation

A lot of changes happened without much attention to the state of exo
bench.

## Changes

Use TaggedModel for BenchChatCompletion so it serialises properly.
Don't break after gpt oss tool call to preserve parity with the rest of
the codebase.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<img width="2856" height="678" alt="image"
src="https://github.com/user-attachments/assets/2e18cf0d-c0f8-467c-9763-1a6a59c8a327"
/>

Also tested GPT OSS tool calling in OpenCode
2026-01-29 19:14:40 +00:00
ciaranbor
a31942ce12 Ciaran/image non streaming (#1328)
## Motivation

The dashboard UI attempted to parse all image generation responses as
SSE streams, even when streaming was disabled. This broke non-streaming
image generation.

## Changes

- Parse JSON responses directly when not streaming, use SSE parser only
when stream=true AND partialImages > 0
- explicitly disable partial images when not streaming

## Why It Works

Both API and dashboard now use the same condition (stream &&
partialImages > 0) to determine response format, ensuring correct
parsing.

## Test Plan

### Manual Testing

Non-streamed image generation results appear in the UI. Streamed image
generation still works
2026-01-29 17:24:32 +00:00
Alex Cheema
7cc313b22a Treat Swift/Xcode build warnings as errors (#1322)
## Motivation

Warnings that go unchecked tend to accumulate and hide real issues.
Treating them as errors ensures they are addressed immediately, both
locally during development and in CI.

## Changes

Added `SWIFT_TREAT_WARNINGS_AS_ERRORS = YES` and
`GCC_TREAT_WARNINGS_AS_ERRORS = YES` to the **project-level** Debug and
Release build configurations in `project.pbxproj`. This applies to all
targets (EXO, EXOTests, EXOUITests).

## Why It Works

Xcode's `SWIFT_TREAT_WARNINGS_AS_ERRORS` and
`GCC_TREAT_WARNINGS_AS_ERRORS` build settings promote Swift and C/ObjC
warnings to errors at compile time. Setting them at the project level
means all targets inherit the policy without needing per-target or
CI-level overrides.

## Test Plan

### Manual Testing
- Built the EXO scheme in Release configuration with `xcodebuild` — no
warning-as-error failures from Swift or C/ObjC sources.

### Automated Testing
- CI already builds with `-configuration Release`, so it will
automatically enforce warnings-as-errors via the inherited project
settings — no CI changes needed.
2026-01-29 17:15:49 +00:00
rltakashige
2837225dc7 Load pipeline layers sequentially (#1329)
## Motivation

Slightly annoyed by needing this change, but same story as for tensor
loading...
2026-01-29 17:08:38 +00:00
Jake Hillion
e4c6a7dbb4 nix: add Python packaging with uv2nix
Add uv2nix to build Python packages from uv.lock. This creates a fully
Nix-managed Python environment with the Rust bindings injected via overlay.

Changes:
- Add pyproject-nix, uv2nix, and pyproject-build-systems flake inputs
- Create python/parts.nix with overlays to inject Nix-built Rust wheel
- Export packages.exo on macOS (wraps exo/exo-master/exo-worker with dashboard)
- Add checks.lint (ruff, all platforms) and checks.pytest (macOS only)
- Simplify CI typecheck job using nicknovitski/nix-develop action
- Delete .github/actions/typecheck composite action (no longer needed)
- Add no-build-package for MLX packages in pyproject.toml (use wheels)

The Python build is currently macOS-only since MLX requires Metal. Linux
support will be added once the pyproject dependencies are simplified.

Test plan:
- Run `nix flake check` on macOS to verify pytest and lint pass
- Build exo package on macOS: `nix build .#exo`
- Verify CI pipeline passes with simplified typecheck job
2026-01-29 16:35:58 +00:00
28 changed files with 647 additions and 363 deletions

View File

@@ -1,12 +0,0 @@
name: Type Check
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

View File

@@ -26,73 +26,14 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Sync dependencies
run: uv sync --all-packages
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
nix:
name: Build and check (${{ matrix.system }})
@@ -191,3 +132,14 @@ jobs:
- name: Run nix flake check
run: nix flake check
- name: Run pytest (macOS only)
if: runner.os == 'macOS'
run: |
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

3
.gitignore vendored
View File

@@ -28,3 +28,6 @@ target/
dashboard/build/
dashboard/node_modules/
dashboard/.svelte-kit/
# host config snapshots
hosts_*.json

View File

@@ -342,6 +342,8 @@
SDKROOT = macosx;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Debug;
};
@@ -397,6 +399,8 @@
MTL_FAST_MATH = YES;
SDKROOT = macosx;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Release;
};

View File

@@ -865,7 +865,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -905,7 +904,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1522,7 +1520,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1532,7 +1529,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1945,7 +1941,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2653,7 +2648,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2696,7 +2690,6 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2869,7 +2862,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3014,7 +3006,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3036,7 +3027,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -173,6 +173,11 @@ export interface PlacementPreviewResponse {
previews: PlacementPreview[];
}
interface ImageApiResponse {
created: number;
data: Array<{ b64_json?: string; url?: string }>;
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -2095,107 +2100,137 @@ class AppStore {
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await response.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const numImages = params.numImages;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img, index) => ({
type: "generated-image" as const,
name: `generated-image-${index + 1}.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const numImages = params.numImages;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
},
);
}
} catch (error) {
console.error("Error generating image:", error);
this.handleStreamingError(
@@ -2343,69 +2378,98 @@ class AppStore {
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
}
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img) => ({
type: "generated-image" as const,
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
},
);
}
} catch (error) {
console.error("Error editing image:", error);
this.handleStreamingError(

65
flake.lock generated
View File

@@ -21,7 +21,9 @@
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1765953015,
@@ -149,19 +151,44 @@
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1763662255,
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"lastModified": 1764134915,
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
"type": "github"
},
"original": {
@@ -178,7 +205,10 @@
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"treefmt-nix": "treefmt-nix",
"uv2nix": "uv2nix"
}
},
"rust-analyzer-src": {
@@ -239,6 +269,29 @@
"repo": "treefmt-nix",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1767701098,
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",

View File

@@ -24,6 +24,26 @@
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
};
# Python packaging with uv2nix
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
@@ -48,6 +68,7 @@
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
];
perSystem =
@@ -88,12 +109,6 @@
};
};
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);

93
python/parts.nix Normal file
View File

@@ -0,0 +1,93 @@
{ inputs, ... }:
{
perSystem =
{ config, self', pkgs, lib, system, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = inputs.self;
};
# Create overlay from workspace
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
# Override overlay to inject Nix-built components
exoOverlay = final: prev: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
src = self'.packages.exo_pyo3_bindings;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
};
};
python = pkgs.python313;
# Overlay to provide build systems and custom packages
buildSystemsOverlay = final: prev: {
# Use our pure Nix-built MLX with Metal support
mlx = self'.packages.mlx;
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
exoOverlay
buildSystemsOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
}
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard}
done
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
};
};
}

View File

@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
model_card = await ModelCard.load(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -166,9 +166,8 @@ class ResumableShardDownloader(ShardDownloader):
for task in asyncio.as_completed(tasks):
try:
yield await task
# TODO: except Exception
except Exception as e:
logger.error("Error downloading shard:", e)
logger.warning(f"Error downloading shard: {type(e).__name__}")
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -65,7 +65,9 @@ from exo.shared.types.api import (
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
StreamOptions,
ToolCall,
Usage,
)
from exo.shared.types.chunks import (
ErrorChunk,
@@ -113,7 +115,9 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
chunk: TokenChunk | ToolCallChunk,
command_id: CommandId,
usage: Usage | None,
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -138,21 +142,10 @@ def chunk_to_response(
finish_reason=chunk.finish_reason,
)
],
usage=usage,
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -274,7 +267,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_card=await ModelCard.load(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -291,7 +284,7 @@ class API:
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
model_card = await ModelCard.load(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
@@ -319,7 +312,7 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_card = await ModelCard.load(model_id)
try:
placements = get_instance_placements(
@@ -522,9 +515,10 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, stream_options: StreamOptions | None = None
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
include_usage = stream_options.include_usage if stream_options else False
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
@@ -540,8 +534,10 @@ class API:
yield "data: [DONE]\n\n"
return
usage = chunk.usage if include_usage else None
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
chunk, command_id, usage=usage
)
logger.debug(f"chunk_response: {chunk_response}")
@@ -557,8 +553,9 @@ class API:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
usage: Usage | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ErrorChunk):
@@ -583,6 +580,9 @@ class API:
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.usage is not None:
usage = chunk.usage
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -604,6 +604,7 @@ class API:
finish_reason=finish_reason,
)
],
usage=usage,
)
async def _collect_chat_completion_with_stats(
@@ -611,7 +612,7 @@ class API:
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
@@ -664,7 +665,7 @@ class API:
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
)
@@ -673,7 +674,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await ModelCard.load(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -691,7 +692,7 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, payload.stream_options),
media_type="text/event-stream",
)
@@ -700,7 +701,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await ModelCard.load(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -720,12 +721,12 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(ModelId(model))
model_card = await ModelCard.load(model)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
@@ -771,7 +772,7 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
command = ImageGeneration(
request_params=payload,
@@ -1016,7 +1017,7 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
@@ -1037,7 +1038,7 @@ class API:
self,
image: UploadFile,
prompt: str,
model: str,
model: ModelId,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1132,7 +1133,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,
@@ -1188,7 +1189,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,

View File

@@ -216,6 +216,8 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

View File

@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from loguru import logger
from pytest import LogCaptureFixture
from pytest import LogCaptureFixture, mark
from exo.routing.router import get_node_id_keypair
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
@@ -74,6 +74,7 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
os.remove(p)
@mark.skip(reason="this functionality is currently disabled but may return in future")
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10

View File

@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -116,8 +116,8 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: PromptTokensDetails | None = None
completion_tokens_details: CompletionTokensDetails | None = None
prompt_tokens_details: PromptTokensDetails
completion_tokens_details: CompletionTokensDetails
class StreamingChoiceResponse(BaseModel):
@@ -170,7 +170,13 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class ChatCompletionTaskParams(BaseModel):
class StreamOptions(BaseModel):
include_usage: bool = False
class ChatCompletionTaskParams(TaggedModel):
model_config = ConfigDict(extra="ignore")
model: str
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
@@ -184,6 +190,7 @@ class ChatCompletionTaskParams(BaseModel):
seed: int | None = None
stop: str | list[str] | None = None
stream: bool = False
stream_options: StreamOptions | None = None
temperature: float | None = None
top_p: float | None = None
tools: list[dict[str, Any]] | None = None

View File

@@ -2,7 +2,7 @@ from collections.abc import Generator
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -17,6 +17,7 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
usage: Usage | None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
@@ -28,6 +29,7 @@ class ErrorChunk(BaseChunk):
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
usage: Usage | None
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None

View File

@@ -2,6 +2,7 @@ from pydantic import Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
@@ -22,7 +23,7 @@ class TestCommand(BaseCommand):
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
class ImageGeneration(BaseCommand):

View File

@@ -3,6 +3,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
@@ -54,7 +55,7 @@ class StartWarmup(BaseTask): # emitted by Worker
class ChatCompletion(BaseTask): # emitted by Master
command_id: CommandId
task_params: ChatCompletionTaskParams
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)

View File

@@ -6,6 +6,7 @@ from exo.shared.types.api import (
GenerationStats,
ImageGenerationStats,
ToolCallItem,
Usage,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -24,6 +25,7 @@ class GenerationResponse(BaseRunnerResponse):
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
usage: Usage | None
class ImageGenerationResponse(BaseRunnerResponse):
@@ -57,6 +59,7 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -98,8 +98,8 @@ def generate_image(
partial_images = (
task.partial_images
if task.partial_images is not None
else (3 if task.stream else 0)
if task.partial_images is not None and task.stream is not None and task.stream
else 0
)
image_path: Path | None = None

View File

@@ -348,6 +348,7 @@ class DiffusionRunner:
ctx.in_loop( # pyright: ignore[reportAny]
t=t,
latents=latents,
time_steps=time_steps,
)
mx.eval(latents)

View File

@@ -201,6 +201,9 @@ def pipeline_auto_parallel(
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
for layer in layers:
mx.eval(layer) # type: ignore
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],

View File

@@ -3,6 +3,7 @@ from copy import deepcopy
from typing import Any, cast
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
@@ -12,25 +13,29 @@ from mlx_lm.models.cache import (
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.85
_DEFAULT_MEMORY_THRESHOLD = 0.9
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
class KVPrefixCache:
def __init__(self, tokenizer: TokenizerWrapper):
def __init__(
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
self._group = group
def clear(self):
"""Clear all cached prompts and caches."""
@@ -81,13 +86,13 @@ class KVPrefixCache:
best_snapshot_index, best_snapshot_length = None, 0
for i, cached_prompt in enumerate(self.prompts):
length = _get_prefix_length(tokenized_prompt, cached_prompt)
length = get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
cached_length = cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -109,7 +114,7 @@ class KVPrefixCache:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
cached_length = cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -131,29 +136,37 @@ class KVPrefixCache:
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory pressure is high."""
"""Evict least recently used entries while memory usage is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while len(self.caches) > 0:
while (
len(self.caches) > 1
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
)
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_THRESHOLD:
break
def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_used_percentage()
if self._group is None:
return local_pressure
all_pressure = mx.distributed.all_gather(
mx.array([local_pressure], dtype=mx.float32),
group=self._group,
)
# .item() evals.
max_pressure = float(mx.max(all_pressure).item())
return max_pressure
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
@@ -168,13 +181,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
@@ -185,6 +198,17 @@ def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
return int(mx.sum(prefix_mask).item())
def get_available_memory() -> Memory:
mem: int = psutil.virtual_memory().available
return Memory.from_bytes(mem)
def get_memory_used_percentage() -> float:
mem = psutil.virtual_memory()
# percent is 0-100
return float(mem.percent / 100)
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> KVCacheType:

View File

@@ -10,8 +10,11 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
CompletionTokensDetails,
FinishReason,
GenerationStats,
PromptTokensDetails,
Usage,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
@@ -39,7 +42,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> float:
) -> tuple[float, int]:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
@@ -50,7 +53,7 @@ def prefill(
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
return 0.0, 0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
@@ -85,7 +88,7 @@ def prefill(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
return tokens_per_sec, num_tokens
def warmup_inference(
@@ -169,6 +172,8 @@ def mlx_generate(
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
logger.info(f"{is_bench=}")
# Currently we support chat-completion tasks only.
logger.debug(f"task_params: {task}")
@@ -204,7 +209,9 @@ def mlx_generate(
)
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
prefill_tps, prefill_tokens = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches
)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
@@ -212,28 +219,43 @@ def mlx_generate(
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
usage: Usage | None = None
in_thinking = False
reasoning_tokens = 0
think_start = tokenizer.think_start
think_end = tokenizer.think_end
for completion_tokens, out in enumerate(
stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
),
start=1,
):
generated_text_parts.append(out.text)
logger.info(out.text)
if think_start is not None and out.text == think_start:
in_thinking = True
elif think_end is not None and out.text == think_end:
in_thinking = False
if in_thinking:
reasoning_tokens += 1
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
@@ -245,11 +267,24 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
usage = Usage(
prompt_tokens=int(out.prompt_tokens),
completion_tokens=completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),
completion_tokens_details=CompletionTokensDetails(
reasoning_tokens=reasoning_tokens
),
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
usage=usage,
)
if out.finish_reason is not None:

View File

@@ -37,6 +37,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -111,8 +112,12 @@ def main(
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
seen = set[TaskId]()
with task_receiver as tasks:
for task in tasks:
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -163,7 +168,7 @@ def main(
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache(tokenizer, group)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
@@ -277,9 +282,11 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
completion_tokens = 0
for response in mlx_generator:
match response:
case GenerationResponse():
completion_tokens += 1
if (
device_rank == 0
and response.finish_reason == "error"
@@ -307,6 +314,7 @@ def main(
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
usage=response.usage,
finish_reason=response.finish_reason,
stats=response.stats,
),
@@ -320,6 +328,7 @@ def main(
chunk=ToolCallChunk(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
usage=response.usage,
),
)
)
@@ -535,10 +544,10 @@ def parse_gpt_oss(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
]
],
usage=response.usage,
)
tool_arg_parts = []
break
current_tool_name = recipient
# If inside a tool call, accumulate arguments
@@ -684,7 +693,7 @@ def parse_tool_calls(
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
except (
json.JSONDecodeError,

View File

@@ -127,20 +127,25 @@ class RunnerSupervisor:
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(
f"Skipping invalid task {task} as it has already been submitted"
)
return
if task.task_id in self.completed:
logger.info(
logger.warning(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
try:
self._task_sender.send(task)
await self._task_sender.send_async(task)
except ClosedResourceError:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:

View File

@@ -14,9 +14,9 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
_cache_length,
_get_prefix_length,
cache_length,
encode_prompt,
get_prefix_length,
make_kv_cache,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
@@ -35,47 +35,47 @@ class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 5
assert get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert _get_prefix_length(a, b) == 1
assert get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
class TestKVPrefix:
@@ -146,7 +146,7 @@ class TestKVPrefixCacheWithModel:
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert _cache_length(cache) == len(tokens)
assert cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
@@ -166,7 +166,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache.add_kv_cache(prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
@@ -209,7 +209,7 @@ class TestKVPrefixCacheWithModel:
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
expected_prefix = get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
@@ -243,7 +243,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -259,7 +259,7 @@ class TestKVPrefixCacheWithModel:
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
@@ -281,7 +281,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -293,7 +293,7 @@ class TestKVPrefixCacheWithModel:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
@@ -325,7 +325,7 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
@@ -400,7 +400,7 @@ class TestKVPrefixCacheWithModel:
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
first_cache_length = cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
@@ -416,7 +416,7 @@ class TestKVPrefixCacheWithModel:
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
@@ -440,7 +440,7 @@ class TestKVPrefixCacheWithModel:
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
@@ -465,7 +465,7 @@ class TestKVPrefixCacheWithModel:
):
pass
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
firstcache_length = cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
@@ -478,7 +478,7 @@ class TestKVPrefixCacheWithModel:
pass
# The first stored cache must not have been mutated by the second generation
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
@@ -540,6 +540,6 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

View File

@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
@@ -120,7 +120,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
@@ -147,6 +147,14 @@ class MockTokenizer:
has_tool_calling = False
class MockGroup:
def rank(self) -> int:
return 0
def size(self) -> int:
return 1
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -182,6 +190,8 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
text="hi",
token_id=0,
finish_reason="stop",
usage=None,
stats=None,
),
)

View File

@@ -0,0 +1,18 @@
{
"$schema": "https://opencode.ai/config.json",
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
"provider": {
"exo": {
"api": "http://localhost:52415/v1",
"models": {
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
"name": "GPT OSS 120B",
"limit": {
"context": 32768,
"output": 8192
}
}
}
}
}
}