Compare commits

...

14 Commits

Author SHA1 Message Date
Alex Cheema
8fc9b13a51 feat: add continuous batching for distributed inference
Implements continuous batching using mlx_lm's BatchGenerator for efficient
multi-request handling in distributed mode.

Key changes:
- Add BatchGenerationEngine that wraps mlx_lm's BatchGenerator for continuous
  batching with prefill batching (up to 8 requests) and decode batching
- Add TimeBudget pattern for controlling generation loop timing with periodic
  distributed sync
- Add distributed_sync utilities for broadcasting objects across ranks using
  mx.distributed.all_sum()
- Stream tokens immediately as generated for smooth streaming (not in batches)
- Fix distributed correctness: deferred shutdown handling, sync_completions
  always syncs in distributed mode to prevent deadlocks

Performance results on Kimi K2 Thinking (658GB) with Tensor RDMA:
- Batch 1:  10.7 tok/s (baseline)
- Batch 4:  34.6 tok/s (3.2x speedup)
- Batch 16: 41.8 tok/s (3.9x speedup)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 20:12:48 +00:00
Sami Khan
6e6567a802 resolve issue #1070 (#1076)
## Motivation

https://github.com/exo-explore/exo/issues/1070

## Changes

Added check in ChatForm.svelte to reset selectedChatModel when it no
longer matches any running instance.

## Why It Works

The $effect now detects when the selected model is stale (not in
availableModels()) and resets to the first available model.

## Test Plan

### Manual Testing

