Compare commits

..

11 Commits

Author SHA1 Message Date
Alex Cheema
c7b7523bb4 Fix merge conflicts and post-merge issues
- Resolve 5 merge conflicts from origin/main merge
- Fix `await runner.shutdown()` → `runner.shutdown()` (now sync)
- Apply nix fmt formatting
- Exclude start_distributed_test.py from pytest collection

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 10:13:51 -08:00
Alex Cheema
3ee1076d44 Merge remote-tracking branch 'origin/main' into alexcheema/dashboard-stop-button
# Conflicts:
#	src/exo/master/placement.py
#	src/exo/worker/engines/mlx/generator/generate.py
#	src/exo/worker/main.py
#	src/exo/worker/runner/runner.py
#	src/exo/worker/runner/runner_supervisor.py
2026-02-13 10:12:24 -08:00
Alex Cheema
c2f47802c8 Merge latest runner-cancellation (force-push) and resolve test conflict
Keep mx.distributed.all_gather monkeypatch which matches how runner.py
actually calls all_gather.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 06:48:42 -08:00
Alex Cheema
a067fd0753 Merge runner-cancellation into dashboard-stop-button
Resolve conflicts by keeping runner_id on CancelTask and using it
directly in worker dispatch and plan generation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 06:45:45 -08:00
Evan
333119ec2f api cancellation
closing the http request to the api now
- sends a cancellation from the api
- writes that canellation in the master
- worker plans off the cancellation
- runner observes that cancellation after every generation step (+1
communication per token)
- cancellation happens synchronously to prevent gpu locks
2026-02-05 14:42:05 +00:00
Evan
79661ea3a3 api cancellation
closing the http request to the api now
- sends a cancellation from the api
- writes that canellation in the master
- worker plans off the cancellation
- runner observes that cancellation after every generation step (+1
communication per token)
- cancellation happens synchronously to prevent gpu locks
2026-02-05 14:33:22 +00:00
Alex Cheema
196f4fb35f Merge remote-tracking branch 'origin/main' into alexcheema/dashboard-stop-button 2026-02-05 05:38:55 -08:00
Alex Cheema
08a4ca59c6 fix: repair test_event_ordering monkeypatch and runner formatting
- Fix broken monkeypatch of `mx.all_gather` by patching `mx.distributed.all_gather` correctly
- Apply formatter to runner.py for CI compliance

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 05:28:13 -08:00
Alex Cheema
45ed3903f3 fix: apply formatter to runner.py for CI compliance
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 05:21:52 -08:00
Alex Cheema
6657d8600f feat: add stop button to dashboard for cancelling ongoing requests
When a chat, image generation, or image edit request is streaming,
the send button is replaced with a stop button. Clicking it aborts the
HTTP connection, which triggers the backend cancellation chain.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 04:58:40 -08:00
Evan
bf06580e0e api cancellation
closing the http request to the api now
- sends a cancellation from the api
- writes that canellation in the master
- worker plans off the cancellation
- runner observes that cancellation after every generation step (+1
communication per token)
- cancellation happens synchronously to prevent gpu locks
2026-02-05 12:39:11 +00:00
37 changed files with 349 additions and 1159 deletions

View File

@@ -1,15 +0,0 @@
.venv/
.direnv/
target/
.git/
.idea/
.pytest_cache/
.ruff_cache/
dashboard/node_modules/
dashboard/.svelte-kit/
dashboard/build/
dist/
*.pdb
**/__pycache__
**/.DS_Store
.mlx_typings/

View File

@@ -1,42 +0,0 @@
name: e2e-tests
on:
push:
pull_request:
branches:
- staging
- main
jobs:
e2e:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
/opt/microsoft /opt/az
docker system prune -af
df -h /
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build E2E image with cache
uses: docker/build-push-action@v6
with:
context: .
file: e2e/Dockerfile
tags: exo-e2e:latest
load: true
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Run E2E tests
run: python3 e2e/run_all.py

