Compare commits

..

6 Commits

Author SHA1 Message Date
Sami Khan
bdb9fbc8c0 Merge branch 'main' into sami/flash 2026-01-14 08:10:51 +05:00
Sami Khan
8c7180810c type checking 2026-01-14 07:15:45 +05:00
Sami Khan
318c6e000b code cleanup 2026-01-14 04:56:59 +05:00
Jake Hillion
e4a01e2b0e chore(deps): nix lock file maintenance
Update nix flake inputs. Add a second input as Swift is currently broken
in nixpkgs on Linux for `swift-format` as we want `nix fmt` to continue
being reproducible everywhere.
2026-01-13 19:57:14 +01:00
Sami Khan
2d45544da0 use rsh server instead of ssh 2026-01-13 02:46:25 +05:00
Sami Khan
7cbafa768a flash+exo 2026-01-12 10:26:16 +05:00
27 changed files with 1154 additions and 1136 deletions

View File

@@ -276,23 +276,24 @@ class BatchGenerator:
logprobs: mx.array
finish_reason: Optional[str]
unprocessed_prompts: List[Any]
def __init__(
self,
model: nn.Module,
model,
max_tokens: int = ...,
stop_tokens: Optional[set[int]] = ...,
stop_tokens: Optional[set] = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
completion_batch_size: int = ...,
prefill_batch_size: int = ...,
prefill_step_size: int = ...,
) -> None: ...
def insert(
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
) -> List[int]: ...
def stats(self) -> BatchStats: ...
def next(self) -> List[Response]: ...
self, prompts, max_tokens: Union[List[int], int, None] = ...
): # -> list[Any]:
...
def stats(self): # -> BatchStats:
...
def next(self): # -> list[Any]:
...
def batch_generate(
model,

View File

@@ -39,18 +39,12 @@ class StreamingDetokenizer:
"""
__slots__ = ...
tokens: list[int]
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
@property
def text(self) -> str:
"""The full text decoded so far."""
...
@property
def last_segment(self) -> str:
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
...
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
@@ -114,7 +108,6 @@ class TokenizerWrapper:
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
eos_token_ids: list[int] | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int

41
flake.lock generated
View File

@@ -8,11 +8,11 @@
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1761893049,
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
"lastModified": 1768287139,
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
"owner": "nix-community",
"repo": "fenix",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
"type": "github"
},
"original": {
@@ -43,11 +43,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
"lastModified": 1768127708,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"type": "github"
},
"original": {
@@ -57,22 +57,39 @@
"type": "github"
}
},
"nixpkgs-swift": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
},
"original": {
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
}
},
"root": {
"inputs": {
"fenix": "fenix",
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1761849405,
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
"lastModified": 1768224240,
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
"rev": "725349602e525df37f377701e001fe8aab807878",
"type": "github"
},
"original": {
@@ -89,11 +106,11 @@
]
},
"locked": {
"lastModified": 1762938485,
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
"lastModified": 1768158989,
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
"owner": "numtide",
"repo": "treefmt-nix",
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
"type": "github"
},
"original": {

View File

@@ -18,6 +18,9 @@
url = "github:numtide/treefmt-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
};
nixConfig = {
@@ -39,9 +42,11 @@
];
perSystem =
{ config, inputs', pkgs, lib, ... }:
{ config, inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
treefmt = {
@@ -60,7 +65,10 @@
enable = true;
includes = [ "*.ts" ];
};
swift-format.enable = true;
swift-format = {
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
};
};

View File

@@ -29,6 +29,7 @@ dependencies = [
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-rsh = "exo.rsh.client:main"
# dependencies only required for development
[dependency-groups]

View File

@@ -15,6 +15,7 @@ import exo.routing.topics as topics
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
from exo.rsh.server import RSH_PORT, run_rsh_server
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
from exo.shared.logging import logger_cleanup, logger_setup
@@ -113,6 +114,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
# Start RSH server for remote execution (used by MPI)
tg.start_soon(run_rsh_server, RSH_PORT)
def shutdown(self):
# if this is our second call to shutdown, just sys.exit

View File

@@ -1,6 +1,6 @@
import time
from collections.abc import AsyncGenerator
from typing import cast
from typing import Any, cast
import anyio
from anyio import create_task_group
@@ -51,7 +51,9 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
LaunchFLASH,
PlaceInstance,
StopFLASH,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -60,7 +62,12 @@ from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
FLASHInstance,
Instance,
InstanceId,
InstanceMeta,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -178,6 +185,10 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
# FLASH simulation endpoints
self.app.post("/flash/launch")(self.launch_flash)
self.app.delete("/flash/{instance_id}")(self.stop_flash)
self.app.get("/flash/instances")(self.list_flash_instances)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -622,6 +633,86 @@ class API:
]
)
async def launch_flash(
self,
simulation_name: str,
flash_executable_path: str,
working_directory: str,
parameter_file_path: str = "",
ranks_per_node: int = 1,
min_nodes: int = 1,
hosts: str = "",
) -> dict[str, str]:
"""Launch a FLASH MPI simulation across the cluster.
Args:
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
If not provided, IPs are discovered from topology edges.
"""
command = LaunchFLASH(
simulation_name=simulation_name,
flash_executable_path=flash_executable_path,
parameter_file_path=parameter_file_path,
working_directory=working_directory,
ranks_per_node=ranks_per_node,
min_nodes=min_nodes,
hosts=hosts,
)
await self._send(command)
return {
"message": "FLASH launch command received",
"command_id": str(command.command_id),
"simulation_name": simulation_name,
}
async def stop_flash(self, instance_id: InstanceId) -> dict[str, str]:
"""Stop a running FLASH simulation."""
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
instance = self.state.instances[instance_id]
if not isinstance(instance, FLASHInstance):
raise HTTPException(
status_code=400, detail="Instance is not a FLASH simulation"
)
command = StopFLASH(instance_id=instance_id)
await self._send(command)
return {
"message": "Stop command received",
"command_id": str(command.command_id),
"instance_id": str(instance_id),
}
async def list_flash_instances(self) -> list[dict[str, Any]]:
"""List all FLASH simulation instances."""
flash_instances: list[dict[str, Any]] = []
for instance_id, instance in self.state.instances.items():
if isinstance(instance, FLASHInstance):
# Get runner statuses for this instance
runner_statuses: dict[str, str | None] = {}
for (
node_id,
runner_id,
) in instance.shard_assignments.node_to_runner.items():
runner_status = self.state.runners.get(runner_id)
runner_statuses[str(node_id)] = (
str(runner_status) if runner_status else None
)
flash_instances.append(
{
"instance_id": str(instance_id),
"simulation_name": instance.simulation_name,
"total_ranks": instance.total_ranks,
"working_directory": instance.working_directory,
"runner_statuses": runner_statuses,
}
)
return flash_instances
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -8,6 +8,7 @@ from exo.master.placement import (
add_instance_to_placements,
delete_instance,
get_transition_events,
place_flash_instance,
place_instance,
)
from exo.shared.apply import apply
@@ -16,8 +17,10 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
LaunchFLASH,
PlaceInstance,
RequestEventLog,
StopFLASH,
TaskFinished,
TestCommand,
)
@@ -173,6 +176,26 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case LaunchFLASH():
placement = place_flash_instance(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case StopFLASH():
# Reuse delete_instance logic to stop FLASH simulation
placement = delete_instance(
DeleteInstance(instance_id=command.instance_id),
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -17,20 +17,24 @@ from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
LaunchFLASH,
PlaceInstance,
)
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
FLASHInstance,
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
def random_ephemeral_port() -> int:
@@ -165,6 +169,9 @@ def place_instance(
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
case InstanceMeta.FLASH:
# FLASH instances are handled by place_flash_instance()
raise ValueError("FLASH instances should use place_flash_instance()")
return target_instances
@@ -180,6 +187,148 @@ def delete_instance(
raise ValueError(f"Instance {command.instance_id} not found")
def place_flash_instance(
command: LaunchFLASH,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
"""Place a FLASH simulation instance across available nodes.
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
FLASH instances use MPI for communication. We just need to provide the
node IPs so the runner can generate an MPI hostfile.
"""
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
all_nodes = list(topology.list_nodes())
if len(all_nodes) < command.min_nodes:
raise ValueError(
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
)
# Select nodes (take the first min_nodes)
selected_nodes = all_nodes[: command.min_nodes]
logger.info(
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
)
# Build shard assignments (one runner per node for FLASH)
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
flash_model_meta = ModelMetadata(
model_id=ModelId(command.simulation_name),
pretty_name=f"FLASH: {command.simulation_name}",
storage_size=Memory(in_bytes=0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
)
for i, node_info in enumerate(selected_nodes):
runner_id = RunnerId()
node_to_runner[node_info.node_id] = runner_id
runner_to_shard[runner_id] = PipelineShardMetadata(
device_rank=i,
world_size=len(selected_nodes),
model_meta=flash_model_meta,
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=ModelId(command.simulation_name),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
hosts_by_node: dict[NodeId, list[Host]] = {}
# If explicit hosts are provided, use them directly
if command.hosts:
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
for i, node_info in enumerate(selected_nodes):
if i < len(explicit_hosts):
hosts_by_node[node_info.node_id] = [Host(ip=explicit_hosts[i], port=0)]
logger.info(
f"FLASH placement: node {node_info.node_id} (rank {i}) -> IP {explicit_hosts[i]}"
)
else:
logger.warning(
f"Not enough hosts provided for node {i}, using localhost"
)
hosts_by_node[node_info.node_id] = [Host(ip="127.0.0.1", port=0)]
logger.info(
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
)
else:
# Try to get IPs from topology edges
for node_info in selected_nodes:
node_hosts: list[Host] = []
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
for _, edge_data in topology.out_edges(node_info.node_id):
if hasattr(edge_data, "send_back_multiaddr"):
# Extract IP from multiaddr like /ip4/192.168.1.100/tcp/52415
multiaddr = str(edge_data.send_back_multiaddr)
if "/ip4/" in multiaddr:
parts = multiaddr.split("/")
try:
ip_idx = parts.index("ip4") + 1
ip = parts[ip_idx]
# Skip link-local and localhost addresses
if not ip.startswith("169.254.") and not ip.startswith(
"127."
):
node_hosts.append(Host(ip=ip, port=0))
break
except (ValueError, IndexError):
pass
# Last resort: use localhost (will only work for single-node)
if not node_hosts:
logger.warning(
f"Could not determine IP for node {node_info.node_id}, using localhost"
)
node_hosts.append(Host(ip="127.0.0.1", port=0))
hosts_by_node[node_info.node_id] = node_hosts
total_ranks = len(selected_nodes) * command.ranks_per_node
# Determine coordinator IP - first node's first host IP
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
coordinator_ip: str = (
hosts_by_node[first_node_id][0].ip
if hosts_by_node[first_node_id]
else "127.0.0.1"
)
target_instances[instance_id] = FLASHInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
flash_executable_path=command.flash_executable_path,
parameter_file_path=command.parameter_file_path,
working_directory=command.working_directory,
ranks_per_node=command.ranks_per_node,
total_ranks=total_ranks,
simulation_name=command.simulation_name,
coordinator_ip=coordinator_ip,
)
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
return target_instances
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],

13
src/exo/rsh/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Exo RSH - Remote Shell for MPI without SSH.
This module provides a remote execution mechanism that allows mpirun to spawn
processes on remote nodes without requiring SSH setup. It works by:
1. Each Exo node runs a small HTTP server (RSH server) on port 52416
2. The exo-rsh script acts as a drop-in replacement for ssh
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target
4. The target executes the command and streams output back
Usage:
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
"""

99
src/exo/rsh/client.py Normal file
View File

@@ -0,0 +1,99 @@
#!/usr/bin/env python3
"""exo-rsh - Remote shell client for MPI.
This script is called by mpirun as a replacement for ssh.
Usage: exo-rsh [ssh-options...] hostname command [args...]
It connects to the target node's RSH server (port 52416) and executes the command.
"""
import json
import socket
import sys
from typing import Any, cast
from urllib.error import URLError
from urllib.request import Request, urlopen
RSH_PORT = 52416
def resolve_hostname(hostname: str) -> str:
"""Resolve hostname to IP address."""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
# If resolution fails, try using the hostname directly
return hostname
def main():
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
# SSH options we might see: -x (disable X11), -o options, etc.
args = sys.argv[1:]
# Skip SSH-style options
hostname = None
command_start = 0
i = 0
while i < len(args):
arg = args[i]
if arg.startswith("-"):
# Skip option and its value if needed
if arg in ("-o", "-i", "-l", "-p", "-F"):
i += 2 # Skip option and its argument
continue
i += 1
continue
else:
# First non-option is the hostname
hostname = arg
command_start = i + 1
break
i += 1
if hostname is None or command_start >= len(args):
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
sys.exit(1)
command = args[command_start:]
# Resolve hostname to IP
ip = resolve_hostname(hostname)
# Make request to RSH server
url = f"http://{ip}:{RSH_PORT}/execute"
data = json.dumps({"command": command}).encode("utf-8")
try:
req = Request(url, data=data, headers={"Content-Type": "application/json"})
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
# Output stdout/stderr
stdout: str = cast(str, result.get("stdout", ""))
stderr: str = cast(str, result.get("stderr", ""))
exit_code: int = cast(int, result.get("exit_code", 0))
if stdout:
sys.stdout.write(stdout)
sys.stdout.flush()
if stderr:
sys.stderr.write(stderr)
sys.stderr.flush()
sys.exit(exit_code)
except URLError as e:
print(
f"exo-rsh: Failed to connect to {hostname}:{RSH_PORT}: {e}", file=sys.stderr
)
sys.exit(255)
except Exception as e:
print(f"exo-rsh: Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

154
src/exo/rsh/server.py Normal file
View File

@@ -0,0 +1,154 @@
"""RSH Server - runs on each Exo node to accept remote execution requests."""
import asyncio
from typing import Optional
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from loguru import logger
from pydantic import BaseModel
RSH_PORT = 52416
class ExecuteRequest(BaseModel):
"""Request to execute a command."""
command: list[str]
cwd: Optional[str] = None
env: Optional[dict[str, str]] = None
class ExecuteResponse(BaseModel):
"""Response from command execution."""
exit_code: int
stdout: str
stderr: str
def create_rsh_app() -> FastAPI:
"""Create the RSH FastAPI application."""
app = FastAPI(title="Exo RSH Server")
@app.get("/health")
async def health(): # pyright: ignore[reportUnusedFunction]
"""Health check endpoint."""
return {"status": "ok"}
@app.post("/execute")
async def execute(request: ExecuteRequest) -> ExecuteResponse: # pyright: ignore[reportUnusedFunction]
"""Execute a command and return the result."""
cmd_str = " ".join(request.command)
logger.info(f"RSH executing: {cmd_str}")
try:
# Build environment
import os
env = os.environ.copy()
if request.env:
env.update(request.env)
# Check if command contains shell metacharacters (semicolons, pipes, etc.)
# If so, run through shell. mpirun sends complex commands like:
# "VAR=value;export VAR;/path/to/prted --args"
needs_shell = any(c in cmd_str for c in ";|&$`")
if needs_shell:
# Run through shell
process = await asyncio.create_subprocess_shell(
cmd_str,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
else:
# Execute directly
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
stdout, stderr = await process.communicate()
exit_code = process.returncode or 0
logger.info(f"RSH command completed with exit code {exit_code}")
return ExecuteResponse(
exit_code=exit_code,
stdout=stdout.decode("utf-8", errors="replace"),
stderr=stderr.decode("utf-8", errors="replace"),
)
except FileNotFoundError as e:
logger.error(f"RSH command not found: {e}")
return ExecuteResponse(
exit_code=127,
stdout="",
stderr=f"Command not found: {request.command[0]}",
)
except Exception as e:
logger.error(f"RSH execution error: {e}")
return ExecuteResponse(
exit_code=1,
stdout="",
stderr=str(e),
)
@app.post("/execute_streaming")
async def execute_streaming(request: ExecuteRequest): # pyright: ignore[reportUnusedFunction]
"""Execute a command and stream the output."""
logger.info(f"RSH streaming execute: {' '.join(request.command)}")
async def stream_output():
try:
env = None
if request.env:
import os
env = os.environ.copy()
env.update(request.env)
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
cwd=request.cwd,
env=env,
)
if process.stdout:
async for line in process.stdout:
yield line
await process.wait()
except Exception as e:
yield f"Error: {e}\n".encode()
return StreamingResponse(
stream_output(),
media_type="application/octet-stream",
)
return app
async def run_rsh_server(port: int = RSH_PORT):
"""Run the RSH server."""
import anyio
app = create_rsh_app()
config = Config()
config.bind = [f"0.0.0.0:{port}"]
config.accesslog = None # Disable access logs for cleaner output
logger.info(f"Starting RSH server on port {port}")
await serve(app, config, shutdown_trigger=lambda: anyio.sleep_forever()) # type: ignore

View File

@@ -35,6 +35,26 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class LaunchFLASH(BaseCommand):
"""Command to launch a FLASH MPI simulation."""
simulation_name: str
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
min_nodes: int = 1
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
# Used when topology edges don't contain IP addresses
hosts: str = ""
class StopFLASH(BaseCommand):
"""Command to stop a running FLASH simulation."""
instance_id: InstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -50,6 +70,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| LaunchFLASH
| StopFLASH
| TaskFinished
)

View File

@@ -14,6 +14,7 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
FLASH = "FLASH"
class BaseInstance(TaggedModel):
@@ -34,8 +35,27 @@ class MlxJacclInstance(BaseInstance):
jaccl_coordinators: dict[NodeId, str]
class FLASHInstance(BaseInstance):
"""Instance for FLASH MPI simulation.
Unlike MLX instances which do tensor parallelism, FLASH instances
coordinate MPI processes across nodes. Each node runs one or more
MPI ranks of the FLASH simulation.
"""
hosts_by_node: dict[NodeId, list[Host]]
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
total_ranks: int
simulation_name: str
coordinator_ip: str
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
class BoundInstance(CamelCaseModel):

View File

@@ -50,9 +50,7 @@ class RunnerReady(BaseRunnerStatus):
class RunnerRunning(BaseRunnerStatus):
"""Runner is processing requests and can accept more (continuous batching)."""
active_requests: int = 0
pass
class RunnerShuttingDown(BaseRunnerStatus):

View File

@@ -1,251 +0,0 @@
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
import time
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import MAX_TOKENS
from exo.worker.engines.mlx.generator.distributed_sync import share_object
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.runner.bootstrap import logger
@dataclass
class ActiveRequest:
"""Tracks an active request in the batch."""
command_id: CommandId
task_id: TaskId
uid: int # BatchGenerator's internal ID
detokenizer: StreamingDetokenizer
tokens_generated: int = 0
prompt_tokens: int = 0
start_time: float = field(default_factory=time.perf_counter)
@dataclass
class BatchedGenerationResponse:
"""Response from batch engine, tagged with command_id and task_id."""
command_id: CommandId
task_id: TaskId
response: GenerationResponse
class BatchGenerationEngine:
"""Manages continuous batching using mlx_lm's BatchGenerator."""
def __init__(
self,
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
max_tokens: int = MAX_TOKENS,
completion_batch_size: int = 32,
prefill_batch_size: int = 8,
prefill_step_size: int = 2048,
):
self.model = model
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.active_requests: dict[int, ActiveRequest] = {}
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
self.group = group
self.rank = group.rank() if group else 0
self.is_distributed = group is not None and group.size() > 1
sampler = make_sampler(temp=0.7, top_p=1.0)
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
self.batch_gen: BatchGenerator = BatchGenerator(
model=model,
max_tokens=max_tokens,
stop_tokens=eos_tokens,
sampler=sampler,
completion_batch_size=completion_batch_size,
prefill_batch_size=prefill_batch_size,
prefill_step_size=prefill_step_size,
)
logger.info(
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
)
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion. Only rank 0 should call this.
In distributed mode, rank 0 receives tasks from the control plane and
queues them here. The actual insertion happens in sync_and_insert_pending()
which ensures all ranks insert the same requests together.
"""
assert self.rank == 0, "Only rank 0 should queue requests"
self._pending_inserts.append((command_id, task_id, task_params))
logger.info(f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})")
def sync_and_insert_pending(self) -> list[int]:
"""Sync pending inserts across ranks and insert them. Returns UIDs.
This method ensures all ranks insert the same requests in the same order.
In non-distributed mode, it simply inserts all pending requests.
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
"""
inserts_to_process: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]]
if not self.is_distributed:
# Non-distributed: just insert directly from pending
inserts_to_process = list(self._pending_inserts)
else:
# Distributed: broadcast pending inserts from rank 0 to all ranks
assert self.group is not None
pending_data = self._pending_inserts if self.rank == 0 else None
synced_data = share_object(pending_data, self.rank, self.group)
if synced_data is None:
self._pending_inserts.clear()
return []
inserts_to_process = synced_data
uids: list[int] = []
for cmd_id, task_id, params in inserts_to_process:
uid = self._do_insert(cmd_id, task_id, params)
uids.append(uid)
self._pending_inserts.clear()
return uids
def _do_insert(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> int:
"""Actually insert a request into BatchGenerator. No sync - called after sync."""
prompt_str = apply_chat_template(self.tokenizer, task_params)
tokens: list[int] = self.tokenizer.encode(prompt_str, add_special_tokens=False)
prompt_tokens = len(tokens)
max_tokens = task_params.max_tokens or self.max_tokens
uids = self.batch_gen.insert([tokens], max_tokens=[max_tokens])
uid = uids[0]
detokenizer = self.tokenizer.detokenizer
self.active_requests[uid] = ActiveRequest(
command_id=command_id,
task_id=task_id,
uid=uid,
detokenizer=detokenizer,
prompt_tokens=prompt_tokens,
)
logger.info(f"Inserted request {command_id} with uid={uid}, prompt_tokens={prompt_tokens}, max_tokens={max_tokens}")
return uid
def step(self) -> list[BatchedGenerationResponse]:
"""Run one decode step. Syncs completed UIDs across ranks if distributed."""
responses = self.batch_gen.next()
if not responses:
return []
results: list[BatchedGenerationResponse] = []
uids_to_remove: list[int] = []
for r in responses:
uid: int = r.uid
req = self.active_requests.get(uid)
if req is None:
logger.warning(f"Received response for unknown uid={uid}")
continue
req.tokens_generated += 1
# Decode the token
token: int = r.token
req.detokenizer.add_token(token)
text: str = req.detokenizer.last_segment
stats: GenerationStats | None = None
finish_reason: FinishReason | None = None
raw_finish_reason: str | None = r.finish_reason
if raw_finish_reason is not None:
# Finalize to get remaining text
req.detokenizer.finalize()
text = req.detokenizer.last_segment
elapsed = time.perf_counter() - req.start_time
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
stats = GenerationStats(
prompt_tps=0.0, # Not tracked per-request in batch mode
generation_tps=generation_tps,
prompt_tokens=req.prompt_tokens,
generation_tokens=req.tokens_generated,
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
)
if raw_finish_reason == "stop":
finish_reason = "stop"
elif raw_finish_reason == "length":
finish_reason = "length"
else:
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
finish_reason = "stop"
uids_to_remove.append(uid) # Sync before removal
logger.info(f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}")
results.append(BatchedGenerationResponse(
command_id=req.command_id,
task_id=req.task_id,
response=GenerationResponse(text=text, token=token, finish_reason=finish_reason, stats=stats),
))
# Sync completed UIDs across ranks before removing
if self.is_distributed and uids_to_remove:
assert self.group is not None
uids_to_remove = share_object(uids_to_remove if self.rank == 0 else None, self.rank, self.group) or []
for uid in uids_to_remove:
if uid in self.active_requests:
del self.active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
@property
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def active_count(self) -> int:
return len(self.active_requests)
@property
def pending_count(self) -> int:
return len(self.batch_gen.unprocessed_prompts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)

View File

@@ -1,73 +0,0 @@
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
# pyright: reportAny=false
import pickle
from enum import IntEnum
from typing import TypeVar, cast
import mlx.core as mx
T = TypeVar("T")
class DistributedOp(IntEnum):
"""Operation codes for distributed synchronization.
Used to ensure all ranks in a distributed setup execute the same
operations in the same order, preventing collective mismatches.
"""
NOOP = 0 # No operation - rank 0 will block waiting for a task
INSERT = 1 # Insert pending requests into the batch
STEP = 2 # Run a decode step
SHUTDOWN = 3 # All ranks should exit
def sync_operation(
op: DistributedOp | None,
rank: int,
group: mx.distributed.Group,
) -> DistributedOp:
"""Broadcast operation code from rank 0 to all ranks.
This ensures all ranks execute the same operation, preventing
collective mismatches that would cause deadlocks.
Args:
op: The operation to perform (only rank 0's value is used)
rank: This process's rank
group: The distributed group
Returns:
The operation that all ranks should execute
"""
if rank == 0:
assert op is not None, "Rank 0 must provide an operation"
code = mx.array([int(op)], dtype=mx.int32)
else:
code = mx.array([0], dtype=mx.int32)
result = mx.distributed.all_sum(code, group=group)
mx.eval(result)
return DistributedOp(int(result.item()))
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
if rank == 0:
if obj is None:
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
return None
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
mx.eval(mx.distributed.all_sum(data, group=group))
return obj
else:
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
if size == 0:
return None
data = mx.zeros(size, dtype=mx.uint8)
data = mx.distributed.all_sum(data, group=group)
mx.eval(data)
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))

View File

@@ -164,6 +164,11 @@ def mlx_distributed_init(
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
case _:
raise ValueError(
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
)
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group

View File

@@ -21,7 +21,12 @@ from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.instances import (
BoundInstance,
FLASHInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -50,6 +55,11 @@ def plan(
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
) -> Task | None:
# Check for FLASH instance tasks first
flash_task = _plan_flash(runners, instances)
if flash_task is not None:
return flash_task
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
@@ -62,6 +72,34 @@ def plan(
)
def _plan_flash(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
"""Plan tasks specifically for FLASH instances.
FLASH instances have a simpler lifecycle:
- CreateRunner (handled by _create_runner)
- LoadModel (starts the simulation immediately)
- Shutdown (handled by _kill_runner)
This function handles the LoadModel step for FLASH instances,
skipping the MLX-specific download/init/warmup steps.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
# Only handle FLASH instances
if not isinstance(instance, FLASHInstance):
continue
# If runner is idle, emit LoadModel to start the simulation
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
return None
def _kill_runner(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -114,6 +152,10 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
# FLASH instances don't need model downloads
if isinstance(runner.bound_instance.instance, FLASHInstance):
continue
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
@@ -277,14 +319,12 @@ def _pending_tasks(
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
if task.task_id in runner.completed or task.task_id in runner.pending:
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

@@ -4,7 +4,11 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import (
BoundInstance,
FLASHInstance,
MlxJacclInstance,
)
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -17,20 +21,27 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
global logger
logger = _logger
# Import main after setting global logger - this lets us just import logger from this module
# Route based on instance type
try:
from exo.worker.runner.runner import main
if isinstance(bound_instance.instance, FLASHInstance):
# FLASH MPI simulation runner
from exo.worker.runner.flash_runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver)
else:
# MLX runner (default)
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -0,0 +1,301 @@
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
Exo-native distributed MPI:
- Exo handles node discovery and coordination
- Coordinator generates hostfile from Exo topology
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
- Each Exo node runs an RSH server on port 52416 for remote execution
- Workers just report ready and wait
"""
import os
import shutil
import socket
import subprocess
import threading
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
LoadModel,
Shutdown,
Task,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.bootstrap import logger
# Find mpirun in PATH, fallback to common locations
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
# exo-rsh is installed as console script by exo package
_exo_rsh_path = shutil.which("exo-rsh")
if not _exo_rsh_path:
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
EXO_RSH_PATH: str = _exo_rsh_path
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
"""Determine this node's rank based on position in hosts_by_node."""
for i, node_id in enumerate(instance.hosts_by_node.keys()):
if str(node_id) == str(my_node_id):
return i
return -1
def get_coordinator_host(instance: FLASHInstance) -> str:
"""Get the IP of the coordinator node."""
return instance.coordinator_ip
def resolve_host(host: str) -> str:
"""Resolve host string to a usable hostname for MPI hostfile.
Accepts either an IP address or hostname. For IPs, attempts to resolve
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
"""
# Check if input is already a hostname (not an IP)
try:
socket.inet_aton(host)
is_ip = True
except socket.error:
is_ip = False
if not is_ip:
# Already a hostname, verify it resolves and return as-is
try:
socket.gethostbyname(host)
return host
except socket.gaierror:
logger.warning(f"Hostname {host} does not resolve, using anyway")
return host
# It's an IP address, try to resolve to hostname
try:
hostname, _, _ = socket.gethostbyaddr(host)
hostname = hostname.split(".")[0]
logger.info(f"Resolved {host} to {hostname}")
return hostname
except socket.herror:
pass
# Fall back to IP
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
return host
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
"""Generate MPI hostfile from instance topology."""
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
with open(hostfile_path, "w") as f:
for _node_id, hosts in instance.hosts_by_node.items():
if hosts:
host = resolve_host(hosts[0].ip)
f.write(f"{host} slots={instance.ranks_per_node}\n")
logger.info(f"Generated hostfile at {hostfile_path}")
with open(hostfile_path, "r") as f:
logger.info(f"Hostfile contents:\n{f.read()}")
return hostfile_path
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
):
"""Main FLASH runner loop.
Coordinator: generates hostfile and runs mpirun (which SSHs to workers)
Workers: just report ready and wait for mpirun to spawn processes on them
"""
assert isinstance(bound_instance.instance, FLASHInstance)
instance = bound_instance.instance
runner_id = bound_instance.bound_runner_id
my_node_id = str(bound_instance.bound_node_id)
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
my_rank = get_my_rank(instance, my_node_id)
world_size = len(instance.hosts_by_node)
is_coordinator = my_rank == 0
coordinator_ip = get_coordinator_host(instance)
logger.info(
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
)
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
process: subprocess.Popen[bytes] | None = None
current_status: RunnerStatus = RunnerIdle()
shutdown_requested = False
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
"""Monitor FLASH stdout for progress updates."""
if proc.stdout is None:
return
for line in iter(proc.stdout.readline, b""):
if shutdown_requested:
break
try:
decoded: str = line.decode("utf-8", errors="replace").strip()
if decoded:
logger.info(f"[FLASH] {decoded}")
except Exception as e:
logger.warning(f"Error parsing FLASH output: {e}")
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(current_status, RunnerIdle):
current_status = RunnerLoading()
logger.info("Starting FLASH simulation")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
if is_coordinator:
# Coordinator: generate hostfile and run mpirun
hostfile = generate_hostfile(
instance, instance.working_directory
)
iface = instance.network_interface
cmd = [
MPIRUN_PATH,
"-np",
str(instance.total_ranks),
"--hostfile",
hostfile,
"--wdir",
instance.working_directory,
"--oversubscribe",
"--mca",
"btl",
"tcp,self",
"--mca",
"btl_tcp_if_include",
iface,
"--mca",
"oob_tcp_if_include",
iface,
"--mca",
"plm_rsh_no_tree_spawn",
"1",
]
# Use exo-rsh for remote execution (no SSH needed)
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
cmd.append(instance.flash_executable_path)
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=instance.working_directory,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
monitor_thread = threading.Thread(
target=monitor_output, args=(process,), daemon=True
)
monitor_thread.start()
current_status = RunnerRunning()
logger.info(
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
)
else:
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
logger.info(
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
)
current_status = RunnerRunning()
except Exception as e:
logger.error(f"Failed to start FLASH: {e}")
import traceback
logger.error(traceback.format_exc())
current_status = RunnerFailed(error_message=str(e))
case Shutdown():
shutdown_requested = True
current_status = RunnerShuttingDown()
logger.info("FLASH runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if process and process.poll() is None:
logger.info("Terminating FLASH simulation")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("FLASH didn't terminate, killing")
process.kill()
process.wait()
current_status = RunnerShutdown()
case _:
if process and process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info("FLASH simulation completed successfully")
current_status = RunnerReady()
else:
logger.error(
f"FLASH simulation failed with code {exit_code}"
)
current_status = RunnerFailed(
error_message=f"Exit code {exit_code}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break
if process and process.poll() is None:
process.terminate()
process.wait(timeout=5)
logger.info("FLASH runner exiting")

View File

@@ -1,8 +1,6 @@
import gc
import time
import mlx.core as mx
from anyio import WouldBlock
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
@@ -23,6 +21,9 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -38,9 +39,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.generator.batch_engine import BatchGenerationEngine
from exo.worker.engines.mlx.generator.distributed_sync import DistributedOp, sync_operation
from exo.worker.engines.mlx.generator.generate import warmup_inference
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
@@ -49,26 +48,6 @@ from exo.worker.engines.mlx.utils_mlx import (
from exo.worker.runner.bootstrap import logger
def _determine_distributed_op(
batch_engine: BatchGenerationEngine | None,
pending_shutdown: Task | None,
should_shutdown: bool,
) -> DistributedOp:
"""Determine what operation to perform next in distributed mode.
Only rank 0's result matters - this gets broadcast to all ranks via sync_operation.
"""
if should_shutdown:
return DistributedOp.SHUTDOWN
if pending_shutdown is not None and batch_engine is not None and not batch_engine.has_active_requests:
return DistributedOp.SHUTDOWN
if batch_engine is not None and batch_engine.has_pending_inserts:
return DistributedOp.INSERT
if batch_engine is not None and batch_engine.has_active_requests:
return DistributedOp.STEP
return DistributedOp.NOOP
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -90,247 +69,142 @@ def main(
model = None
tokenizer = None
group = None
batch_engine: BatchGenerationEngine | None = None
pending_shutdown: Shutdown | None = None
current_status: RunnerStatus = RunnerIdle()
def send_status(status: RunnerStatus) -> None:
event_sender.send(RunnerStatusUpdated(runner_id=runner_id, runner_status=status))
logger.info("runner created")
send_status(current_status)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
def handle_task(task: Task, is_deferred: bool = False) -> bool:
nonlocal current_status, model, tokenizer, group, batch_engine, pending_shutdown
logger.info("runner connected")
current_status = RunnerConnected()
# For Shutdown, check if we need to defer BEFORE sending Running/Acknowledged
if isinstance(task, Shutdown) and not is_deferred:
if batch_engine is not None and (batch_engine.has_active_requests or batch_engine.has_pending_inserts):
logger.info("deferring shutdown until active requests complete")
pending_shutdown = task
return True
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running))
event_sender.send(TaskAcknowledged(task_id=task.task_id))
model, tokenizer = load_mlx_items(bound_instance, group)
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
send_status(current_status)
group = initialize_mlx(bound_instance)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
logger.info("runner connected")
current_status = RunnerConnected()
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
send_status(current_status)
case LoadModel() if (
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
send_status(current_status)
model, tokenizer = load_mlx_items(bound_instance, group)
current_status = RunnerLoaded()
logger.info("runner loaded")
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
send_status(current_status)
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model is not None
assert tokenizer is not None
current_status = RunnerWarmingUp()
logger.info("runner warming up")
send_status(current_status)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(model=model, tokenizer=tokenizer)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(f"runner initialized in {time.time() - setup_start_time} seconds")
batch_engine = BatchGenerationEngine(model=model, tokenizer=tokenizer, group=group)
current_status = RunnerReady()
logger.info("runner ready")
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
send_status(current_status)
case ChatCompletion(
task_params=task_params, command_id=command_id
) if isinstance(current_status, (RunnerReady, RunnerRunning)):
assert batch_engine is not None
if task_params.messages and task_params.messages[0].content is not None:
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Queue the request - actual insertion happens in sync_and_insert_pending()
# In distributed mode, only rank 0 receives tasks from control plane
batch_engine.queue_request(command_id=command_id, task_id=task.task_id, task_params=task_params)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
# case TokenizedResponse():
# TODO: something here ig
# Status will be updated after actual insertion in the main loop
# For now, set to RunnerRunning to indicate we're processing
current_status = RunnerRunning(active_requests=batch_engine.active_count + batch_engine.pending_insert_count)
send_status(current_status)
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
mx.clear_cache()
import gc
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
send_status(current_status)
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
current_status = RunnerShutdown()
send_status(current_status)
return False
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
return True
with task_receiver as tasks:
running = True
is_rank_0 = shard_metadata.device_rank == 0
is_distributed = group is not None and group.size() > 1
while running:
if is_distributed:
assert group is not None
assert batch_engine is not None
# Distributed mode: synchronize operations across all ranks
# Step 1: Only rank 0 checks for tasks and determines operation
should_shutdown = False
if is_rank_0:
while True:
try:
task = tasks.receive_nowait()
task_result = handle_task(task)
if not task_result:
should_shutdown = True
break
except WouldBlock:
break
op = _determine_distributed_op(batch_engine, pending_shutdown, should_shutdown)
else:
op = None
# Step 2: Sync operation across all ranks
synced_op = sync_operation(op, shard_metadata.device_rank, group)
# Step 3: All ranks execute the same operation
match synced_op:
case DistributedOp.INSERT:
batch_engine.sync_and_insert_pending()
if is_rank_0:
current_status = RunnerRunning(active_requests=batch_engine.active_count)
send_status(current_status)
case DistributedOp.STEP:
for resp in batch_engine.step():
if is_rank_0:
event_sender.send(ChunkGenerated(
command_id=resp.command_id,
chunk=TokenChunk(
idx=resp.response.token,
model=shard_metadata.model_meta.model_id,
text=resp.response.text,
token_id=resp.response.token,
finish_reason=resp.response.finish_reason,
stats=resp.response.stats,
),
))
if resp.response.finish_reason is not None:
if is_rank_0:
event_sender.send(TaskStatusUpdated(task_id=resp.task_id, task_status=TaskStatus.Complete))
if is_rank_0:
if batch_engine.has_active_requests:
current_status = RunnerRunning(active_requests=batch_engine.active_count)
else:
current_status = RunnerReady()
send_status(current_status)
case DistributedOp.SHUTDOWN:
running = False
if is_rank_0 and pending_shutdown is not None:
handle_task(pending_shutdown, is_deferred=True)
case DistributedOp.NOOP:
# No work to do - all ranks poll together
# We can't have rank 0 block while others try to sync
if is_rank_0:
try:
task = tasks.receive_nowait()
task_result = handle_task(task)
if not task_result:
# Will sync SHUTDOWN on next iteration
pass
except WouldBlock:
pass
# All ranks: short sleep before looping back to sync
time.sleep(0.001)
else:
# Non-distributed mode: original logic with queue + insert
while True:
try:
task = tasks.receive_nowait()
running = handle_task(task)
if not running:
break
except WouldBlock:
break
if not running:
break
# Insert any queued requests (non-distributed just inserts directly)
# Status was already sent in handle_task when queueing
if batch_engine is not None and batch_engine.has_pending_inserts:
batch_engine.sync_and_insert_pending()
if batch_engine is not None and batch_engine.has_active_requests:
for resp in batch_engine.step():
if shard_metadata.device_rank == 0:
event_sender.send(ChunkGenerated(
command_id=resp.command_id,
chunk=TokenChunk(
idx=resp.response.token,
model=shard_metadata.model_meta.model_id,
text=resp.response.text,
token_id=resp.response.token,
finish_reason=resp.response.finish_reason,
stats=resp.response.stats,
),
))
if resp.response.finish_reason is not None:
event_sender.send(TaskStatusUpdated(task_id=resp.task_id, task_status=TaskStatus.Complete))
if batch_engine.has_active_requests:
current_status = RunnerRunning(active_requests=batch_engine.active_count)
else:
current_status = RunnerReady()
send_status(current_status)
# Process deferred shutdown after all requests complete
if pending_shutdown is not None and not batch_engine.has_active_requests and not batch_engine.has_pending_inserts:
running = handle_task(pending_shutdown, is_deferred=True)
else:
task = tasks.receive()
running = handle_task(task)
# Cleanup
del model, tokenizer, group, batch_engine
mx.clear_cache()
gc.collect()
gc.collect()
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"

View File

@@ -105,7 +105,7 @@ class RunnerSupervisor:
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown successfully, terminating")
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
@@ -128,11 +128,9 @@ class RunnerSupervisor:
async def start_task(self, task: Task):
if task.task_id in self.completed:
logger.info(f"Skipping task {task.task_id} - already completed")
return
if task.task_id in self.pending:
logger.info(f"Skipping task {task.task_id} - already pending")
return
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -151,17 +149,13 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
# Just set the event to unblock start_task, but keep in pending
# to prevent duplicate forwarding until completion
if event.task_id in self.pending:
self.pending[event.task_id].set()
self.pending.pop(event.task_id).set()
continue
if isinstance(event, TaskStatusUpdated) and event.task_status in (
TaskStatus.Complete,
TaskStatus.TimedOut,
TaskStatus.Failed,
if (
isinstance(event, TaskStatusUpdated)
and event.task_status == TaskStatus.Complete
):
# If a task has just finished, we should be working on it.
# If a task has just been completed, we should be working on it.
assert isinstance(
self.status,
(
@@ -172,8 +166,6 @@ class RunnerSupervisor:
RunnerShuttingDown,
),
)
# Now safe to remove from pending and add to completed
self.pending.pop(event.task_id, None)
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:

View File

@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask):

View File

@@ -1,315 +0,0 @@
"""
Tests for continuous batching behavior in the runner.
These tests verify that:
1. Single requests work through the batch path
2. Multiple concurrent requests batch together
3. Tokens are routed to the correct requests
4. Requests complete at different times appropriately
"""
# pyright: reportAny=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportAttributeAccessIssue=false
# pyright: reportInvalidTypeVarUse=false
from typing import Any
from unittest.mock import MagicMock
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import RunnerRunning
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
class FakeBatchEngineWithTokens:
"""
Fake batch engine that generates a specified number of tokens per request.
This simulates realistic batch generation behavior where:
- Requests are queued on insert
- Each step() call generates one token for all active requests
- Requests complete when they've generated all their tokens
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
self._uid_counter = 0
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
self.rank = 0 # Fake rank for testing
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion."""
self._pending_inserts.append((command_id, task_id, task_params))
def sync_and_insert_pending(self) -> list[int]:
"""Insert all pending requests."""
uids: list[int] = []
for command_id, task_id, task_params in self._pending_inserts:
uid = self._do_insert(command_id, task_id, task_params)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return len(self._pending_inserts) > 0
def _do_insert(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams | None,
) -> int:
uid = self._uid_counter
self._uid_counter += 1
# Track: (command_id, task_id, tokens_generated, max_tokens)
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
return uid
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
uids_to_remove: list[int] = []
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
self._active_requests.items()
):
tokens_gen += 1
finish_reason = "stop" if tokens_gen >= max_tokens else None
text = f"token{tokens_gen}"
if finish_reason:
uids_to_remove.append(uid)
else:
self._active_requests[uid] = (
command_id,
task_id,
tokens_gen,
max_tokens,
)
results.append(
BatchedGenerationResponse(
command_id=command_id,
task_id=task_id,
response=GenerationResponse(
token=tokens_gen,
text=text,
finish_reason=finish_reason,
),
)
)
for uid in uids_to_remove:
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return len(self._active_requests) > 0
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
def make_nothin[T, U, V](res: T):
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
@pytest.fixture
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
monkeypatch.setattr(
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
)
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
"""
Run tasks through the runner, adding shutdown at the end.
Tasks are sent in order, with shutdown sent last.
The batch engine processes between task handling.
"""
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NodeId(NODE_A),
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
shutdown_task = Shutdown(
task_id=TaskId("shutdown"),
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
with task_sender, event_receiver:
# Send all tasks including shutdown
for t in tasks:
task_sender.send(t)
task_sender.send(shutdown_task)
# Disable cleanup methods to prevent issues
event_sender.close = lambda: None
event_sender.join = lambda: None
task_receiver.close = lambda: None
task_receiver.join = lambda: None
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_receiver.collect()
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
def make_chat_task(
task_id: str, command_id: str, max_tokens: int = 3
) -> ChatCompletion:
return ChatCompletion(
task_id=TaskId(task_id),
command_id=CommandId(command_id),
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_tokens=max_tokens,
),
instance_id=INSTANCE_1_ID,
)
def test_single_request_generates_tokens(patch_batch_engine: None):
"""
Verify a single request generates the expected tokens through the batch path.
Note: With the current non-blocking design, shutdown is processed before
batch steps run when all tasks are queued together. This test verifies
the runner status reflects active requests.
"""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events - this shows the request was inserted
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
"""Verify RunnerRunning status includes active_requests count."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_chat_task_acknowledged(patch_batch_engine: None):
"""Verify chat completion task is acknowledged with proper status updates."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find the chat task status events
chat_running = [
e
for e in events
if isinstance(e, TaskStatusUpdated)
and e.task_id == TaskId("chat1")
and e.task_status == TaskStatus.Running
]
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
def test_multiple_requests_tracked(patch_batch_engine: None):
"""Verify multiple concurrent requests are tracked in active_requests."""
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
# Should have at least 2 RunnerRunning events (one per request inserted)
assert len(running_events) >= 2, f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
# First should have 1 active request, second should have 2
assert running_events[0].runner_status.active_requests == 1
assert running_events[1].runner_status.active_requests == 2

View File

@@ -1,79 +0,0 @@
"""Test for distributed synchronization in batch generation.
These tests verify that all ranks in a distributed setup call the same
collective operations in the same order, preventing race conditions and deadlocks.
"""
import pytest
from exo.worker.engines.mlx.generator.distributed_sync import DistributedOp
from exo.worker.runner.runner import _determine_distributed_op
class FakeBatchEngine:
"""Minimal fake batch engine for testing _determine_distributed_op."""
def __init__(self, has_pending_inserts: bool = False, has_active_requests: bool = False):
self._has_pending_inserts = has_pending_inserts
self._has_active_requests = has_active_requests
@property
def has_pending_inserts(self) -> bool:
return self._has_pending_inserts
@property
def has_active_requests(self) -> bool:
return self._has_active_requests
def test_distributed_sync_prevents_race_condition():
"""
Test that the new architecture prevents the race condition.
In the old code, each rank independently decided what operation to perform
based on its local state. This could cause rank 0 to call insert_request()
while rank 1 called step(), causing a collective mismatch.
The fix ensures:
1. Only rank 0 determines the operation (via _determine_distributed_op)
2. Rank 0's decision is broadcast to all ranks (via sync_operation)
3. All ranks then execute the same operation
"""
# Scenario: rank 0 has pending inserts, rank 1 has active requests
# In old code: rank 0 would INSERT, rank 1 would STEP -> race condition
# In new code: rank 0's decision (INSERT) is broadcast to all ranks
# Test _determine_distributed_op gives priority to INSERT over STEP
batch_engine_with_inserts = FakeBatchEngine(has_pending_inserts=True, has_active_requests=True)
op = _determine_distributed_op(batch_engine_with_inserts, pending_shutdown=None, should_shutdown=False)
assert op == DistributedOp.INSERT, "INSERT should take priority when there are pending inserts"
# Test step is used when there are active requests but no pending inserts
batch_engine_with_active = FakeBatchEngine(has_pending_inserts=False, has_active_requests=True)
op = _determine_distributed_op(batch_engine_with_active, pending_shutdown=None, should_shutdown=False)
assert op == DistributedOp.STEP, "STEP should be used when there are active requests"
# Test NOOP when nothing to do
batch_engine_idle = FakeBatchEngine(has_pending_inserts=False, has_active_requests=False)
op = _determine_distributed_op(batch_engine_idle, pending_shutdown=None, should_shutdown=False)
assert op == DistributedOp.NOOP, "NOOP should be used when nothing to do"
def test_distributed_sync_shutdown_handling():
"""Test that shutdown is properly coordinated across ranks."""
batch_engine = FakeBatchEngine(has_pending_inserts=False, has_active_requests=False)
# Test direct shutdown request
op = _determine_distributed_op(batch_engine, pending_shutdown=None, should_shutdown=True)
assert op == DistributedOp.SHUTDOWN, "SHUTDOWN should be returned when should_shutdown is True"
# Test pending shutdown with no active requests
from unittest.mock import MagicMock
pending_shutdown = MagicMock()
op = _determine_distributed_op(batch_engine, pending_shutdown=pending_shutdown, should_shutdown=False)
assert op == DistributedOp.SHUTDOWN, "SHUTDOWN should be returned when pending_shutdown and no active requests"
# Test pending shutdown with active requests - should continue processing
batch_engine_active = FakeBatchEngine(has_pending_inserts=False, has_active_requests=True)
op = _determine_distributed_op(batch_engine_active, pending_shutdown=pending_shutdown, should_shutdown=False)
assert op == DistributedOp.STEP, "Should continue STEP while requests are active even with pending shutdown"

View File

@@ -1,16 +1,11 @@
# Check tasks are complete before runner is ever ready.
# pyright: reportAny=false
from collections.abc import Iterable
from typing import Any, Callable
from unittest.mock import MagicMock
from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import CommandId
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
@@ -27,7 +22,6 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
@@ -44,9 +38,6 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from ...constants import (
CHAT_COMPLETION_TASK_ID,
@@ -116,85 +107,18 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
assert test_event == true_event, f"{test_event} != {true_event}"
class FakeBatchEngine:
"""
Fake batch engine for testing.
Queues requests on insert, returns one token per step.
The runner's non-blocking loop drains all tasks before running batch steps,
so this engine queues requests and has_active_requests returns True only
after at least one request has been inserted.
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
self._uid_counter = 0
self.rank = 0 # Fake rank for testing
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion."""
self._pending_inserts.append((command_id, task_id, task_params))
def sync_and_insert_pending(self) -> list[int]:
"""Insert all pending requests."""
uids: list[int] = []
for command_id, task_id, _task_params in self._pending_inserts:
uid = self._uid_counter
self._uid_counter += 1
self._active_requests[uid] = (command_id, task_id)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return len(self._pending_inserts) > 0
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
# Process all active requests - return one token and complete
for uid, (command_id, task_id) in list(self._active_requests.items()):
results.append(
BatchedGenerationResponse(
command_id=command_id,
task_id=task_id,
response=GenerationResponse(
token=0,
text="hi",
finish_reason="stop",
),
)
)
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return len(self._active_requests) > 0
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a fake "group" (non-None for state machine)
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock())))
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
def _run(tasks: Iterable[Task]):
@@ -224,8 +148,7 @@ def _run(tasks: Iterable[Task]):
return event_receiver.collect()
def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPatch):
"""Verify chat completion generates tokens, completes, and runner returns to Ready."""
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
expected_chunk = ChunkGenerated(
@@ -268,9 +191,7 @@ def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPat
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerRunning(active_requests=1)
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
@@ -285,6 +206,7 @@ def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPat
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
],
)