Compare commits

...

15 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
0c5c87cd9d Fix gibberish outputs 2026-01-19 16:26:50 +00:00
Ryuichi Leo Takashige
4c1af11f14 Fix model hanging 2026-01-19 15:09:31 +00:00
Ryuichi Leo Takashige
f654b98d97 Fix model hanging 2026-01-19 14:35:16 +00:00
Ryuichi Leo Takashige
060dc8a3d8 Failing test 2026-01-19 13:11:41 +00:00
rltakashige
ea0588429b Custom mlx layer composition (#1201)
## Motivation

With a single pipeline layer, PipelineFirstLayer gets composed with
PipelineLastLayer.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing


### Automated Testing
Made failing tests. Fixed them!
2026-01-19 12:36:25 +00:00
rltakashige
73b3f87e07 Set swa_idx and ga_idx for single layer (#1202)
## Motivation

Layer types does not contain either "sliding_attention" or
"full_attention" for pipeline parallel (single layer).

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Manually tested single layer of GPT OSS. Doesn't crash

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-19 12:31:11 +00:00
Evan Quiney
746589ba6b tidy: remove context manager from api (#1199) 2026-01-19 11:58:13 +00:00
rltakashige
f82f862fd7 Fix several issues with placement (#1200)
## Motivation

Uneven placements were causing issues for some users with lopsided
setups. While fixing, I ran into another issue with impossible
allocation of memory.

## Changes

- Allocate at least 1 layer per device.
- Catch overallocation of memory with an error.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Tested that GPT OSS is placed correctly.

### Automated Testing
Added breaking tests in the first commit. Resolved with new placement
algorithm in the second one.
2026-01-19 11:52:35 +00:00
Alex Cheema
7ff937d8a1 Add dashboard screenshots to README (#1185)
## Motivation

The README showcases exo's features and benchmarks but doesn't show what
the dashboard actually looks like. Adding a screenshot helps users
understand what they'll get when they run exo.

## Changes

- Added dashboard screenshot to `docs/imgs/dashboard-cluster-view.png`:
Shows the cluster topology view with 4 × 512GB M3 Ultra Mac Studio
running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)
- Added a new "Dashboard" section to README.md below Features,
displaying the screenshot with caption

## Why It Works

Visual documentation helps users understand what exo offers before they
install it. The screenshot demonstrates the cluster management
capabilities.

## Test Plan

### Manual Testing
- Verified image renders correctly in GitHub markdown preview

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 10:43:27 +00:00
Evan Quiney
d19bf02404 re-raise exceptions in the runner (#1198)
## Motivation

Runners that crash can swallow errors - we should re-raise. Also the
exception handler annoyed me.

## Changes

The try: except in the runner's chat now re-raises.
2026-01-19 10:35:23 +00:00
rltakashige
618cee5223 Resolve test event ordering flakiness (#1194)
## Motivation

mp sender occasionally does not have time to flush its events before
collect() is called, making the event ordering test fail.

## Changes

- Replace mp_channel with simple collector for event ordering test
- Also suppress warning for <frozen importlib._bootstrap>:488 <frozen
importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject
has no __module__ attribute


## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
Ran the test 100 times without it failing.
2026-01-18 20:33:20 +00:00
Antonio Lujano Luna
9c29eb7d48 Add proxy and custom SSL certificate support for corporate networks (#1189)
Support HTTPS_PROXY/HTTP_PROXY environment variables for proxy
configuration and SSL_CERT_FILE for custom CA certificates, enabling use
in corporate environments with SSL inspection.

## Motivation
Users in corporate environments often need to route traffic through HTTP
proxies and use custom CA certificates for SSL inspection. Without this
support, exo cannot download models in these network configurations.

## Changes
- Added `HTTPS_PROXY`/`HTTP_PROXY` environment variable support to
`create_http_session()` in `download_utils.py`
- Added `SSL_CERT_FILE` environment variable support for custom CA
certificate bundles, falling back to certifi's default bundle

## Why It Works
- `aiohttp.ClientSession` natively supports the `proxy` parameter for
routing requests through HTTP proxies
- `ssl.create_default_context(cafile=...)` accepts a custom CA bundle
path, allowing corporate CAs to be trusted
- Using environment variables is consistent with the codebase's existing
configuration patterns (e.g., `EXO_HOME`, `HF_ENDPOINT`)

## Test Plan
### Manual Testing
- Set `HTTPS_PROXY` environment variable and verified model downloads
route through proxy
- Set `SSL_CERT_FILE` to custom CA bundle and verified SSL verification
succeeds with corporate SSL inspection

### Automated Testing
- No automated tests added; this change is configuration-only and does
not alter existing behavior when environment variables are unset
2026-01-18 12:05:50 +00:00
Alex Cheema
c5158bee53 Add pre-commit checks documentation to AGENTS.md (#1184)
## Motivation

CI failures can be avoided by running checks locally before committing.
This adds clear documentation to AGENTS.md so that AI agents (and
humans) know exactly which checks must pass before pushing code.

## Changes

Added a new "Pre-Commit Checks (REQUIRED)" section to AGENTS.md that:
- Lists all 4 required checks (basedpyright, ruff, nix fmt, pytest)
- Provides a one-liner to run all checks in sequence
- Notes that `nix fmt` changes must be staged before committing
- Explains that CI runs `nix flake check` which verifies everything

## Why It Works

Clear documentation prevents CI failures by ensuring contributors run
checks locally first. The one-liner command makes it easy to run all
checks before committing.

## Test Plan

### Manual Testing
- Verified the documented commands work correctly

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 21:50:24 +00:00
rltakashige
5c8a237940 Handle model timeouts (#1177)
- Add eval with a timeout.
- Add fast synch flag

## Motivation

Because of the experimental FAST SYNCH flag, some models may not work.
This PR catches when this occurs and allows users to specify a run
without fast synch

## Changes

- Adds a flag to enable or disable fast synch (--fast-synch and
--no-fast-synch)
- Adds a heuristic timeout
- Reduces exo_bench default timeout to 10 minutes.

## Why It Works

Heuristic timeout assumes normal loading times on Mac devices (60 +
model size in gb / 5: e.g. DeepSeek takes up to 120 seconds to load on
tensor parallel, and timeout is set to 60 + 120 = 180s.

We could raise this value if necessary.

## Test Plan

### Manual Testing
Catches that GPT OSS fails to load in Tensor RDMA
Can launch with --no-fast-synch flag to launch GPT OSS.

**GPT OSS 20B**
TP with fast synch
<img width="3064" height="456" alt="image"
src="https://github.com/user-attachments/assets/f6e25cd8-8621-4e99-99fe-292ee05c4035"
/>

TP without fast synch
<img width="3098" height="496" alt="image"
src="https://github.com/user-attachments/assets/d36453d9-6686-4cfe-aa7c-a7d458369d4d"
/>
[Note: the performance is really not great as fast synch is off]

(As a sanity check)
PP with fast synch
<img width="3124" height="496" alt="image"
src="https://github.com/user-attachments/assets/e97d4547-c6fa-483d-badb-4b371b900b4c"
/>

PP without fast synch
<img width="3078" height="508" alt="image"
src="https://github.com/user-attachments/assets/b2e20dfd-4b0e-4295-8a92-417dfe745c28"
/>

PP without RDMA
<img width="3070" height="498" alt="image"
src="https://github.com/user-attachments/assets/a8509d68-0aef-4cda-bca5-a67d39a0801e"
/>

TP without RDMA
<img width="3068" height="496" alt="image"
src="https://github.com/user-attachments/assets/b5691429-89f4-4369-bcf2-8fde2ad7154a"
/>
2026-01-16 20:25:12 +00:00
rltakashige
745343c705 Return error responses for Chat Completions (#1173)
- Error chunks
- Use error handling in exo_bench.py

## Motivation

Return when an error occurs so that generation stops. Adding timeouts is
a separate TODO for model loading and chat completions.

## Changes

- Return HTTP exceptions as JSON responses in an OpenAI compatible
format.
- Context manager for generation to catch and return error messages.
- Use error handling in exo_bench.py.

## Test Plan

### Manual Testing
Manually tested that exo_bench returns on failures within and outside
generation

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 19:24:37 +00:00
22 changed files with 1262 additions and 100 deletions

View File

@@ -40,6 +40,31 @@ uv run ruff check
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition

View File

@@ -27,6 +27,15 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Dashboard
exo includes a built-in dashboard for managing your cluster and chatting with models.
<p align="center">
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
</p>
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
## Benchmarks
<details>

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
@@ -26,7 +27,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -104,22 +105,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
@@ -299,6 +324,12 @@ def main() -> int:
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
@@ -320,7 +351,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -399,7 +430,7 @@ def main() -> int:
):
continue
if 0 < n <= args.max_nodes:
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
if not selected:
@@ -441,7 +472,13 @@ def main() -> int:
)
client.request_json("POST", "/instance", body={"instance": instance})
wait_for_instance_ready(client, instance_id)
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)
@@ -453,17 +490,17 @@ def main() -> int:
logger.debug(f" warmup {i + 1}/{args.warmup} done")
for pp in pp_list:
if (
pp * n_nodes > 2048
and "ring" in instance_meta.lower()
and "tensor" in sharding.lower()
):
model_card = MODEL_CARDS[short_id]
if model_card.metadata.storage_size > Memory.from_gb(10):
logger.info(
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
)
continue
# if (
# pp * n_nodes > 2048
# and "ring" in instance_meta.lower()
# and "tensor" in sharding.lower()
# ):
# model_card = MODEL_CARDS[short_id]
# if model_card.metadata.storage_size > Memory.from_gb(10):
# logger.info(
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
# )
# continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 187 KiB

View File

@@ -126,3 +126,6 @@ env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -205,6 +205,14 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -218,6 +226,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -259,6 +268,20 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,13 +1,14 @@
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
import anyio
from anyio import create_task_group
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
@@ -29,6 +30,8 @@ from exo.shared.types.api import (
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
GenerationStats,
ModelList,
@@ -49,7 +52,12 @@ from exo.shared.types.commands import (
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
@@ -115,6 +123,7 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -145,6 +154,21 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
self.app.exception_handler(HTTPException)(self.http_exception_handler)
async def http_exception_handler(
self, _: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -406,6 +430,18 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -426,6 +462,12 @@ class API:
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -463,6 +505,12 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -607,14 +655,14 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -49,33 +49,83 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
if not selected_cycle:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node.node_profile.memory.ram_available.in_bytes / cycle_memory.in_bytes
for node in selected_cycle
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node.node_profile.memory.ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node.node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(

View File

@@ -0,0 +1,107 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -70,7 +70,7 @@ def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
[
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
((312, 468, 1092), 12, (2, 3, 7)),
],
)
def test_get_instance_placements_create_instance(

View File

@@ -3,6 +3,7 @@ from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
@@ -165,6 +166,9 @@ def test_get_smallest_cycles(
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
@@ -397,3 +401,96 @@ def test_get_mlx_jaccl_coordinators(
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a = create_node(900 * 1024, node_a_id)
node_b = create_node(50 * 1024, node_b_id)
node_c = create_node(10 * 1024, node_c_id) # Insufficient memory
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(model_meta, selected_cycle, Sharding.Pipeline)

View File

@@ -11,10 +11,21 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call"
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"

View File

@@ -22,6 +22,7 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):

View File

@@ -245,12 +245,15 @@ def create_http_session(
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(cafile=certifi.where())
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,

View File

@@ -41,14 +41,16 @@ class _LayerCallable(Protocol):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class CustomMlxLayer(nn.Module):
class CustomMlxModule(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable):
super().__init__()
# Set twice to avoid __setattr__ recursion
object.__setattr__(self, "_original_layer", original_layer)
self.original_layer: _LayerCallable = original_layer
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -58,10 +60,10 @@ class CustomMlxLayer(nn.Module):
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
return object.__getattribute__(original_layer, name)
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
class PipelineFirstLayer(CustomMlxModule):
def __init__(
self,
original_layer: _LayerCallable,
@@ -78,7 +80,7 @@ class PipelineFirstLayer(CustomMlxLayer):
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(CustomMlxLayer):
class PipelineLastLayer(CustomMlxModule):
def __init__(
self,
original_layer: _LayerCallable,
@@ -168,11 +170,21 @@ def pipeline_auto_parallel(
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
# We can assume the model has at least one layer thanks to placement.
# If a layer type doesn't exist, we can set it to 0.
inner_model_instance.swa_idx = (
0
if "sliding_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
inner_model_instance.ga_idx = (
0
if "full_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
)
_set_layers(model, layers)
@@ -181,7 +193,32 @@ def pipeline_auto_parallel(
"Expected a list of layers after auto-parallel initialisation"
)
return model
return PipelineParallelModel(model, group)
class PipelineParallelModel(CustomMlxModule):
def __init__(self, model: nn.Module, group: mx.distributed.Group):
super().__init__(model)
self.original_call_signature = signature(self.original_layer.__call__)
self.group = group
dict.__setitem__(self, "original_layer", model)
def __call__(
self,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = self.original_layer(*args, **kwargs) # type: ignore
cache = self.original_call_signature.bind_partial(
*args, **kwargs
).arguments.get("cache", None)
if cache is not None:
for c in cache: # type: ignore
if hasattr(c, "state") and c.state is not None: # type: ignore
c.state = mx.depends(c.state, logits) # type: ignore
return logits
def tensor_auto_parallel(
@@ -389,7 +426,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
return model
class ShardedDeepseekV3MoE(CustomMlxLayer):
class ShardedDeepseekV3MoE(CustomMlxModule):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
@@ -464,7 +501,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
return model
class ShardedQwenMoE(CustomMlxLayer):
class ShardedQwenMoE(CustomMlxModule):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
@@ -511,7 +548,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
return model
class ShardedGptOssMoE(CustomMlxLayer):
class ShardedGptOssMoE(CustomMlxModule):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None

View File

@@ -2,7 +2,9 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -82,6 +84,45 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -188,7 +229,9 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
@@ -202,7 +245,9 @@ def load_mlx_items(
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
@@ -216,6 +261,7 @@ def load_mlx_items(
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
@@ -252,7 +298,15 @@ def shard_and_load(
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)

View File

@@ -17,15 +17,23 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
)
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main

View File

@@ -67,6 +67,7 @@ def main(
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
@@ -118,7 +119,20 @@ def main(
)
)
model, tokenizer = load_mlx_items(bound_instance, group)
def on_model_load_timeout() -> None:
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id,
runner_status=RunnerFailed(
error_message="Model loading timed out"
),
)
)
time.sleep(0.5)
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -148,8 +162,6 @@ def main(
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -158,41 +170,61 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert model
assert tokenizer
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
)
try:
_check_for_debug_prompts(task_params.messages[0].content)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
)
# TODO: Add tool call parser here
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
case GenerationResponse():
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
)
# case TokenizedResponse():
# TODO: something here ig
# can we make this more explicit?
except Exception as e:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_meta.model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")

View File

@@ -0,0 +1,220 @@
# type: ignore
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
class MockLayer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
self.use_sliding = True
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x * 2
@dataclass(frozen=True)
class PipelineTestConfig:
model_path: Path
total_layers: int
base_port: int
max_tokens: int
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
import json
import tempfile
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
return hostfile_path, hosts
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
total_layers=24,
base_port=29600,
max_tokens=200,
)
DEFAULT_GPT_OSS_MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
def run_gpt_oss_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 200,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=False,
),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=24,
)
model, tokenizer = shard_and_load(shard_meta, group)
model = cast(Model, model)
# Generate a prompt of exact token length
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
# Build prompt with approximate target length
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
# Truncate to exact target length
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
):
generated_text += response.text
if response.finish_reason is not None:
break
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
def run_gpt_oss_tensor_parallel_device(
rank: int,
world_size: int,
hostfile_path: str,
prompt_tokens: int,
prefill_step_size: int,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 10,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
# For tensor parallelism, all devices run all layers
shard_meta = TensorShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
device_rank=rank,
world_size=world_size,
start_layer=0,
end_layer=24,
n_layers=24,
)
model, tokenizer = shard_and_load(shard_meta, group)
model = cast(Model, model)
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
):
generated_text += response.text
if response.finish_reason is not None:
break
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]

View File

@@ -0,0 +1,154 @@
import multiprocessing as mp
from typing import Any
import mlx.core as mx
import pytest
from exo.worker.engines.mlx.auto_parallel import (
CustomMlxModule,
PipelineFirstLayer,
PipelineLastLayer,
PipelineParallelModel,
)
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
def run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
result_queue: Any, # pyright: ignore[reportAny]
) -> None:
import os
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
import mlx.nn as mlx_nn
class MockLayerInner(mlx_nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
def __call__(
self, x: mlx_core.array, *args: object, **kwargs: object
) -> mlx_core.array:
return x * 2
class MockModel(mlx_nn.Module):
def __init__(self, layers: list[mlx_nn.Module]) -> None:
super().__init__()
self.layers = layers
def __call__(
self, x: mlx_core.array, *args: object, **kwargs: object
) -> mlx_core.array:
for layer in self.layers:
x = layer(x, *args, **kwargs) # pyright: ignore[reportUnknownVariableType]
return x # pyright: ignore[reportUnknownVariableType]
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
inner_model = MockModel([composed])
model = PipelineParallelModel(inner_model, group)
x = mlx_core.ones((1, 4))
result = model(x)
mlx_core.eval(result)
success = result.shape == x.shape
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
def test_single_wrapper_delegates_attributes() -> None:
mock = MockLayer()
wrapped = CustomMlxModule(mock)
assert wrapped.custom_attr == "test_value" # type: ignore[attr-defined]
assert wrapped.use_sliding is True # type: ignore[attr-defined]
def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer()
group = mx.distributed.init()
first = PipelineFirstLayer(mock, r=0, group=group)
composed = PipelineLastLayer(first, r=0, s=1, group=group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined]
def test_missing_attribute_raises() -> None:
mock = MockLayer()
wrapped = CustomMlxModule(mock)
with pytest.raises(AttributeError):
_ = wrapped.nonexistent_attr # type: ignore[attr-defined]
def test_composed_call_works() -> None:
import json
import os
import tempfile
ctx = mp.get_context("spawn")
world_size = 2
base_port = 29500
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
try:
result_queue: Any = ctx.Queue()
processes: list[Any] = []
for rank in range(world_size):
p = ctx.Process(
target=run_pipeline_device,
args=(rank, world_size, hostfile_path, result_queue),
)
p.start()
processes.append(p)
for p in processes: # pyright: ignore[reportAny]
p.join(timeout=10) # pyright: ignore[reportAny]
results: dict[int, Any] = {}
errors: dict[int, str] = {}
while not result_queue.empty(): # pyright: ignore[reportAny]
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
if success:
results[rank] = value
else:
errors[rank] = value
assert len(results) == world_size, (
f"Expected {world_size} results, got {len(results)}. Errors: {errors}"
)
for rank in range(world_size):
assert rank in results, (
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
)
result_array = results[rank]
# Both devices see the final result (4.0) after all_gather
assert (result_array == 4.0).all(), (
f"Device {rank}: expected 4.0, got {result_array}"
)
finally:
os.unlink(hostfile_path)

View File

@@ -0,0 +1,230 @@
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import Any, Callable
import pytest
from exo.worker.tests.unittests.test_mlx.conftest import (
DEFAULT_GPT_OSS_CONFIG,
create_hostfile,
run_gpt_oss_pipeline_device,
run_gpt_oss_tensor_parallel_device,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
pytestmark = [
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
),
]
@dataclass
class DistributedTestResult:
timed_out: bool
world_size: int
results: dict[int, tuple[bool, str]]
@property
def all_success(self) -> bool:
if len(self.results) != self.world_size:
return False
return all(r[0] for r in self.results.values())
def run_distributed_test(
world_size: int,
port_offset: int,
process_timeout: int,
target: Callable[..., None],
make_args: Callable[[int], tuple[Any, ...]],
) -> DistributedTestResult:
ctx = mp.get_context("spawn")
hostfile_path, _ = create_hostfile(
world_size, DEFAULT_GPT_OSS_CONFIG.base_port + port_offset
)
try:
result_queue: Any = ctx.Queue()
processes: list[Any] = []
for rank in range(world_size):
args = make_args(rank)
p = ctx.Process(
target=target,
args=(rank, world_size, hostfile_path, *args, result_queue),
)
p.start()
processes.append(p)
for p in processes: # pyright: ignore[reportAny]
p.join(timeout=process_timeout) # pyright: ignore[reportAny]
timed_out = any(p.is_alive() for p in processes) # pyright: ignore[reportAny]
for p in processes: # pyright: ignore[reportAny]
if p.is_alive(): # pyright: ignore[reportAny]
p.terminate() # pyright: ignore[reportAny]
p.join(timeout=5) # pyright: ignore[reportAny]
results: dict[int, tuple[bool, str]] = {}
while not result_queue.empty(): # pyright: ignore[reportAny]
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
results[rank] = (success, value)
return DistributedTestResult(
timed_out=timed_out, world_size=world_size, results=results
)
finally:
os.unlink(hostfile_path)
def run_pipeline_test(
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
port_offset: int = 0,
process_timeout: int = 60,
) -> DistributedTestResult:
def make_args(rank: int) -> tuple[Any, ...]:
return (
layer_splits,
prompt_tokens,
prefill_step_size,
)
return run_distributed_test(
world_size=len(layer_splits),
port_offset=port_offset,
process_timeout=process_timeout,
target=run_gpt_oss_pipeline_device,
make_args=make_args,
)
def run_tensor_test(
prompt_tokens: int,
prefill_step_size: int,
port_offset: int = 0,
process_timeout: int = 60,
) -> DistributedTestResult:
def make_args(rank: int) -> tuple[Any, ...]:
return (
prompt_tokens,
prefill_step_size,
)
return run_distributed_test(
world_size=2,
port_offset=port_offset,
process_timeout=process_timeout,
target=run_gpt_oss_tensor_parallel_device,
make_args=make_args,
)
class TestPipelineParallelFix:
BUG_TRIGGER_SPLITS: list[tuple[int, int]] = [(0, 1), (1, 24)]
def test_pipeline_single_layer_first_device(self) -> None:
result = run_pipeline_test(
layer_splits=self.BUG_TRIGGER_SPLITS,
prompt_tokens=100,
prefill_step_size=64,
process_timeout=60,
)
assert not result.timed_out, "Unexpected timeout - fix may not be working"
assert result.all_success, f"Failures: {result.results}"
class TestPipelineSplitConfigurations:
@pytest.mark.parametrize(
"layer_splits",
[
[(0, 1), (1, 24)],
[(0, 6), (6, 24)],
[(0, 12), (12, 24)],
],
ids=["1_23", "6_18", "12_12"],
)
def test_pipeline_splits(
self,
layer_splits: list[tuple[int, int]],
) -> None:
result = run_pipeline_test(
layer_splits=layer_splits,
prompt_tokens=600,
prefill_step_size=512,
port_offset=100,
)
assert not result.timed_out, f"Timeout with {layer_splits}"
assert result.all_success, f"Failures with {layer_splits}: {result.results}"
class TestPrefillStepSizeBoundaries:
@pytest.mark.parametrize(
"prefill_step_size,prompt_tokens",
[
(512, 511),
(512, 512),
(512, 513),
(512, 1024),
],
ids=["under", "exact", "over", "double"],
)
def test_boundary_conditions(
self,
prefill_step_size: int,
prompt_tokens: int,
) -> None:
result = run_pipeline_test(
layer_splits=[(0, 12), (12, 24)],
prompt_tokens=prompt_tokens,
prefill_step_size=prefill_step_size,
port_offset=200,
)
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelFix:
def test_tensor_parallel(self) -> None:
result = run_tensor_test(
prompt_tokens=100,
prefill_step_size=64,
port_offset=400,
)
assert not result.timed_out, "Unexpected timeout"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelBoundaries:
@pytest.mark.parametrize(
"prefill_step_size,prompt_tokens",
[
(512, 511),
(512, 512),
(512, 513),
(512, 1024),
],
ids=["under", "exact", "over", "double"],
)
def test_tensor_parallel_boundaries(
self,
prefill_step_size: int,
prompt_tokens: int,
) -> None:
result = run_tensor_test(
prompt_tokens=prompt_tokens,
prefill_step_size=prefill_step_size,
port_offset=500,
)
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
assert result.all_success, f"Failures: {result.results}"

View File

@@ -121,6 +121,21 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness.
class EventCollector:
def __init__(self) -> None:
self.events: list[Event] = []
def send(self, event: Event) -> None:
self.events.append(event)
def close(self) -> None:
pass
def join(self) -> None:
pass
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -130,22 +145,20 @@ def _run(tasks: Iterable[Task]):
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
event_sender = EventCollector()
with task_sender, event_receiver:
with task_sender:
for t in tasks:
task_sender.send(t)
# worst monkeypatch known to man
# this is some c++ nonsense
event_sender.close = nothin
event_sender.join = nothin
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver)
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
return event_receiver.collect()
return event_sender.events
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):