View File

@@ -5,21 +5,21 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).

View File

@@ -1 +0,0 @@
collect_ignore = ["tests/start_distributed_test.py"]

View File

@@ -1,6 +1,7 @@
<script lang="ts">
import {
isLoading,
stopGeneration,
sendMessage,
generateImage,
editImage,
@@ -648,86 +649,92 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
{#if loading}
<button
type="button"
onclick={() => stopGeneration()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-red-900/50 hover:text-red-400 border border-exo-medium-gray/50 hover:border-red-500/50 cursor-pointer"
aria-label="Stop generation"
>
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
class="w-2.5 h-2.5 sm:w-3 sm:h-3"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
fill="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
<rect x="4" y="4" width="16" height="16" rx="2" />
</svg>
<span>EDIT</span>
<span class="hidden sm:inline">STOP</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
</div>
<!-- Bottom accent line -->

View File

@@ -514,6 +514,7 @@ class AppStore {
messages = $state<Message[]>([]);
currentResponse = $state("");
isLoading = $state(false);
private currentAbortController: AbortController | null = null;
// Performance metrics
ttftMs = $state<number | null>(null); // Time to first token in ms
@@ -1814,9 +1815,11 @@ class AppStore {
return;
}
this.currentAbortController = new AbortController();
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
signal: this.currentAbortController.signal,
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
@@ -1930,6 +1933,7 @@ class AppStore {
"Unknown error",
);
} finally {
this.currentAbortController = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
@@ -2066,6 +2070,10 @@ class AppStore {
assistantMessageId: string,
errorPrefix = "Failed to get response",
): void {
// Don't show error for user-initiated abort (stop button)
if (error instanceof DOMException && error.name === "AbortError") {
return;
}
if (this.conversationExists(targetConversationId)) {
this.updateConversationMessage(
targetConversationId,
@@ -2107,6 +2115,17 @@ class AppStore {
return null;
}
/**
* Stop the current generation by aborting the HTTP connection.
* This triggers backend cancellation via the mechanism in PR #1276.
*/
stopGeneration() {
if (this.currentAbortController) {
this.currentAbortController.abort();
this.currentAbortController = null;
}
}
/**
* Send a message to the LLM and stream the response
*/
@@ -2255,11 +2274,13 @@ class AppStore {
let firstTokenTime: number | null = null;
let tokenCount = 0;
this.currentAbortController = new AbortController();
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
signal: this.currentAbortController.signal,
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
@@ -2410,6 +2431,7 @@ class AppStore {
"Failed to get response",
);
} finally {
this.currentAbortController = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
@@ -2514,11 +2536,13 @@ class AppStore {
};
}
this.currentAbortController = new AbortController();
const response = await fetch("/v1/images/generations", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
signal: this.currentAbortController.signal,
body: JSON.stringify(requestBody),
});
@@ -2667,6 +2691,7 @@ class AppStore {
"Failed to generate image",
);
} finally {
this.currentAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
@@ -2788,8 +2813,10 @@ class AppStore {
);
}
this.currentAbortController = new AbortController();
const apiResponse = await fetch("/v1/images/edits", {
method: "POST",
signal: this.currentAbortController.signal,
body: formData,
});
@@ -2899,6 +2926,7 @@ class AppStore {
"Failed to edit image",
);
} finally {
this.currentAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
@@ -3039,6 +3067,7 @@ export const hasStartedChat = () => appStore.hasStartedChat;
export const messages = () => appStore.messages;
export const currentResponse = () => appStore.currentResponse;
export const isLoading = () => appStore.isLoading;
export const stopGeneration = () => appStore.stopGeneration();
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;

View File

@@ -1,58 +0,0 @@
# Stage 1: Build the dashboard
FROM node:22-slim AS dashboard
WORKDIR /app/dashboard
COPY dashboard/package.json dashboard/package-lock.json ./
RUN npm ci
COPY dashboard/ .
RUN npm run build
# Stage 2: Build and run exo
FROM python:3.13-slim
# Install system dependencies
# libblas-dev/liblapack-dev/liblapacke-dev are required by MLX CPU backend on Linux
RUN apt-get update && apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
libblas-dev \
liblapack-dev \
liblapacke-dev \
curl \
protobuf-compiler \
iptables \
&& rm -rf /var/lib/apt/lists/*
# Install Rust nightly
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
ENV PATH="/root/.cargo/bin:${PATH}"
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
# Must be done BEFORE uv sync so any source builds also get the fix
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
chmod +x /usr/bin/g++
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files first for better layer caching
COPY pyproject.toml Cargo.toml uv.lock README.md ./
COPY rust/ ./rust/
COPY bench/pyproject.toml ./bench/pyproject.toml
# Copy source and resources
COPY src/ ./src/
COPY resources/ ./resources/
# Copy built dashboard from stage 1
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
# Install Python deps and build Rust bindings, then clean up build artifacts
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
CMD [".venv/bin/exo", "-v"]

View File

@@ -1,195 +0,0 @@
"""Shared E2E test infrastructure for exo cluster tests."""
import asyncio
import json
import os
import sys
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
E2E_DIR = Path(__file__).parent.resolve()
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
class Cluster:
"""Async wrapper around a docker compose exo cluster."""
def __init__(self, name: str, overrides: list[str] | None = None):
self.name = name
self.project = f"e2e-{name}"
compose_files = [str(E2E_DIR / "docker-compose.yml")]
for path in overrides or []:
compose_files.append(str(E2E_DIR / path))
self._compose_base = [
"docker",
"compose",
"-p",
self.project,
*[arg for f in compose_files for arg in ("-f", f)],
]
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
await self.stop()
async def _run(self, *args: str, check: bool = True) -> str:
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
print(output, file=sys.stderr)
raise RuntimeError(
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
)
return output
async def build(self):
# Skip build if the image was pre-built (e.g. in CI with buildx cache)
proc = await asyncio.create_subprocess_exec(
"docker",
"image",
"inspect",
"exo-e2e:latest",
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await proc.wait()
if proc.returncode == 0:
print(" Using pre-built image (exo-e2e:latest)")
return
print(" Building images...")
await self._run("build", "--quiet")
async def start(self):
print(" Starting cluster...")
await self._run("up", "-d")
async def stop(self):
print(" Cleaning up...")
await self._run("down", "--timeout", "5", check=False)
async def logs(self) -> str:
return await self._run("logs", check=False)
async def exec(
self, service: str, *cmd: str, check: bool = True
) -> tuple[int, str]:
"""Run a command inside a running container. Returns (returncode, output)."""
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
"exec",
"-T",
service,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
raise RuntimeError(
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
)
return proc.returncode, output
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
"""Poll check_fn every 2s until it returns True or timeout expires."""
print(f" Waiting for {description}...")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if await check_fn():
print(f" {description}")
return
await asyncio.sleep(2)
output = await self.logs()
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
raise TimeoutError(f"Timed out waiting for {description}")
async def assert_healthy(self):
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
async def both_nodes_started():
log = await self.logs()
return log.count("Starting node") >= 2
async def nodes_discovered():
log = await self.logs()
return log.count("ConnectionMessageType.Connected") >= 2
async def master_elected():
log = await self.logs()
return "demoting self" in log
async def api_responding():
try:
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
return resp.status == 200
except (URLError, OSError):
return False
await self.wait_for("Both nodes started", both_nodes_started)
await self.wait_for("Nodes discovered each other", nodes_discovered)
await self.wait_for("Master election resolved", master_elected)
await self.wait_for("API responding", api_responding)
async def _api(
self, method: str, path: str, body: dict | None = None, timeout: int = 30
) -> dict:
"""Make an API request to the cluster. Returns parsed JSON."""
url = f"http://localhost:52415{path}"
data = json.dumps(body).encode() if body else None
req = Request(
url, data=data, headers={"Content-Type": "application/json"}, method=method
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda: urlopen(req, timeout=timeout).read()
)
return json.loads(resp_bytes)
async def place_model(self, model: str, timeout: int = 600):
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
await self._api("POST", "/place_instance", {"model_id": model})
async def model_ready():
try:
resp = await self._api("GET", "/v1/models")
return any(m.get("id") == model for m in resp.get("data", []))
except Exception:
return False
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
async def chat(
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
) -> dict:
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
deadline = asyncio.get_event_loop().time() + timeout
last_error = None
while asyncio.get_event_loop().time() < deadline:
try:
req = Request(
"http://localhost:52415/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda r=req: urlopen(r, timeout=300).read()
)
return json.loads(resp_bytes)
except Exception as e:
last_error = e
await asyncio.sleep(5)
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")

View File

@@ -1,20 +0,0 @@
services:
exo-node-1:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]
ports:
- "52415:52415"
exo-node-2:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]

View File

@@ -1,77 +0,0 @@
#!/usr/bin/env python3
"""Discovers and runs all E2E tests in e2e/test_*.py.
Tests with '# slow' on the first line of their docstring are skipped
unless --slow is passed or E2E_SLOW=1 is set.
"""
import os
import subprocess
import sys
from pathlib import Path
E2E_DIR = Path(__file__).parent.resolve()
def is_slow(test_file: Path) -> bool:
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
with open(test_file) as f:
for line in f:
if line.strip().startswith("#"):
continue
if line.strip().startswith('"""') or line.strip().startswith("'''"):
# Read into the docstring
for doc_line in f:
if "slow" in doc_line.lower() and doc_line.strip().startswith(
"slow"
):
return True
if '"""' in doc_line or "'''" in doc_line:
break
break
return False
def main():
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
if "--update-snapshots" in sys.argv:
os.environ["UPDATE_SNAPSHOTS"] = "1"
test_files = sorted(E2E_DIR.glob("test_*.py"))
if not test_files:
print("No test files found")
sys.exit(1)
passed = 0
failed = 0
skipped = 0
failures = []
for test_file in test_files:
name = test_file.stem
if is_slow(test_file) and not run_slow:
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
skipped += 1
continue
print(f"=== {name} ===")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
failed += 1
failures.append(name)
print()
total = passed + failed + skipped
print("================================")
print(
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
)
if failed:
print(f"Failed: {' '.join(failures)}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,69 +0,0 @@
"""Snapshot testing infrastructure for E2E tests.
Provides deterministic regression testing by comparing inference output
against saved snapshots. On first run, snapshots are created automatically.
Set UPDATE_SNAPSHOTS=1 to regenerate snapshots when output intentionally changes.
Snapshots are stored per-architecture (e.g. snapshots/x86_64/, snapshots/arm64/)
since floating-point results differ between CPU architectures.
"""
import difflib
import json
import os
import platform
from pathlib import Path
ARCH = platform.machine()
SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" / ARCH
def assert_snapshot(
name: str,
content: str,
metadata: dict,
) -> None:
"""Compare content against a saved snapshot, or create one if missing.
Args:
name: Snapshot identifier (used as filename: snapshots/{arch}/{name}.json).
content: The actual inference output to compare.
metadata: Additional context stored alongside content (model, seed, etc.).
Not used for comparison -- purely documentary.
Raises:
AssertionError: If content doesn't match the saved snapshot.
Environment:
UPDATE_SNAPSHOTS=1: Overwrite existing snapshot with actual content.
"""
snapshot_file = SNAPSHOTS_DIR / f"{name}.json"
update = os.environ.get("UPDATE_SNAPSHOTS") == "1"
if snapshot_file.exists() and not update:
snapshot = json.loads(snapshot_file.read_text())
expected = snapshot["content"]
if content != expected:
diff = "\n".join(
difflib.unified_diff(
expected.splitlines(),
content.splitlines(),
fromfile=f"expected ({snapshot_file.relative_to(SNAPSHOTS_DIR.parent.parent)})",
tofile="actual",
lineterm="",
)
)
raise AssertionError(
f"Snapshot mismatch for '{name}' on {ARCH}!\n\n"
f"{diff}\n\n"
f"Expected: {expected!r}\n"
f"Actual: {content!r}\n\n"
f"To update: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py"
)
print(f" Output matches snapshot ({ARCH}/{snapshot_file.name})")
else:
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
snapshot_data = {**metadata, "arch": ARCH, "content": content}
snapshot_file.write_text(json.dumps(snapshot_data, indent=2) + "\n")
action = "Updated" if update else "Created"
print(f" {action} snapshot: {ARCH}/{snapshot_file.name}")

View File

@@ -1,22 +0,0 @@
"""Test: Basic cluster formation.
Verifies two nodes discover each other, elect a master, and the API responds.
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster("cluster_formation") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print("PASSED: cluster_formation")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,60 +0,0 @@
"""Test: Deterministic inference output (snapshot test).
Sends a chat completion request with a fixed seed,
then verifies the output matches a known-good snapshot. This ensures
inference produces consistent results across runs.
Uses MLX CPU backend in Docker on x86 Linux.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "What is 2+2? Reply with just the number."
MAX_TOKENS = 32
async def main():
async with Cluster("inference_snapshot") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="inference_snapshot",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: inference_snapshot")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,47 +0,0 @@
"""Test: Cluster works without internet access.
Verifies exo functions correctly when containers can talk to each other
but cannot reach the internet. Uses iptables to block all outbound traffic
except private subnets and multicast (for mDNS discovery).
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster(
"no_internet",
overrides=["tests/no_internet/docker-compose.override.yml"],
) as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Verify internet is actually blocked from inside the containers
for node in ["exo-node-1", "exo-node-2"]:
rc, _ = await cluster.exec(
node,
"curl",
"-sf",
"--max-time",
"3",
"https://huggingface.co",
check=False,
)
assert rc != 0, f"{node} should not be able to reach the internet"
print(f" {node}: internet correctly blocked")
# Verify exo detected no internet connectivity
log = await cluster.logs()
assert "Internet connectivity: False" in log, "exo should detect no internet"
print(" exo correctly detected no internet connectivity")
print("PASSED: no_internet")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,58 +0,0 @@
"""Test: Code generation snapshot.
Verifies deterministic output for a code generation prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = (
"Write a Python function to reverse a string. Only output the code, no explanation."
)
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_code_gen") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_code_gen",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_code_gen")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,63 +0,0 @@
"""Test: Edge case snapshots.
Verifies deterministic output for edge-case prompts: single word input,
special characters, and unicode.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
MAX_TOKENS = 32
CASES = [
("edge_single_word", "Hi"),
("edge_special_chars", "What does 2 * (3 + 4) / 7 - 1 equal? Use <math> tags."),
("edge_unicode", "Translate 'hello' to Japanese, Chinese, and Korean."),
]
async def main():
async with Cluster("snapshot_edge") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
for snapshot_name, prompt in CASES:
print(f" [{snapshot_name}] Sending: {prompt!r}")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": prompt}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{snapshot_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": prompt,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_edge")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,56 +0,0 @@
"""Test: Longer output snapshot.
Verifies deterministic output with a higher max_tokens (128).
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "Explain how a binary search algorithm works."
MAX_TOKENS = 128
async def main():
async with Cluster("snapshot_long_output") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED}, max_tokens={MAX_TOKENS})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_long_output",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_long_output")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,72 +0,0 @@
"""Test: Multi-model snapshot tests.
slow
Verifies deterministic output across different model architectures to catch
model-specific regressions. Each model uses its own snapshot file.
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
SEED = 42
PROMPT = "What is the capital of France?"
MAX_TOKENS = 32
MODELS = [
"mlx-community/SmolLM2-135M-Instruct",
"mlx-community/Llama-3.2-1B-Instruct-4bit",
"mlx-community/gemma-2-2b-it-4bit",
]
async def main():
async with Cluster("snapshot_multi_model") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
for model in MODELS:
short_name = (
model.split("/")[-1].lower().replace("-", "_").replace(".", "_")
)
snapshot_name = f"snapshot_multi_{short_name}"
print(f" Launching model {model}...")
await cluster.place_model(model)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=model,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{short_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": model,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print(f" [{short_name}] PASSED")
print("PASSED: snapshot_multi_model")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,56 +0,0 @@
"""Test: Reasoning/math snapshot.
Verifies deterministic output for a simple reasoning prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "If I have 3 apples and give away 1, how many do I have? Think step by step."
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_reasoning") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_reasoning",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_reasoning")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,32 +0,0 @@
# Block all outbound internet traffic using iptables while preserving:
# - Multicast (224.0.0.0/4) for mDNS peer discovery
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
# - Loopback (127/8)
# Requires NET_ADMIN capability for iptables.
services:
exo-node-1:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v
exo-node-2:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v

View File

@@ -132,7 +132,7 @@ markers = [
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/SmolLM2-135M-Instruct"
n_layers = 30
hidden_size = 576
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "SmolLM2 135M"
capabilities = ["text"]
[storage_size]
in_bytes = 269060381

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/gemma-2-2b-it-4bit"
n_layers = 26
hidden_size = 2304
supports_tensor = false
tasks = ["TextGeneration"]
family = "gemma2"
quantization = "4bit"
base_model = "Gemma 2 2B"
capabilities = ["text"]
[storage_size]
in_bytes = 1492755242

View File

@@ -176,7 +176,7 @@ async def generate_chat_stream(
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ChatCompletionResponse:
) -> AsyncGenerator[str]:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
@@ -223,7 +223,7 @@ async def collect_chat_response(
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
yield ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
@@ -241,4 +241,5 @@ async def collect_chat_response(
finish_reason=finish_reason,
)
],
)
).model_dump_json()
return

View File

@@ -125,6 +125,7 @@ from exo.shared.types.commands import (
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
TextGeneration,
)
@@ -540,16 +541,14 @@ class API:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
@@ -644,11 +643,14 @@ class API:
"X-Accel-Buffering": "no",
},
)
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
else:
return StreamingResponse(
collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
@@ -664,8 +666,7 @@ class API:
command = TextGeneration(task_params=task_params)
await self._send(command)
response = await self._collect_text_generation_with_stats(command.command_id)
return response
return await self._collect_text_generation_with_stats(command.command_id)
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
@@ -883,6 +884,11 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -964,6 +970,11 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))

View File

@@ -24,6 +24,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
TextGeneration,
@@ -39,6 +40,7 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
TraceEventData,
TracesCollected,
TracesMerged,
@@ -279,7 +281,7 @@ class Master:
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
for cmd in cancel_unnecessary_downloads(
placement, self.state.downloads
@@ -299,7 +301,7 @@ class Master:
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case CreateInstance():
@@ -309,7 +311,7 @@ class Master:
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
@@ -319,6 +321,18 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -327,10 +341,9 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time

View File

@@ -22,9 +22,15 @@ from exo.shared.types.commands import (
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.events import (
Event,
InstanceCreated,
InstanceDeleted,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
@@ -186,6 +192,7 @@ def delete_instance(
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
tasks: Mapping[TaskId, Task],
) -> Sequence[Event]:
events: list[Event] = []
@@ -201,6 +208,18 @@ def get_transition_events(
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
for task in tasks.values():
if task.instance_id == instance_id and task.task_status in [
TaskStatus.Pending,
TaskStatus.Running,
]:
events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
events.append(
InstanceDeleted(
instance_id=instance_id,

View File

@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
target_instances = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 0
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 1
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 1

View File

@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -89,6 +93,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
runner_id: RunnerId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -87,6 +93,7 @@ Task = (
| LoadModel
| StartWarmup
| TextGeneration
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -64,8 +64,6 @@ from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -83,30 +81,6 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -591,3 +565,23 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -33,6 +33,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -224,15 +225,22 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
await runner.start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
runner.shutdown()
case CancelTask(
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -270,18 +278,18 @@ class Worker:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
await self._start_runner_task(modified_task)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
await self._start_runner_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:
await self.runners[
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
@@ -320,8 +328,6 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ConnectToGroup,
CreateRunner,
DownloadModel,
@@ -53,13 +54,14 @@ def plan(
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
_cancel_tasks(runners, tasks)
or _kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
)
@@ -270,7 +272,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -284,7 +286,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
received = len(input_chunk_buffer.get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -292,16 +294,33 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner_id, runner in runners.items():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id,
cancelled_task_id=task.task_id,
runner_id=runner_id,
)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,6 +15,7 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -38,7 +39,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -1,5 +1,6 @@
import base64
import json
import math
import resource
import time
from collections.abc import Generator
@@ -88,6 +89,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -112,6 +114,7 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
@@ -129,11 +132,15 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
model: Model | DistributedImageModel | None = None
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -146,6 +153,7 @@ def main(
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -191,7 +199,7 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
model, tokenizer = load_mlx_items(
inference_model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
@@ -203,7 +211,7 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
model = initialize_image_model(bound_instance)
image_model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -211,8 +219,6 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -224,16 +230,31 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
t = time.perf_counter()
toks = warmup_inference(
model=model,
model=inference_model,
tokenizer=tokenizer,
group=group,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / (time.perf_counter() - t)), 100
)
if group is not None:
check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]), group=group
)
).item()
)
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
@@ -241,8 +262,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -262,9 +283,9 @@ def main(
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert model and not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
assert check_for_cancel_every
try:
_check_for_debug_prompts(task_params)
@@ -274,7 +295,7 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
model=inference_model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
@@ -299,11 +320,11 @@ def main(
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
elif isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
model, GptOssModel
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -316,7 +337,18 @@ def main(
)
completion_tokens = 0
tokens_since_last_cancel_check = 0
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
tokens_since_last_cancel_check = 0
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
completion_tokens += 1
@@ -388,7 +420,7 @@ def main(
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -401,7 +433,9 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
@@ -451,7 +485,7 @@ def main(
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -464,7 +498,9 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
if _is_primary_output_node(shard_metadata):
match response:
case PartialImageResponse():
@@ -530,7 +566,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc

View File

@@ -47,9 +47,11 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -60,8 +62,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -69,6 +71,7 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -83,6 +86,7 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -97,6 +101,7 @@ class RunnerSupervisor:
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.close()
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
@@ -112,14 +117,6 @@ class RunnerSupervisor:
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(
@@ -141,6 +138,13 @@ class RunnerSupervisor:
return
await event.wait()
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
async def _forward_events(self):
with self._ev_recv as events:
try:

View File

@@ -2,6 +2,7 @@
from collections.abc import Iterable
from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.runner as mlx_runner
@@ -19,6 +20,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
TextGeneration,
)
@@ -113,6 +115,8 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mx.distributed, "all_gather", make_nothin(mx.array([1])))
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
@@ -163,6 +167,7 @@ def _run(tasks: Iterable[Task]):
)
task_sender, task_receiver = mp_channel[Task]()
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
event_sender = EventCollector()
with task_sender:
@@ -174,7 +179,7 @@ def _run(tasks: Iterable[Task]):
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
return event_sender.events