1. Create instance of Model A → Delete it → Create instance of Model B →
Chat
2. Verify request goes to Model B (not Model A)

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
2026-01-15 20:00:41 +00:00
rltakashige
a735dad667 Parse GPT OSS in runner (#1160)
## Motivation

Simplification of API + moving model specific code to the runner

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Test Plan

### Manual Testing
Tested that GPT OSS outputs are parsed correctly on the dashboard.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-15 19:53:55 +00:00
rltakashige
aaf4e36bc3 FIX GPT OSS (#1165)
## Motivation

Adds several unmerged fixes for GPT OSS.
Also adds GPT OSS 20B MXFP4 Q8 instead of Q4 for numerical stability (as
this is unstable for MLX LM too)
<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->


## Test Plan

### Manual Testing
Manually tested. No further gibberish responses.

### Automated Testing
Ran EXO Bench - pipeline, tensor and single node work on both 20B and
120B models
2026-01-15 19:20:17 +00:00
Evan Quiney
3e623ccf0d up http timeout to 3 seconds and retry on BadStatusLine (#1164)
we're seeing a lot of network churn - perhaps this is a connection
timing out issue? lets also re-try after a second

## testing
none yet

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 18:15:12 +00:00
Evan Quiney
c22dad8a7d dashboard: add peer: true to package lock (#1162)
this happens every time i run npm install - lets upstream it

## testing
dashboard builds and renders
2026-01-15 17:01:43 +00:00
Evan
4bc4d50685 rust: remove dead code
the system custodian has been made unnecessary with the swift app - we
can remove it

## testing
everything still builds
2026-01-15 16:51:46 +00:00
Jake Hillion
e0aab46fd8 model_cards.py: clean up commented out code
Clean up the commented out code and make sure the comments are unified.
Carrying around the commented out code means people making changes to
model_cards are supposed to update it, but that's not clear and won't be
picked up by type checking etc. Drop it for now - it's in the git
history.

Also make the rest of the comments a bit more uniform, and place
comments about a specific model card inside the model card (instead of
above) so they don't get lost when code is added/moved around.

Test plan:
- my eyes
2026-01-15 13:21:58 +00:00
Evan Quiney
82ba42bae9 add glm-47, minimax-m21 (#1147)
Adds support glm 4.7 and MiniMax M2.1

Manual testing:
Tensor + Pipeline execution of both models.

Closes #1141 and #1142
2026-01-14 16:33:17 +00:00
Jake Hillion
3671528fa4 nix: add dashboard build with dream2nix
Continue working towards a fully Nix based build by building the
dashboard with Nix. Continuing the theme of using the existing lock
files, use dream2nix to parse the lock file and build the tree of
dependency derivations.

dream2nix doesn't like the bundleDependencies, so we apply a small patch
to the lock file that drops all dependencies that are bundled. This
should ideally be contributed upstream but that can be done later.

Use this new dashboard build in the build-app CI workflow, meaning
future macOS apps will include this reproducible dashboard.

Test plan:
- Built a DMG, shipped to a cluster, loaded in a browser with no cache
  and the dashboard looks good.

- Directory layout is as expected:
```
$ nix build .#dashboard
$ find result/
...
result/_app/immutable/entry
result/_app/immutable/entry/app.CTPAnMjf.js
result/_app/immutable/entry/start.fUSEa-2O.js
result/_app/immutable/nodes
result/_app/immutable/nodes/3.DqQr1Obm.js
result/_app/immutable/nodes/0.DgEY44RO.js
result/_app/immutable/nodes/2.BjZg_lJh.js
result/_app/immutable/nodes/1.D6vGUYYT.js
result/_app/env.js
result/_app/version.json
result/exo-logo.png
result/favicon.ico
result/index.html
```
2026-01-14 15:58:16 +01:00
Jake Hillion
e6434ec446 nix: add Rust builds with crane and fenix
The Rust workspace lacked Nix build support, making it difficult to
build packages reproducibly or run checks in CI.

Added a flake-parts module at rust/parts.nix that uses crane for Rust
builds and fenix for the nightly toolchain. The source filter isolates
rust/ and root Cargo files to prevent Python/docs changes from
triggering Rust rebuilds. Exports packages (system_custodian,
exo_pyo3_bindings wheel, exo-rust-workspace) and checks (cargo-nextest,
cargo-doc) for all three target platforms.

The devShell now uses inputsFrom to inherit build dependencies from the
workspace package, removing the need for manual pkg-config/openssl setup.

Test plan:
- Ran `nix flake check` successfully
- Built `nix build ".#checks.x86_64-linux.cargo-nextest"` and tests pass
- Built `nix build ".#exo_pyo3_bindings"` and wheel is produced
2026-01-14 11:52:29 +00:00
Jake Hillion
bdb43e1dbb nix: drop noisy echos from devshell
Drop all the printing when entering a devshell. It's annoying, and not a
super accurate description of how to develop exo anyway.
2026-01-14 10:04:57 +00:00
Jake Hillion
e4a01e2b0e chore(deps): nix lock file maintenance
Update nix flake inputs. Add a second input as Swift is currently broken
in nixpkgs on Linux for `swift-format` as we want `nix fmt` to continue
being reproducible everywhere.
2026-01-13 19:57:14 +01:00
Evan Quiney
1200a7db64 Add tensor sharding for GPT-OSS (#1144)
## Motivation

GPT OSS did not previously support tensor sharding

## Changes

Add GPT sharding support in tensor_auto_parallel.
Code is mostly @rltakashige's

## Test Plan

### Manual Testing
Tested GPT-OSS - MLX Fast Sync causes issues in Tensor RDMA - this is a general problem at the moment.
2026-01-13 17:25:52 +00:00
34 changed files with 2055 additions and 562 deletions

View File

@@ -113,11 +113,22 @@ jobs:
uv python install
uv sync --locked
- name: Install Nix
uses: cachix/install-nix-action@v31
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Configure Cachix
uses: cachix/cachix-action@v14
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build dashboard
run: |
cd dashboard
npm ci
npm run build
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
mkdir -p dashboard/build
cp -r "$DASHBOARD_OUT"/* dashboard/build/
- name: Install Sparkle CLI
run: |

View File

@@ -276,24 +276,23 @@ class BatchGenerator:
logprobs: mx.array
finish_reason: Optional[str]
unprocessed_prompts: List[Any]
def __init__(
self,
model,
model: nn.Module,
max_tokens: int = ...,
stop_tokens: Optional[set] = ...,
stop_tokens: Optional[set[int]] = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
completion_batch_size: int = ...,
prefill_batch_size: int = ...,
prefill_step_size: int = ...,
) -> None: ...
def insert(
self, prompts, max_tokens: Union[List[int], int, None] = ...
): # -> list[Any]:
...
def stats(self): # -> BatchStats:
...
def next(self): # -> list[Any]:
...
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
) -> List[int]: ...
def stats(self) -> BatchStats: ...
def next(self) -> List[Response]: ...
def batch_generate(
model,

View File

@@ -39,12 +39,18 @@ class StreamingDetokenizer:
"""
__slots__ = ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
tokens: list[int]
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
@property
def last_segment(self):
def text(self) -> str:
"""The full text decoded so far."""
...
@property
def last_segment(self) -> str:
"""Return the last segment of readable text since last time this property was accessed."""
...
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
@@ -108,6 +114,7 @@ class TokenizerWrapper:
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
eos_token_ids: list[int] | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int

View File

@@ -91,6 +91,45 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Model Storage
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
## Creating Model Instances via API
When testing with the API, you must first create a model instance before sending chat completions:
```bash
# 1. Get instance previews for a model
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
# 2. Create an instance from the first valid preview
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
# 3. Wait for the runner to become ready (check logs for "runner ready")
# 4. Send chat completions using the full model ID
curl -X POST http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
```
## Logs
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
### Distributed Testing
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
```bash
# On each machine in the test cluster, use the same unique namespace
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
```
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.

19
Cargo.lock generated
View File

@@ -4340,25 +4340,6 @@ dependencies = [
"libc",
]
[[package]]
name = "system_custodian"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
name = "tagptr"
version = "0.2.0"

View File

@@ -3,7 +3,6 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
]
@@ -25,7 +24,6 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools

60
dashboard/dashboard.nix Normal file
View File

@@ -0,0 +1,60 @@
{ lib
, config
, dream2nix
, ...
}:
let
# Read and parse the lock file
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
# For packages with bundleDependencies, filter out deps that are bundled
# (bundled deps are inside the tarball, not separate lockfile entries)
fixedPackages = lib.mapAttrs
(path: entry:
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
then entry // {
dependencies = lib.filterAttrs
(name: _: !(lib.elem name entry.bundleDependencies))
(entry.dependencies or { });
}
else entry
)
(rawLockFile.packages or { });
fixedLockFile = rawLockFile // { packages = fixedPackages; };
in
{
imports = [
dream2nix.modules.dream2nix.nodejs-package-lock-v3
dream2nix.modules.dream2nix.nodejs-granular-v3
];
name = "exo-dashboard";
version = "1.0.0";
mkDerivation = {
src = config.deps.dashboardSrc;
buildPhase = ''
runHook preBuild
npm run build
runHook postBuild
'';
installPhase = ''
runHook preInstall
cp -r build $out/build
runHook postInstall
'';
};
deps = { nixpkgs, ... }: {
inherit (nixpkgs) stdenv;
dashboardSrc = null; # Injected by parts.nix
};
nodejs-package-lock-v3 = {
# Don't use packageLockFile - provide the fixed lock content directly
packageLock = fixedLockFile;
};
}

View File

@@ -863,6 +863,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -902,6 +903,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,6 +1520,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1527,6 +1530,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1939,6 +1943,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2646,6 +2651,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2833,6 +2839,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2977,6 +2984,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2998,6 +3006,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

44
dashboard/parts.nix Normal file
View File

@@ -0,0 +1,44 @@
{ inputs, ... }:
{
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to only include dashboard directory
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
inDashboardDir =
(lib.hasInfix "/dashboard/" path)
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|| (baseName == "dashboard" && type == "directory");
in
inDashboardDir;
};
# Build the dashboard with dream2nix (includes node_modules in output)
dashboardFull = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
# Inject the filtered source
{
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
}
];
};
in
{
# Extract just the static site from the full build
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
cp -r ${dashboardFull}/build $out
'';
};
}

View File

@@ -60,12 +60,39 @@
return models;
});
// Auto-select the first available model if none is selected
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
$effect(() => {
const models = availableModels();
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -400,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -763,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -771,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);

162
flake.lock generated
View File

@@ -1,5 +1,42 @@
{
"nodes": {
"crane": {
"locked": {
"lastModified": 1767744144,
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
"owner": "ipetkov",
"repo": "crane",
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
"type": "github"
},
"original": {
"owner": "ipetkov",
"repo": "crane",
"type": "github"
}
},
"dream2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
"owner": "nix-community",
"repo": "dream2nix",
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "dream2nix",
"type": "github"
}
},
"fenix": {
"inputs": {
"nixpkgs": [
@@ -8,11 +45,11 @@
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1761893049,
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
"lastModified": 1768287139,
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
"owner": "nix-community",
"repo": "fenix",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
"type": "github"
},
"original": {
@@ -21,6 +58,22 @@
"type": "github"
}
},
"flake-compat": {
"flake": false,
"locked": {
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"type": "github"
},
"original": {
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-parts": {
"inputs": {
"nixpkgs-lib": [
@@ -43,11 +96,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
"lastModified": 1768127708,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"type": "github"
},
"original": {
@@ -57,22 +110,85 @@
"type": "github"
}
},
"nixpkgs-swift": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
},
"original": {
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
}
},
"purescript-overlay": {
"inputs": {
"flake-compat": "flake-compat",
"nixpkgs": [
"dream2nix",
"nixpkgs"
],
"slimlock": "slimlock"
},
"locked": {
"lastModified": 1728546539,
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
"type": "github"
},
"original": {
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"type": "github"
}
},
"root": {
"inputs": {
"crane": "crane",
"dream2nix": "dream2nix",
"fenix": "fenix",
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1761849405,
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
"lastModified": 1768224240,
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
"rev": "725349602e525df37f377701e001fe8aab807878",
"type": "github"
},
"original": {
@@ -82,6 +198,28 @@
"type": "github"
}
},
"slimlock": {
"inputs": {
"nixpkgs": [
"dream2nix",
"purescript-overlay",
"nixpkgs"
]
},
"locked": {
"lastModified": 1688756706,
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
"owner": "thomashoneyman",
"repo": "slimlock",
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
"type": "github"
},
"original": {
"owner": "thomashoneyman",
"repo": "slimlock",
"type": "github"
}
},
"treefmt-nix": {
"inputs": {
"nixpkgs": [
@@ -89,11 +227,11 @@
]
},
"locked": {
"lastModified": 1762938485,
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
"lastModified": 1768158989,
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
"owner": "numtide",
"repo": "treefmt-nix",
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
"type": "github"
},
"original": {

View File

@@ -9,6 +9,8 @@
inputs.nixpkgs-lib.follows = "nixpkgs";
};
crane.url = "github:ipetkov/crane";
fenix = {
url = "github:nix-community/fenix";
inputs.nixpkgs.follows = "nixpkgs";
@@ -18,6 +20,14 @@
url = "github:numtide/treefmt-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
};
nixConfig = {
@@ -36,12 +46,16 @@
imports = [
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
];
perSystem =
{ config, inputs', pkgs, lib, ... }:
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
treefmt = {
@@ -54,13 +68,16 @@
};
rustfmt = {
enable = true;
package = fenixToolchain.rustfmt;
package = config.rust.toolchain;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format.enable = true;
swift-format = {
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
};
};
@@ -71,6 +88,8 @@
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];
packages =
[
# FORMATTING
@@ -83,14 +102,8 @@
basedpyright
# RUST
(fenixToolchain.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
config.rust.toolchain
maturin
# NIX
nixpkgs-fmt
@@ -102,30 +115,20 @@
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
++ lib.optionals stdenv.isLinux [
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
]
++ lib.optionals stdenv.isDarwin [
macmon
]);
];
OPENSSL_NO_VENDOR = "1";
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
${lib.optionalString stdenv.isLinux ''
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list
'';
};
};

View File

@@ -1,3 +1,5 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

145
rust/parts.nix Normal file
View File

@@ -0,0 +1,145 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
"rust-analyzer"
];
# Crane with fenix toolchain
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
# Source filtering - only include rust/ directory and root Cargo files
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
parentDir = builtins.dirOf path;
inRustDir =
(lib.hasInfix "/rust/" path)
|| (lib.hasSuffix "/rust" parentDir)
|| (baseName == "rust" && type == "directory");
isRootCargoFile =
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
&& (builtins.dirOf path == toString inputs.self);
in
isRootCargoFile
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
};
# Common arguments for all Rust builds
commonArgs = {
inherit src;
pname = "exo-rust";
version = "0.0.1";
strictDeps = true;
nativeBuildInputs = [
pkgs.pkg-config
pkgs.python313 # Required for pyo3-build-config
];
buildInputs = [
pkgs.openssl
pkgs.python313 # Required for pyo3 tests
];
OPENSSL_NO_VENDOR = "1";
# Required for pyo3 tests to find libpython
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
};
# Build dependencies once for caching
cargoArtifacts = craneLib.buildDepsOnly (
commonArgs
// {
cargoExtraArgs = "--workspace";
}
);
in
{
# Export toolchain for use in treefmt and devShell
options.rust = {
toolchain = lib.mkOption {
type = lib.types.package;
default = rustToolchain;
description = "The Rust toolchain to use";
};
};
config = {
packages = {
# Python bindings wheel via maturin
exo_pyo3_bindings = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
pname = "exo_pyo3_bindings";
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
pkgs.maturin
];
buildPhaseCargoCommand = ''
maturin build \
--release \
--manylinux off \
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
--features "pyo3/extension-module,pyo3/experimental-async" \
--interpreter ${pkgs.python313}/bin/python \
--out dist
'';
# Don't use crane's default install behavior
doNotPostBuildInstallCargoBinaries = true;
installPhaseCommand = ''
mkdir -p $out
cp dist/*.whl $out/
'';
}
);
};
checks = {
# Full workspace build (all crates)
cargo-build = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Run tests with nextest
cargo-nextest = craneLib.cargoNextest (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Build documentation
cargo-doc = craneLib.cargoDoc (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
};
};
};
}

View File

@@ -1,47 +0,0 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -1,4 +0,0 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -1,69 +0,0 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -13,12 +13,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -67,8 +61,6 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -381,35 +373,8 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -417,16 +382,10 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -442,11 +401,11 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -458,7 +417,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -466,7 +425,7 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if model is None:
model = chunk.model
@@ -495,7 +454,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -503,7 +462,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if model is None:
model = chunk.model
@@ -544,8 +503,6 @@ class API:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -562,17 +519,16 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
return await self._collect_chat_completion(command.command_id)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -589,10 +545,7 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
def _calculate_total_available_memory(self) -> Memory:

View File

@@ -14,32 +14,6 @@ class ModelCard(CamelCaseModel):
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
# "deepseek-v3-0324:4bit": ModelCard(
# short_id="deepseek-v3-0324:4bit",
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
# name="DeepSeek V3 0324 (4-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
# pretty_name="DeepSeek V3 0324 (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# ),
# ),
# "deepseek-v3-0324": ModelCard(
# short_id="deepseek-v3-0324",
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
# name="DeepSeek V3 0324 (8-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
# pretty_name="DeepSeek V3 0324 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# ),
# ),
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -70,63 +44,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# "deepseek-v3.2": ModelCard(
# short_id="deepseek-v3.2",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# name="DeepSeek V3.2 (8-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# pretty_name="DeepSeek V3.2 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-v3.2-4bit": ModelCard(
# short_id="deepseek-v3.2-4bit",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# name="DeepSeek V3.2 (4-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# pretty_name="DeepSeek V3.2 (4-bit)",
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# deepseek r1
# "deepseek-r1-0528-4bit": ModelCard(
# short_id="deepseek-r1-0528-4bit",
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
# name="DeepSeek-R1-0528 (4-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
# pretty_name="DeepSeek R1 671B (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-r1-0528": ModelCard(
# short_id="deepseek-r1-0528",
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
# name="DeepSeek-R1-0528 (8-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
# pretty_name="DeepSeek R1 671B (8-bit)",
# storage_size=Memory.from_bytes(754998771712),
# n_layers=61,
# . hidden_size=7168,
# ),
# ),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
@@ -508,23 +425,24 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# Needs to be quantized g32 or g16.
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
@@ -554,19 +472,81 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# storage_size=Memory.from_kb(133_000_000),
# n_layers=88,
# hidden_size=12288,
# supports_tensor=True,
# ),
# ),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
}

View File

@@ -50,7 +50,9 @@ class RunnerReady(BaseRunnerStatus):
class RunnerRunning(BaseRunnerStatus):
pass
"""Runner is processing requests and can accept more (continuous batching)."""
active_requests: int = 0
class RunnerShuttingDown(BaseRunnerStatus):

View File

@@ -10,18 +10,24 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.cache import (
_BaseCache, # pyright: ignore[reportPrivateUsage]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
)
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
class _LayerCallable(Protocol):
@@ -91,8 +97,6 @@ class PipelineLastLayer(CustomMlxLayer):
x, *args, **kwargs
).arguments.get("cache", None)
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
@@ -100,7 +104,6 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# This change happened upstream - check out mlx github somewhere??
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
@@ -132,24 +135,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
return layers
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(model, DeepseekV3Model) and hasattr(
inner_model_instance, "num_layers"
):
inner_model_instance.start_idx = 0
inner_model_instance.end_idx = len(layers)
inner_model_instance.num_layers = len(layers)
elif hasattr(inner_model_instance, "h"):
inner_model_instance.h = layers
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
def pipeline_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
@@ -165,8 +150,7 @@ def pipeline_auto_parallel(
"""
inner_model_instance: nn.Module = _inner_model(model)
# Handle both model.layers and model.h cases
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
@@ -180,6 +164,17 @@ def pipeline_auto_parallel(
group=group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -204,18 +199,44 @@ def tensor_auto_parallel(
group=group,
)
segments: int = 1
def _all_to_sharded(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - all to sharded")
return weight.ndim - 1, segments
return max(weight.ndim - 2, 0), segments
all_to_sharded_linear_in_place = partial(
shard_inplace,
sharding="all-to-sharded",
group=group,
)
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding="sharded-to-all",
sharding=_all_to_sharded, # type: ignore
group=group,
)
if isinstance(model, LlamaModel):
n = group.size()
def _sharded_to_all(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - sharded to all")
weight /= n
return None
return -1, segments
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding=_sharded_to_all, # type: ignore
group=group,
)
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -223,7 +244,8 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, DeepseekV3Model):
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -231,7 +253,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Qwen3MoeModel):
elif isinstance(model, MiniMaxModel):
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -239,6 +269,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -284,6 +323,32 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
) and hasattr(inner_model_instance, "num_layers"):
logger.info(
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
)
inner_model_instance.start_idx = 0
inner_model_instance.end_idx = len(layers)
inner_model_instance.num_layers = len(layers)
elif isinstance(model, Qwen3MoeModel):
logger.info(
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
)
inner_model_instance.num_hidden_layers = len(layers)
elif hasattr(inner_model_instance, "h"):
inner_model_instance.h = layers
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
@@ -304,7 +369,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.num_heads //= self.N
# Shard the MLP
if isinstance(layer.mlp, DeepseekV3MLP):
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
@@ -338,6 +403,35 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.block_sparse_moe.switch_mlp.down_proj
)
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(Qwen3MoeModel, model)
@@ -352,11 +446,13 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
if isinstance(
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # type: ignore
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -380,3 +476,50 @@ class ShardedQwenMoE(CustomMlxLayer):
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn.num_key_value_groups = (
layer.self_attn.num_attention_heads
// layer.self_attn.num_key_value_heads
)
layer.self_attn.sinks = layer.self_attn.sinks[
layer.self_attn.num_attention_heads
* self.group.rank() : layer.self_attn.num_attention_heads
* (self.group.rank() + 1)
]
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
return model
class ShardedGptOssMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -0,0 +1,302 @@
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
import time
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import MAX_TOKENS
from exo.worker.engines.mlx.generator.distributed_sync import share_object
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.runner.bootstrap import logger
@dataclass
class ActiveRequest:
"""Tracks an active request in the batch."""
command_id: CommandId
task_id: TaskId
uid: int # BatchGenerator's internal ID
detokenizer: StreamingDetokenizer
tokens_generated: int = 0
prompt_tokens: int = 0
start_time: float = field(default_factory=time.perf_counter)
@dataclass
class BatchedGenerationResponse:
"""Response from batch engine, tagged with command_id and task_id."""
command_id: CommandId
task_id: TaskId
response: GenerationResponse
class BatchGenerationEngine:
"""Manages continuous batching using mlx_lm's BatchGenerator."""
def __init__(
self,
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
max_tokens: int = MAX_TOKENS,
completion_batch_size: int = 32,
prefill_batch_size: int = 8,
prefill_step_size: int = 2048,
):
self.model = model
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.active_requests: dict[int, ActiveRequest] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, ChatCompletionTaskParams]
] = []
self._pending_completions: list[
int
] = [] # UIDs completed but not yet synced/removed
self.group = group
self.rank = group.rank() if group else 0
self.is_distributed = group is not None and group.size() > 1
sampler = make_sampler(temp=0.7, top_p=1.0)
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
self.batch_gen: BatchGenerator = BatchGenerator(
model=model,
max_tokens=max_tokens,
stop_tokens=eos_tokens,
sampler=sampler,
completion_batch_size=completion_batch_size,
prefill_batch_size=prefill_batch_size,
prefill_step_size=prefill_step_size,
)
logger.info(
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
)
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion. Only rank 0 should call this.
In distributed mode, rank 0 receives tasks from the control plane and
queues them here. The actual insertion happens in sync_and_insert_pending()
which ensures all ranks insert the same requests together.
"""
assert self.rank == 0, "Only rank 0 should queue requests"
self._pending_inserts.append((command_id, task_id, task_params))
logger.info(
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})"
)
def sync_and_insert_pending(self) -> list[int]:
"""Sync pending inserts across ranks and insert them. Returns UIDs.
This method ensures all ranks insert the same requests in the same order.
In non-distributed mode, it simply inserts all pending requests.
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
Batches all pending inserts into a single batch_gen.insert() call for
efficient prefill batching.
"""
inserts_to_process: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]]
if not self.is_distributed:
# Non-distributed: just insert directly from pending
inserts_to_process = list(self._pending_inserts)
else:
# Distributed: broadcast pending inserts from rank 0 to all ranks
assert self.group is not None
pending_data = self._pending_inserts if self.rank == 0 else None
synced_data = share_object(pending_data, self.rank, self.group)
if synced_data is None:
self._pending_inserts.clear()
return []
inserts_to_process = synced_data
if not inserts_to_process:
self._pending_inserts.clear()
return []
# Prepare all requests for batched insertion
all_tokens: list[list[int]] = []
all_max_tokens: list[int] = []
all_prompt_tokens: list[int] = []
request_info: list[tuple[CommandId, TaskId]] = []
for cmd_id, task_id, params in inserts_to_process:
prompt_str = apply_chat_template(self.tokenizer, params)
tokens: list[int] = self.tokenizer.encode(
prompt_str, add_special_tokens=False
)
max_tokens = params.max_tokens or self.max_tokens
all_tokens.append(tokens)
all_max_tokens.append(max_tokens)
all_prompt_tokens.append(len(tokens))
request_info.append((cmd_id, task_id))
# Single batched insert for efficient prefill
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
# Track all inserted requests
for i, uid in enumerate(uids):
cmd_id, task_id = request_info[i]
self.active_requests[uid] = ActiveRequest(
command_id=cmd_id,
task_id=task_id,
uid=uid,
detokenizer=self.tokenizer.detokenizer,
prompt_tokens=all_prompt_tokens[i],
)
logger.info(
f"Inserted request {cmd_id} with uid={uid}, prompt_tokens={all_prompt_tokens[i]}, max_tokens={all_max_tokens[i]}"
)
self._pending_inserts.clear()
return uids
def step(self) -> list[BatchedGenerationResponse]:
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
responses = self.batch_gen.next()
if not responses:
return []
results: list[BatchedGenerationResponse] = []
for r in responses:
uid: int = r.uid
req = self.active_requests.get(uid)
if req is None:
logger.warning(f"Received response for unknown uid={uid}")
continue
req.tokens_generated += 1
# Decode the token
token: int = r.token
req.detokenizer.add_token(token)
text: str = req.detokenizer.last_segment
stats: GenerationStats | None = None
finish_reason: FinishReason | None = None
raw_finish_reason: str | None = r.finish_reason
if raw_finish_reason is not None:
# Finalize to get remaining text
req.detokenizer.finalize()
text = req.detokenizer.last_segment
elapsed = time.perf_counter() - req.start_time
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
stats = GenerationStats(
prompt_tps=0.0, # Not tracked per-request in batch mode
generation_tps=generation_tps,
prompt_tokens=req.prompt_tokens,
generation_tokens=req.tokens_generated,
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
)
if raw_finish_reason == "stop":
finish_reason = "stop"
elif raw_finish_reason == "length":
finish_reason = "length"
else:
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
finish_reason = "stop"
# Track completion but don't remove yet - wait for sync_completions()
self._pending_completions.append(uid)
logger.info(
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
)
results.append(
BatchedGenerationResponse(
command_id=req.command_id,
task_id=req.task_id,
response=GenerationResponse(
text=text, token=token, finish_reason=finish_reason, stats=stats
),
)
)
# In non-distributed mode, clean up completions immediately
if not self.is_distributed:
self._remove_completed()
return results
def sync_completions(self) -> None:
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
if not self.is_distributed:
# Non-distributed: early return if nothing to do
if not self._pending_completions:
return
self._remove_completed()
return
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
# This prevents deadlock if one rank has completions and another doesn't
assert self.group is not None
synced_uids = share_object(
self._pending_completions if self.rank == 0 else None,
self.rank,
self.group,
)
if synced_uids:
self._pending_completions = synced_uids
self._remove_completed()
def _remove_completed(self) -> None:
"""Remove completed requests from tracking."""
for uid in self._pending_completions:
if uid in self.active_requests:
del self.active_requests[uid]
self._pending_completions.clear()
@property
def has_active_requests(self) -> bool:
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
@property
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def active_count(self) -> int:
return len(self.active_requests)
@property
def pending_count(self) -> int:
return len(self.batch_gen.unprocessed_prompts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def has_pending_completions(self) -> bool:
return bool(self._pending_completions)

View File

@@ -0,0 +1,30 @@
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
# pyright: reportAny=false
import pickle
from typing import TypeVar, cast
import mlx.core as mx
T = TypeVar("T")
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
if rank == 0:
if obj is None:
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
return None
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
mx.eval(mx.distributed.all_sum(data, group=group))
return obj
else:
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
if size == 0:
return None
data = mx.zeros(size, dtype=mx.uint8)
data = mx.distributed.all_sum(data, group=group)
mx.eval(data)
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))

View File

@@ -0,0 +1,104 @@
"""Time budget iterator for controlling generation loop timing in distributed mode.
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
rather than syncing every token. This reduces distributed sync overhead.
"""
import time
from typing import Iterator
import mlx.core as mx
from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
class TimeBudget(Iterator[None]):
"""Controls generation loop timing, syncing across ranks periodically.
In distributed mode, periodically syncs timing across all ranks to
dynamically adjust iteration count based on actual performance.
In non-distributed mode, simply runs for the time budget.
Usage:
for _ in TimeBudget(budget=0.5):
batch_engine.step()
# ... process responses ...
"""
def __init__(
self,
budget: float = 0.5,
iterations: int = 25,
sync_frequency: int = 10,
group: mx.distributed.Group | None = None,
):
"""Initialize TimeBudget.
Args:
budget: Time budget in seconds before yielding control
iterations: Initial number of iterations per budget period (distributed only)
sync_frequency: How often to sync timing across ranks (distributed only)
group: Distributed group, or None for non-distributed mode
"""
self._budget = budget
self._iterations = iterations
self._sync_frequency = sync_frequency
self._group = group
self._is_distributed = group is not None and group.size() > 1
# Runtime state
self._start: float = 0.0
self._current_iterations: int = 0
self._loops: int = 0
self._time_spent: float = 0.0
def __iter__(self) -> "TimeBudget":
self._start = time.perf_counter()
self._current_iterations = 0
return self
def __next__(self) -> None:
if not self._is_distributed:
# Non-distributed: just check time budget
if time.perf_counter() - self._start > self._budget:
raise StopIteration()
return None
# Distributed mode: iteration-based with periodic timing sync
self._current_iterations += 1
if self._current_iterations > self._iterations:
self._loops += 1
self._time_spent += time.perf_counter() - self._start
if self._loops % self._sync_frequency == 0:
# Sync timing across all ranks
assert self._group is not None
with mx.stream(generation_stream):
time_array = mx.array([self._time_spent], dtype=mx.float32)
total_time = mx.distributed.all_sum(time_array, group=self._group)
mx.eval(total_time)
loop_time = float(total_time.item())
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
if avg_loop_time > 0:
factor = self._budget / avg_loop_time
self._iterations = max(round(self._iterations * factor), 1)
logger.debug(
f"TimeBudget adjusted iterations to {self._iterations}"
)
self._loops = 0
self._time_spent = 0.0
raise StopIteration()
return None
@property
def iterations(self) -> int:
"""Current iterations per budget period."""
return self._iterations

View File

@@ -20,6 +20,7 @@ except ImportError:
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -365,6 +366,8 @@ def apply_chat_template(
tools=chat_task_data.tools,
)
logger.info(prompt)
return prompt
@@ -396,6 +399,11 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -277,12 +277,14 @@ def _pending_tasks(
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
if task.task_id in runner.completed or task.task_id in runner.pending:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

@@ -1,6 +1,8 @@
import gc
import time
import mlx.core as mx
from anyio import WouldBlock
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
@@ -21,9 +23,6 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -39,7 +38,9 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.generator.batch_engine import BatchGenerationEngine
from exo.worker.engines.mlx.generator.generate import warmup_inference
from exo.worker.engines.mlx.generator.time_budget import TimeBudget
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
@@ -69,142 +70,318 @@ def main(
model = None
tokenizer = None
group = None
batch_engine: BatchGenerationEngine | None = None
pending_shutdown: Shutdown | None = None
current_status: RunnerStatus = RunnerIdle()
def send_status(status: RunnerStatus) -> None:
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=status)
)
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
send_status(current_status)
logger.info("runner connected")
current_status = RunnerConnected()
def handle_task(task: Task, is_deferred: bool = False) -> bool:
nonlocal current_status, model, tokenizer, group, batch_engine, pending_shutdown
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
# For Shutdown, check if we need to defer BEFORE sending Running/Acknowledged
if (
isinstance(task, Shutdown)
and not is_deferred
and batch_engine is not None
and (batch_engine.has_active_requests or batch_engine.has_pending_inserts)
):
logger.info("deferring shutdown until active requests complete")
pending_shutdown = task
return True
model, tokenizer = load_mlx_items(bound_instance, group)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
send_status(current_status)
group = initialize_mlx(bound_instance)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
logger.info("runner connected")
current_status = RunnerConnected()
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
send_status(current_status)
case LoadModel() if (
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
send_status(current_status)
model, tokenizer = load_mlx_items(bound_instance, group)
current_status = RunnerLoaded()
logger.info("runner loaded")
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
send_status(current_status)
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model is not None
assert tokenizer is not None
current_status = RunnerWarmingUp()
logger.info("runner warming up")
send_status(current_status)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(model=model, tokenizer=tokenizer)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
batch_engine = BatchGenerationEngine(
model=model, tokenizer=tokenizer, group=group
)
current_status = RunnerReady()
logger.info("runner ready")
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
assert task_params.messages[0].content is not None
)
send_status(current_status)
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, (RunnerReady, RunnerRunning))
):
assert batch_engine is not None
# In distributed mode, only rank 0 should queue requests
# Other ranks should skip - they'll participate in sync_and_insert_pending()
is_distributed_mode = group is not None and group.size() > 1
if is_distributed_mode and shard_metadata.device_rank != 0:
logger.debug(
f"Rank {shard_metadata.device_rank} skipping ChatCompletionTask (only rank 0 queues)"
)
return True
if task_params.messages and task_params.messages[0].content is not None:
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
# Queue the request - actual insertion happens in sync_and_insert_pending()
batch_engine.queue_request(
command_id=command_id, task_id=task.task_id, task_params=task_params
)
# Status will be updated after actual insertion in the main loop
# For now, set to RunnerRunning to indicate we're processing
current_status = RunnerRunning(
active_requests=batch_engine.active_count
+ batch_engine.pending_insert_count
)
send_status(current_status)
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
send_status(current_status)
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
current_status = RunnerShutdown()
send_status(current_status)
return False
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
return True
with task_receiver as tasks:
running = True
is_rank_0 = shard_metadata.device_rank == 0
while running:
# Use batch_engine.is_distributed since it's set correctly after group initialization
# (the group variable is None at loop start, but set by ConnectToGroup task)
if batch_engine is not None and batch_engine.is_distributed:
assert group is not None
assert batch_engine is not None
# Distributed mode: tasks wake up all ranks, then we sync and generate
# Check deferred shutdown FIRST - all ranks must check and process together
# This must run before any collective operations to prevent deadlock
if (
pending_shutdown is not None
and not batch_engine.has_active_requests
and not batch_engine.has_pending_inserts
):
handle_task(pending_shutdown, is_deferred=True)
running = False
continue
# When idle, block waiting for task (exo sends tasks to all ranks)
# When active, poll non-blocking to batch incoming requests
if (
not batch_engine.has_active_requests
and not batch_engine.has_pending_inserts
):
# IDLE: Block until task arrives (all ranks receive the same task)
task = tasks.receive()
task_result = handle_task(task)
if not task_result:
running = False
continue
else:
# ACTIVE: Poll for new tasks without blocking
while True:
try:
task = tasks.receive_nowait()
task_result = handle_task(task)
if not task_result:
running = False
break
except WouldBlock:
break
if not running:
continue
# Sync and insert pending requests (collective operation)
# Rank 0 broadcasts its pending to all ranks
inserted = batch_engine.sync_and_insert_pending()
if is_rank_0 and inserted:
current_status = RunnerRunning(
active_requests=batch_engine.active_count
)
send_status(current_status)
# Run generation for time budget
if batch_engine.has_active_requests:
time_budget = TimeBudget(budget=0.5, group=group)
for _ in time_budget:
if not batch_engine.has_active_requests:
break
for resp in batch_engine.step():
# Send token IMMEDIATELY for smooth streaming (only rank 0)
if is_rank_0:
event_sender.send(
ChunkGenerated(
command_id=resp.command_id,
chunk=TokenChunk(
idx=resp.response.token,
model=shard_metadata.model_meta.model_id,
text=resp.response.text,
token_id=resp.response.token,
finish_reason=resp.response.finish_reason,
stats=resp.response.stats,
),
)
)
if resp.response.finish_reason is not None:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
TaskStatusUpdated(
task_id=resp.task_id,
task_status=TaskStatus.Complete,
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
# Sync completions at budget boundary (always call - it's a collective operation)
batch_engine.sync_completions()
# Update status after budget
if is_rank_0:
current_status = (
RunnerRunning(active_requests=batch_engine.active_count)
if batch_engine.has_active_requests
else RunnerReady()
)
send_status(current_status)
else:
# Non-distributed mode: original logic with queue + insert
while True:
try:
task = tasks.receive_nowait()
running = handle_task(task)
if not running:
break
except WouldBlock:
break
if not running:
break
# Insert any queued requests (non-distributed just inserts directly)
# Status was already sent in handle_task when queueing
if batch_engine is not None and batch_engine.has_pending_inserts:
batch_engine.sync_and_insert_pending()
if batch_engine is not None and batch_engine.has_active_requests:
for resp in batch_engine.step():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=resp.command_id,
chunk=TokenChunk(
idx=resp.response.token,
model=shard_metadata.model_meta.model_id,
text=resp.response.text,
token_id=resp.response.token,
finish_reason=resp.response.finish_reason,
stats=resp.response.stats,
),
)
)
if resp.response.finish_reason is not None:
event_sender.send(
TaskStatusUpdated(
task_id=resp.task_id,
task_status=TaskStatus.Complete,
)
)
if batch_engine.has_active_requests:
current_status = RunnerRunning(
active_requests=batch_engine.active_count
)
)
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
mx.clear_cache()
import gc
else:
current_status = RunnerReady()
send_status(current_status)
gc.collect()
break
# Process deferred shutdown after all requests complete
if (
pending_shutdown is not None
and not batch_engine.has_active_requests
and not batch_engine.has_pending_inserts
):
running = handle_task(pending_shutdown, is_deferred=True)
else:
task = tasks.receive()
running = handle_task(task)
# Cleanup
del model, tokenizer, group, batch_engine
mx.clear_cache()
gc.collect()
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"

View File

@@ -105,7 +105,7 @@ class RunnerSupervisor:
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
logger.warning("Runner process didn't shutdown successfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
@@ -128,9 +128,11 @@ class RunnerSupervisor:
async def start_task(self, task: Task):
if task.task_id in self.completed:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
logger.info(f"Skipping task {task.task_id} - already completed")
return
if task.task_id in self.pending:
logger.info(f"Skipping task {task.task_id} - already pending")
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -149,13 +151,17 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
# Just set the event to unblock start_task, but keep in pending
# to prevent duplicate forwarding until completion
if event.task_id in self.pending:
self.pending[event.task_id].set()
continue
if (
isinstance(event, TaskStatusUpdated)
and event.task_status == TaskStatus.Complete
if isinstance(event, TaskStatusUpdated) and event.task_status in (
TaskStatus.Complete,
TaskStatus.TimedOut,
TaskStatus.Failed,
):
# If a task has just been completed, we should be working on it.
# If a task has just finished, we should be working on it.
assert isinstance(
self.status,
(
@@ -166,6 +172,8 @@ class RunnerSupervisor:
RunnerShuttingDown,
),
)
# Now safe to remove from pending and add to completed
self.pending.pop(event.task_id, None)
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:

View File

@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask):

View File

@@ -0,0 +1,319 @@
"""
Tests for continuous batching behavior in the runner.
These tests verify that:
1. Single requests work through the batch path
2. Multiple concurrent requests batch together
3. Tokens are routed to the correct requests
4. Requests complete at different times appropriately
"""
# pyright: reportAny=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportAttributeAccessIssue=false
# pyright: reportInvalidTypeVarUse=false
from typing import Any
from unittest.mock import MagicMock
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import RunnerRunning
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
class FakeBatchEngineWithTokens:
"""
Fake batch engine that generates a specified number of tokens per request.
This simulates realistic batch generation behavior where:
- Requests are queued on insert
- Each step() call generates one token for all active requests
- Requests complete when they've generated all their tokens
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, ChatCompletionTaskParams]
] = []
self._uid_counter = 0
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
self.rank = 0 # Fake rank for testing
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion."""
self._pending_inserts.append((command_id, task_id, task_params))
def sync_and_insert_pending(self) -> list[int]:
"""Insert all pending requests."""
uids: list[int] = []
for command_id, task_id, task_params in self._pending_inserts:
uid = self._do_insert(command_id, task_id, task_params)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return len(self._pending_inserts) > 0
def _do_insert(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams | None,
) -> int:
uid = self._uid_counter
self._uid_counter += 1
# Track: (command_id, task_id, tokens_generated, max_tokens)
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
return uid
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
uids_to_remove: list[int] = []
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
self._active_requests.items()
):
tokens_gen += 1
finish_reason = "stop" if tokens_gen >= max_tokens else None
text = f"token{tokens_gen}"
if finish_reason:
uids_to_remove.append(uid)
else:
self._active_requests[uid] = (
command_id,
task_id,
tokens_gen,
max_tokens,
)
results.append(
BatchedGenerationResponse(
command_id=command_id,
task_id=task_id,
response=GenerationResponse(
token=tokens_gen,
text=text,
finish_reason=finish_reason,
),
)
)
for uid in uids_to_remove:
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return len(self._active_requests) > 0
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
def make_nothin[T, U, V](res: T):
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
@pytest.fixture
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
monkeypatch.setattr(
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
)
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
"""
Run tasks through the runner, adding shutdown at the end.
Tasks are sent in order, with shutdown sent last.
The batch engine processes between task handling.
"""
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NodeId(NODE_A),
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
shutdown_task = Shutdown(
task_id=TaskId("shutdown"),
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
with task_sender, event_receiver:
# Send all tasks including shutdown
for t in tasks:
task_sender.send(t)
task_sender.send(shutdown_task)
# Disable cleanup methods to prevent issues
event_sender.close = lambda: None
event_sender.join = lambda: None
task_receiver.close = lambda: None
task_receiver.join = lambda: None
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_receiver.collect()
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
def make_chat_task(
task_id: str, command_id: str, max_tokens: int = 3
) -> ChatCompletion:
return ChatCompletion(
task_id=TaskId(task_id),
command_id=CommandId(command_id),
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_tokens=max_tokens,
),
instance_id=INSTANCE_1_ID,
)
def test_single_request_generates_tokens(patch_batch_engine: None):
"""
Verify a single request generates the expected tokens through the batch path.
Note: With the current non-blocking design, shutdown is processed before
batch steps run when all tasks are queued together. This test verifies
the runner status reflects active requests.
"""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events - this shows the request was inserted
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
"""Verify RunnerRunning status includes active_requests count."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_chat_task_acknowledged(patch_batch_engine: None):
"""Verify chat completion task is acknowledged with proper status updates."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find the chat task status events
chat_running = [
e
for e in events
if isinstance(e, TaskStatusUpdated)
and e.task_id == TaskId("chat1")
and e.task_status == TaskStatus.Running
]
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
def test_multiple_requests_tracked(patch_batch_engine: None):
"""Verify multiple concurrent requests are tracked in active_requests."""
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
# Should have at least 2 RunnerRunning events (one per request inserted)
assert len(running_events) >= 2, (
f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
)
# First should have 1 active request, second should have 2
assert running_events[0].runner_status.active_requests == 1
assert running_events[1].runner_status.active_requests == 2

View File

@@ -1,12 +1,17 @@
# Check tasks are complete before runner is ever ready.
# pyright: reportAny=false
from collections.abc import Iterable
from typing import Callable
from typing import Any, Callable
from unittest.mock import MagicMock
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -22,6 +27,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
@@ -38,6 +44,9 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -107,18 +116,89 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
assert test_event == true_event, f"{test_event} != {true_event}"
class FakeBatchEngine:
"""
Fake batch engine for testing.
Queues requests on insert, returns one token per step.
The runner's non-blocking loop drains all tasks before running batch steps,
so this engine queues requests and has_active_requests returns True only
after at least one request has been inserted.
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, ChatCompletionTaskParams]
] = []
self._uid_counter = 0
self.rank = 0 # Fake rank for testing
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion."""
self._pending_inserts.append((command_id, task_id, task_params))
def sync_and_insert_pending(self) -> list[int]:
"""Insert all pending requests."""
uids: list[int] = []
for command_id, task_id, _task_params in self._pending_inserts:
uid = self._uid_counter
self._uid_counter += 1
self._active_requests[uid] = (command_id, task_id)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return len(self._pending_inserts) > 0
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
# Process all active requests - return one token and complete
for uid, (command_id, task_id) in list(self._active_requests.items()):
results.append(
BatchedGenerationResponse(
command_id=command_id,
task_id=task_id,
response=GenerationResponse(
token=0,
text="hi",
finish_reason="stop",
),
)
)
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return len(self._active_requests) > 0
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
# initialize_mlx returns a fake "group" (non-None for state machine)
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
monkeypatch.setattr(
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
)
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
def _run(tasks: Iterable[Task]):
@@ -148,7 +228,8 @@ def _run(tasks: Iterable[Task]):
return event_receiver.collect()
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPatch):
"""Verify chat completion generates tokens, completes, and runner returns to Ready."""
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
expected_chunk = ChunkGenerated(
@@ -191,7 +272,9 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerRunning(active_requests=1)
),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
@@ -206,7 +289,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
],
)

View File

@@ -1,4 +1,5 @@
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
@@ -6,6 +7,8 @@ from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
@@ -15,8 +18,9 @@ async def check_reachability(
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
@@ -32,7 +36,16 @@ async def check_reachability(
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()

View File

@@ -89,6 +89,12 @@ async def assert_downloads():
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def ring_backend(test: Tests):