mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-14 09:00:07 -05:00
Compare commits
2 Commits
sami/flash
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dba7557513 | ||
|
|
16578d35a4 |
@@ -276,24 +276,23 @@ class BatchGenerator:
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
unprocessed_prompts: List[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: nn.Module,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
stop_tokens: Optional[set[int]] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
|
||||
) -> List[int]: ...
|
||||
def stats(self) -> BatchStats: ...
|
||||
def next(self) -> List[Response]: ...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
|
||||
@@ -39,12 +39,18 @@ class StreamingDetokenizer:
|
||||
"""
|
||||
|
||||
__slots__ = ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
tokens: list[int]
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
@property
|
||||
def last_segment(self):
|
||||
def text(self) -> str:
|
||||
"""The full text decoded so far."""
|
||||
...
|
||||
@property
|
||||
def last_segment(self) -> str:
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
...
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
@@ -108,6 +114,7 @@ class TokenizerWrapper:
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
eos_token_ids: list[int] | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
|
||||
37
flake.lock
generated
37
flake.lock
generated
@@ -8,11 +8,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -42,22 +42,6 @@
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
@@ -68,8 +52,8 @@
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
@@ -78,18 +62,17 @@
|
||||
"fenix": "fenix",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -106,11 +89,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
12
flake.nix
12
flake.nix
@@ -18,9 +18,6 @@
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
nixConfig = {
|
||||
@@ -42,11 +39,9 @@
|
||||
];
|
||||
|
||||
perSystem =
|
||||
{ config, inputs', pkgs, lib, system, ... }:
|
||||
{ config, inputs', pkgs, lib, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
@@ -65,10 +60,7 @@
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
swift-format.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ dependencies = [
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
exo-rsh = "exo.rsh.client:main"
|
||||
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
|
||||
@@ -15,7 +15,6 @@ import exo.routing.topics as topics
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
from exo.rsh.server import RSH_PORT, run_rsh_server
|
||||
from exo.shared.constants import EXO_LOG
|
||||
from exo.shared.election import Election, ElectionResult
|
||||
from exo.shared.logging import logger_cleanup, logger_setup
|
||||
@@ -114,8 +113,6 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
# Start RSH server for remote execution (used by MPI)
|
||||
tg.start_soon(run_rsh_server, RSH_PORT)
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
@@ -51,9 +51,7 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
LaunchFLASH,
|
||||
PlaceInstance,
|
||||
StopFLASH,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
@@ -62,12 +60,7 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
FLASHInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
)
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
@@ -185,10 +178,6 @@ class API:
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
# FLASH simulation endpoints
|
||||
self.app.post("/flash/launch")(self.launch_flash)
|
||||
self.app.delete("/flash/{instance_id}")(self.stop_flash)
|
||||
self.app.get("/flash/instances")(self.list_flash_instances)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -633,86 +622,6 @@ class API:
|
||||
]
|
||||
)
|
||||
|
||||
async def launch_flash(
|
||||
self,
|
||||
simulation_name: str,
|
||||
flash_executable_path: str,
|
||||
working_directory: str,
|
||||
parameter_file_path: str = "",
|
||||
ranks_per_node: int = 1,
|
||||
min_nodes: int = 1,
|
||||
hosts: str = "",
|
||||
) -> dict[str, str]:
|
||||
"""Launch a FLASH MPI simulation across the cluster.
|
||||
|
||||
Args:
|
||||
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
|
||||
If not provided, IPs are discovered from topology edges.
|
||||
"""
|
||||
command = LaunchFLASH(
|
||||
simulation_name=simulation_name,
|
||||
flash_executable_path=flash_executable_path,
|
||||
parameter_file_path=parameter_file_path,
|
||||
working_directory=working_directory,
|
||||
ranks_per_node=ranks_per_node,
|
||||
min_nodes=min_nodes,
|
||||
hosts=hosts,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return {
|
||||
"message": "FLASH launch command received",
|
||||
"command_id": str(command.command_id),
|
||||
"simulation_name": simulation_name,
|
||||
}
|
||||
|
||||
async def stop_flash(self, instance_id: InstanceId) -> dict[str, str]:
|
||||
"""Stop a running FLASH simulation."""
|
||||
if instance_id not in self.state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
instance = self.state.instances[instance_id]
|
||||
if not isinstance(instance, FLASHInstance):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Instance is not a FLASH simulation"
|
||||
)
|
||||
|
||||
command = StopFLASH(instance_id=instance_id)
|
||||
await self._send(command)
|
||||
|
||||
return {
|
||||
"message": "Stop command received",
|
||||
"command_id": str(command.command_id),
|
||||
"instance_id": str(instance_id),
|
||||
}
|
||||
|
||||
async def list_flash_instances(self) -> list[dict[str, Any]]:
|
||||
"""List all FLASH simulation instances."""
|
||||
flash_instances: list[dict[str, Any]] = []
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
if isinstance(instance, FLASHInstance):
|
||||
# Get runner statuses for this instance
|
||||
runner_statuses: dict[str, str | None] = {}
|
||||
for (
|
||||
node_id,
|
||||
runner_id,
|
||||
) in instance.shard_assignments.node_to_runner.items():
|
||||
runner_status = self.state.runners.get(runner_id)
|
||||
runner_statuses[str(node_id)] = (
|
||||
str(runner_status) if runner_status else None
|
||||
)
|
||||
|
||||
flash_instances.append(
|
||||
{
|
||||
"instance_id": str(instance_id),
|
||||
"simulation_name": instance.simulation_name,
|
||||
"total_ranks": instance.total_ranks,
|
||||
"working_directory": instance.working_directory,
|
||||
"runner_statuses": runner_statuses,
|
||||
}
|
||||
)
|
||||
return flash_instances
|
||||
|
||||
async def run(self):
|
||||
cfg = Config()
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
|
||||
@@ -8,7 +8,6 @@ from exo.master.placement import (
|
||||
add_instance_to_placements,
|
||||
delete_instance,
|
||||
get_transition_events,
|
||||
place_flash_instance,
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
@@ -17,10 +16,8 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
LaunchFLASH,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
StopFLASH,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -176,26 +173,6 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case LaunchFLASH():
|
||||
placement = place_flash_instance(
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case StopFLASH():
|
||||
# Reuse delete_instance logic to stop FLASH simulation
|
||||
placement = delete_instance(
|
||||
DeleteInstance(instance_id=command.instance_id),
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -17,24 +17,20 @@ from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
LaunchFLASH,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
FLASHInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def random_ephemeral_port() -> int:
|
||||
@@ -169,9 +165,6 @@ def place_instance(
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
case InstanceMeta.FLASH:
|
||||
# FLASH instances are handled by place_flash_instance()
|
||||
raise ValueError("FLASH instances should use place_flash_instance()")
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -187,148 +180,6 @@ def delete_instance(
|
||||
raise ValueError(f"Instance {command.instance_id} not found")
|
||||
|
||||
|
||||
def place_flash_instance(
|
||||
command: LaunchFLASH,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
"""Place a FLASH simulation instance across available nodes.
|
||||
|
||||
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
|
||||
FLASH instances use MPI for communication. We just need to provide the
|
||||
node IPs so the runner can generate an MPI hostfile.
|
||||
"""
|
||||
instance_id = InstanceId()
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
if len(all_nodes) < command.min_nodes:
|
||||
raise ValueError(
|
||||
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
|
||||
)
|
||||
|
||||
# Select nodes (take the first min_nodes)
|
||||
selected_nodes = all_nodes[: command.min_nodes]
|
||||
|
||||
logger.info(
|
||||
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
|
||||
)
|
||||
|
||||
# Build shard assignments (one runner per node for FLASH)
|
||||
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
|
||||
flash_model_meta = ModelMetadata(
|
||||
model_id=ModelId(command.simulation_name),
|
||||
pretty_name=f"FLASH: {command.simulation_name}",
|
||||
storage_size=Memory(in_bytes=0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
)
|
||||
|
||||
for i, node_info in enumerate(selected_nodes):
|
||||
runner_id = RunnerId()
|
||||
node_to_runner[node_info.node_id] = runner_id
|
||||
runner_to_shard[runner_id] = PipelineShardMetadata(
|
||||
device_rank=i,
|
||||
world_size=len(selected_nodes),
|
||||
model_meta=flash_model_meta,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=ModelId(command.simulation_name),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
# If explicit hosts are provided, use them directly
|
||||
if command.hosts:
|
||||
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
|
||||
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
|
||||
for i, node_info in enumerate(selected_nodes):
|
||||
if i < len(explicit_hosts):
|
||||
hosts_by_node[node_info.node_id] = [Host(ip=explicit_hosts[i], port=0)]
|
||||
logger.info(
|
||||
f"FLASH placement: node {node_info.node_id} (rank {i}) -> IP {explicit_hosts[i]}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Not enough hosts provided for node {i}, using localhost"
|
||||
)
|
||||
hosts_by_node[node_info.node_id] = [Host(ip="127.0.0.1", port=0)]
|
||||
logger.info(
|
||||
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
|
||||
)
|
||||
else:
|
||||
# Try to get IPs from topology edges
|
||||
for node_info in selected_nodes:
|
||||
node_hosts: list[Host] = []
|
||||
|
||||
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
|
||||
for _, edge_data in topology.out_edges(node_info.node_id):
|
||||
if hasattr(edge_data, "send_back_multiaddr"):
|
||||
# Extract IP from multiaddr like /ip4/192.168.1.100/tcp/52415
|
||||
multiaddr = str(edge_data.send_back_multiaddr)
|
||||
if "/ip4/" in multiaddr:
|
||||
parts = multiaddr.split("/")
|
||||
try:
|
||||
ip_idx = parts.index("ip4") + 1
|
||||
ip = parts[ip_idx]
|
||||
# Skip link-local and localhost addresses
|
||||
if not ip.startswith("169.254.") and not ip.startswith(
|
||||
"127."
|
||||
):
|
||||
node_hosts.append(Host(ip=ip, port=0))
|
||||
break
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Last resort: use localhost (will only work for single-node)
|
||||
if not node_hosts:
|
||||
logger.warning(
|
||||
f"Could not determine IP for node {node_info.node_id}, using localhost"
|
||||
)
|
||||
node_hosts.append(Host(ip="127.0.0.1", port=0))
|
||||
|
||||
hosts_by_node[node_info.node_id] = node_hosts
|
||||
|
||||
total_ranks = len(selected_nodes) * command.ranks_per_node
|
||||
|
||||
# Determine coordinator IP - first node's first host IP
|
||||
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
|
||||
coordinator_ip: str = (
|
||||
hosts_by_node[first_node_id][0].ip
|
||||
if hosts_by_node[first_node_id]
|
||||
else "127.0.0.1"
|
||||
)
|
||||
|
||||
target_instances[instance_id] = FLASHInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
flash_executable_path=command.flash_executable_path,
|
||||
parameter_file_path=command.parameter_file_path,
|
||||
working_directory=command.working_directory,
|
||||
ranks_per_node=command.ranks_per_node,
|
||||
total_ranks=total_ranks,
|
||||
simulation_name=command.simulation_name,
|
||||
coordinator_ip=coordinator_ip,
|
||||
)
|
||||
|
||||
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
|
||||
|
||||
return target_instances
|
||||
|
||||
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Exo RSH - Remote Shell for MPI without SSH.
|
||||
|
||||
This module provides a remote execution mechanism that allows mpirun to spawn
|
||||
processes on remote nodes without requiring SSH setup. It works by:
|
||||
|
||||
1. Each Exo node runs a small HTTP server (RSH server) on port 52416
|
||||
2. The exo-rsh script acts as a drop-in replacement for ssh
|
||||
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target
|
||||
4. The target executes the command and streams output back
|
||||
|
||||
Usage:
|
||||
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
|
||||
"""
|
||||
@@ -1,99 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""exo-rsh - Remote shell client for MPI.
|
||||
|
||||
This script is called by mpirun as a replacement for ssh.
|
||||
Usage: exo-rsh [ssh-options...] hostname command [args...]
|
||||
|
||||
It connects to the target node's RSH server (port 52416) and executes the command.
|
||||
"""
|
||||
|
||||
import json
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
RSH_PORT = 52416
|
||||
|
||||
|
||||
def resolve_hostname(hostname: str) -> str:
|
||||
"""Resolve hostname to IP address."""
|
||||
try:
|
||||
return socket.gethostbyname(hostname)
|
||||
except socket.gaierror:
|
||||
# If resolution fails, try using the hostname directly
|
||||
return hostname
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
|
||||
# SSH options we might see: -x (disable X11), -o options, etc.
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Skip SSH-style options
|
||||
hostname = None
|
||||
command_start = 0
|
||||
|
||||
i = 0
|
||||
while i < len(args):
|
||||
arg = args[i]
|
||||
if arg.startswith("-"):
|
||||
# Skip option and its value if needed
|
||||
if arg in ("-o", "-i", "-l", "-p", "-F"):
|
||||
i += 2 # Skip option and its argument
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
# First non-option is the hostname
|
||||
hostname = arg
|
||||
command_start = i + 1
|
||||
break
|
||||
i += 1
|
||||
|
||||
if hostname is None or command_start >= len(args):
|
||||
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
command = args[command_start:]
|
||||
|
||||
# Resolve hostname to IP
|
||||
ip = resolve_hostname(hostname)
|
||||
|
||||
# Make request to RSH server
|
||||
url = f"http://{ip}:{RSH_PORT}/execute"
|
||||
data = json.dumps({"command": command}).encode("utf-8")
|
||||
|
||||
try:
|
||||
req = Request(url, data=data, headers={"Content-Type": "application/json"})
|
||||
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
|
||||
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
|
||||
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
|
||||
|
||||
# Output stdout/stderr
|
||||
stdout: str = cast(str, result.get("stdout", ""))
|
||||
stderr: str = cast(str, result.get("stderr", ""))
|
||||
exit_code: int = cast(int, result.get("exit_code", 0))
|
||||
|
||||
if stdout:
|
||||
sys.stdout.write(stdout)
|
||||
sys.stdout.flush()
|
||||
if stderr:
|
||||
sys.stderr.write(stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
except URLError as e:
|
||||
print(
|
||||
f"exo-rsh: Failed to connect to {hostname}:{RSH_PORT}: {e}", file=sys.stderr
|
||||
)
|
||||
sys.exit(255)
|
||||
except Exception as e:
|
||||
print(f"exo-rsh: Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,154 +0,0 @@
|
||||
"""RSH Server - runs on each Exo node to accept remote execution requests."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from hypercorn.config import Config
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
RSH_PORT = 52416
|
||||
|
||||
|
||||
class ExecuteRequest(BaseModel):
|
||||
"""Request to execute a command."""
|
||||
|
||||
command: list[str]
|
||||
cwd: Optional[str] = None
|
||||
env: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class ExecuteResponse(BaseModel):
|
||||
"""Response from command execution."""
|
||||
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
|
||||
|
||||
def create_rsh_app() -> FastAPI:
|
||||
"""Create the RSH FastAPI application."""
|
||||
app = FastAPI(title="Exo RSH Server")
|
||||
|
||||
@app.get("/health")
|
||||
async def health(): # pyright: ignore[reportUnusedFunction]
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/execute")
|
||||
async def execute(request: ExecuteRequest) -> ExecuteResponse: # pyright: ignore[reportUnusedFunction]
|
||||
"""Execute a command and return the result."""
|
||||
cmd_str = " ".join(request.command)
|
||||
logger.info(f"RSH executing: {cmd_str}")
|
||||
|
||||
try:
|
||||
# Build environment
|
||||
import os
|
||||
|
||||
env = os.environ.copy()
|
||||
if request.env:
|
||||
env.update(request.env)
|
||||
|
||||
# Check if command contains shell metacharacters (semicolons, pipes, etc.)
|
||||
# If so, run through shell. mpirun sends complex commands like:
|
||||
# "VAR=value;export VAR;/path/to/prted --args"
|
||||
needs_shell = any(c in cmd_str for c in ";|&$`")
|
||||
|
||||
if needs_shell:
|
||||
# Run through shell
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd_str,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=request.cwd,
|
||||
env=env,
|
||||
)
|
||||
else:
|
||||
# Execute directly
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*request.command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=request.cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
exit_code = process.returncode or 0
|
||||
|
||||
logger.info(f"RSH command completed with exit code {exit_code}")
|
||||
|
||||
return ExecuteResponse(
|
||||
exit_code=exit_code,
|
||||
stdout=stdout.decode("utf-8", errors="replace"),
|
||||
stderr=stderr.decode("utf-8", errors="replace"),
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"RSH command not found: {e}")
|
||||
return ExecuteResponse(
|
||||
exit_code=127,
|
||||
stdout="",
|
||||
stderr=f"Command not found: {request.command[0]}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"RSH execution error: {e}")
|
||||
return ExecuteResponse(
|
||||
exit_code=1,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
)
|
||||
|
||||
@app.post("/execute_streaming")
|
||||
async def execute_streaming(request: ExecuteRequest): # pyright: ignore[reportUnusedFunction]
|
||||
"""Execute a command and stream the output."""
|
||||
logger.info(f"RSH streaming execute: {' '.join(request.command)}")
|
||||
|
||||
async def stream_output():
|
||||
try:
|
||||
env = None
|
||||
if request.env:
|
||||
import os
|
||||
|
||||
env = os.environ.copy()
|
||||
env.update(request.env)
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*request.command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
cwd=request.cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if process.stdout:
|
||||
async for line in process.stdout:
|
||||
yield line
|
||||
|
||||
await process.wait()
|
||||
|
||||
except Exception as e:
|
||||
yield f"Error: {e}\n".encode()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_output(),
|
||||
media_type="application/octet-stream",
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_rsh_server(port: int = RSH_PORT):
|
||||
"""Run the RSH server."""
|
||||
import anyio
|
||||
|
||||
app = create_rsh_app()
|
||||
config = Config()
|
||||
config.bind = [f"0.0.0.0:{port}"]
|
||||
config.accesslog = None # Disable access logs for cleaner output
|
||||
|
||||
logger.info(f"Starting RSH server on port {port}")
|
||||
await serve(app, config, shutdown_trigger=lambda: anyio.sleep_forever()) # type: ignore
|
||||
@@ -35,26 +35,6 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class LaunchFLASH(BaseCommand):
|
||||
"""Command to launch a FLASH MPI simulation."""
|
||||
|
||||
simulation_name: str
|
||||
flash_executable_path: str
|
||||
parameter_file_path: str
|
||||
working_directory: str
|
||||
ranks_per_node: int = 1
|
||||
min_nodes: int = 1
|
||||
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
|
||||
# Used when topology edges don't contain IP addresses
|
||||
hosts: str = ""
|
||||
|
||||
|
||||
class StopFLASH(BaseCommand):
|
||||
"""Command to stop a running FLASH simulation."""
|
||||
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -70,8 +50,6 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| LaunchFLASH
|
||||
| StopFLASH
|
||||
| TaskFinished
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ class InstanceId(Id):
|
||||
class InstanceMeta(str, Enum):
|
||||
MlxRing = "MlxRing"
|
||||
MlxJaccl = "MlxJaccl"
|
||||
FLASH = "FLASH"
|
||||
|
||||
|
||||
class BaseInstance(TaggedModel):
|
||||
@@ -35,27 +34,8 @@ class MlxJacclInstance(BaseInstance):
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
class FLASHInstance(BaseInstance):
|
||||
"""Instance for FLASH MPI simulation.
|
||||
|
||||
Unlike MLX instances which do tensor parallelism, FLASH instances
|
||||
coordinate MPI processes across nodes. Each node runs one or more
|
||||
MPI ranks of the FLASH simulation.
|
||||
"""
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
flash_executable_path: str
|
||||
parameter_file_path: str
|
||||
working_directory: str
|
||||
ranks_per_node: int = 1
|
||||
total_ranks: int
|
||||
simulation_name: str
|
||||
coordinator_ip: str
|
||||
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
|
||||
Instance = MlxRingInstance | MlxJacclInstance
|
||||
|
||||
|
||||
class BoundInstance(CamelCaseModel):
|
||||
|
||||
@@ -50,7 +50,9 @@ class RunnerReady(BaseRunnerStatus):
|
||||
|
||||
|
||||
class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
"""Runner is processing requests and can accept more (continuous batching)."""
|
||||
|
||||
active_requests: int = 0
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
|
||||
251
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
251
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""Tracks an active request in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
uid: int # BatchGenerator's internal ID
|
||||
detokenizer: StreamingDetokenizer
|
||||
tokens_generated: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedGenerationResponse:
|
||||
"""Response from batch engine, tagged with command_id and task_id."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
response: GenerationResponse
|
||||
|
||||
|
||||
class BatchGenerationEngine:
|
||||
"""Manages continuous batching using mlx_lm's BatchGenerator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
completion_batch_size: int = 32,
|
||||
prefill_batch_size: int = 8,
|
||||
prefill_step_size: int = 2048,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.active_requests: dict[int, ActiveRequest] = {}
|
||||
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
|
||||
|
||||
self.group = group
|
||||
self.rank = group.rank() if group else 0
|
||||
self.is_distributed = group is not None and group.size() > 1
|
||||
|
||||
sampler = make_sampler(temp=0.7, top_p=1.0)
|
||||
|
||||
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
|
||||
|
||||
self.batch_gen: BatchGenerator = BatchGenerator(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
stop_tokens=eos_tokens,
|
||||
sampler=sampler,
|
||||
completion_batch_size=completion_batch_size,
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
prefill_step_size=prefill_step_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
|
||||
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
|
||||
)
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion. Only rank 0 should call this.
|
||||
|
||||
In distributed mode, rank 0 receives tasks from the control plane and
|
||||
queues them here. The actual insertion happens in sync_and_insert_pending()
|
||||
which ensures all ranks insert the same requests together.
|
||||
"""
|
||||
assert self.rank == 0, "Only rank 0 should queue requests"
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
logger.info(f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})")
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Sync pending inserts across ranks and insert them. Returns UIDs.
|
||||
|
||||
This method ensures all ranks insert the same requests in the same order.
|
||||
In non-distributed mode, it simply inserts all pending requests.
|
||||
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
|
||||
"""
|
||||
inserts_to_process: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]]
|
||||
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: just insert directly from pending
|
||||
inserts_to_process = list(self._pending_inserts)
|
||||
else:
|
||||
# Distributed: broadcast pending inserts from rank 0 to all ranks
|
||||
assert self.group is not None
|
||||
pending_data = self._pending_inserts if self.rank == 0 else None
|
||||
synced_data = share_object(pending_data, self.rank, self.group)
|
||||
|
||||
if synced_data is None:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
inserts_to_process = synced_data
|
||||
|
||||
uids: list[int] = []
|
||||
for cmd_id, task_id, params in inserts_to_process:
|
||||
uid = self._do_insert(cmd_id, task_id, params)
|
||||
uids.append(uid)
|
||||
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def _do_insert(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> int:
|
||||
"""Actually insert a request into BatchGenerator. No sync - called after sync."""
|
||||
prompt_str = apply_chat_template(self.tokenizer, task_params)
|
||||
tokens: list[int] = self.tokenizer.encode(prompt_str, add_special_tokens=False)
|
||||
prompt_tokens = len(tokens)
|
||||
max_tokens = task_params.max_tokens or self.max_tokens
|
||||
|
||||
uids = self.batch_gen.insert([tokens], max_tokens=[max_tokens])
|
||||
uid = uids[0]
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
|
||||
self.active_requests[uid] = ActiveRequest(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
uid=uid,
|
||||
detokenizer=detokenizer,
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Inserted request {command_id} with uid={uid}, prompt_tokens={prompt_tokens}, max_tokens={max_tokens}")
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
"""Run one decode step. Syncs completed UIDs across ranks if distributed."""
|
||||
responses = self.batch_gen.next()
|
||||
if not responses:
|
||||
return []
|
||||
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for r in responses:
|
||||
uid: int = r.uid
|
||||
req = self.active_requests.get(uid)
|
||||
if req is None:
|
||||
logger.warning(f"Received response for unknown uid={uid}")
|
||||
continue
|
||||
|
||||
req.tokens_generated += 1
|
||||
|
||||
# Decode the token
|
||||
token: int = r.token
|
||||
req.detokenizer.add_token(token)
|
||||
text: str = req.detokenizer.last_segment
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
raw_finish_reason: str | None = r.finish_reason
|
||||
if raw_finish_reason is not None:
|
||||
# Finalize to get remaining text
|
||||
req.detokenizer.finalize()
|
||||
text = req.detokenizer.last_segment
|
||||
|
||||
elapsed = time.perf_counter() - req.start_time
|
||||
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=0.0, # Not tracked per-request in batch mode
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=req.prompt_tokens,
|
||||
generation_tokens=req.tokens_generated,
|
||||
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
|
||||
)
|
||||
|
||||
if raw_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif raw_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
|
||||
finish_reason = "stop"
|
||||
|
||||
uids_to_remove.append(uid) # Sync before removal
|
||||
logger.info(f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}")
|
||||
|
||||
results.append(BatchedGenerationResponse(
|
||||
command_id=req.command_id,
|
||||
task_id=req.task_id,
|
||||
response=GenerationResponse(text=text, token=token, finish_reason=finish_reason, stats=stats),
|
||||
))
|
||||
|
||||
# Sync completed UIDs across ranks before removing
|
||||
if self.is_distributed and uids_to_remove:
|
||||
assert self.group is not None
|
||||
uids_to_remove = share_object(uids_to_remove if self.rank == 0 else None, self.rank, self.group) or []
|
||||
|
||||
for uid in uids_to_remove:
|
||||
if uid in self.active_requests:
|
||||
del self.active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self.active_requests)
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return len(self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
73
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
73
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
import pickle
|
||||
from enum import IntEnum
|
||||
from typing import TypeVar, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DistributedOp(IntEnum):
|
||||
"""Operation codes for distributed synchronization.
|
||||
|
||||
Used to ensure all ranks in a distributed setup execute the same
|
||||
operations in the same order, preventing collective mismatches.
|
||||
"""
|
||||
|
||||
NOOP = 0 # No operation - rank 0 will block waiting for a task
|
||||
INSERT = 1 # Insert pending requests into the batch
|
||||
STEP = 2 # Run a decode step
|
||||
SHUTDOWN = 3 # All ranks should exit
|
||||
|
||||
|
||||
def sync_operation(
|
||||
op: DistributedOp | None,
|
||||
rank: int,
|
||||
group: mx.distributed.Group,
|
||||
) -> DistributedOp:
|
||||
"""Broadcast operation code from rank 0 to all ranks.
|
||||
|
||||
This ensures all ranks execute the same operation, preventing
|
||||
collective mismatches that would cause deadlocks.
|
||||
|
||||
Args:
|
||||
op: The operation to perform (only rank 0's value is used)
|
||||
rank: This process's rank
|
||||
group: The distributed group
|
||||
|
||||
Returns:
|
||||
The operation that all ranks should execute
|
||||
"""
|
||||
if rank == 0:
|
||||
assert op is not None, "Rank 0 must provide an operation"
|
||||
code = mx.array([int(op)], dtype=mx.int32)
|
||||
else:
|
||||
code = mx.array([0], dtype=mx.int32)
|
||||
|
||||
result = mx.distributed.all_sum(code, group=group)
|
||||
mx.eval(result)
|
||||
return DistributedOp(int(result.item()))
|
||||
|
||||
|
||||
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
|
||||
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
|
||||
if rank == 0:
|
||||
if obj is None:
|
||||
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
|
||||
return None
|
||||
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
|
||||
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
|
||||
mx.eval(mx.distributed.all_sum(data, group=group))
|
||||
return obj
|
||||
else:
|
||||
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
|
||||
if size == 0:
|
||||
return None
|
||||
data = mx.zeros(size, dtype=mx.uint8)
|
||||
data = mx.distributed.all_sum(data, group=group)
|
||||
mx.eval(data)
|
||||
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))
|
||||
@@ -164,11 +164,6 @@ def mlx_distributed_init(
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
|
||||
)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
|
||||
@@ -21,12 +21,7 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
FLASHInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
@@ -55,11 +50,6 @@ def plan(
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
# Check for FLASH instance tasks first
|
||||
flash_task = _plan_flash(runners, instances)
|
||||
if flash_task is not None:
|
||||
return flash_task
|
||||
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
@@ -72,34 +62,6 @@ def plan(
|
||||
)
|
||||
|
||||
|
||||
def _plan_flash(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
) -> Task | None:
|
||||
"""Plan tasks specifically for FLASH instances.
|
||||
|
||||
FLASH instances have a simpler lifecycle:
|
||||
- CreateRunner (handled by _create_runner)
|
||||
- LoadModel (starts the simulation immediately)
|
||||
- Shutdown (handled by _kill_runner)
|
||||
|
||||
This function handles the LoadModel step for FLASH instances,
|
||||
skipping the MLX-specific download/init/warmup steps.
|
||||
"""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
|
||||
# Only handle FLASH instances
|
||||
if not isinstance(instance, FLASHInstance):
|
||||
continue
|
||||
|
||||
# If runner is idle, emit LoadModel to start the simulation
|
||||
if isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _kill_runner(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -152,10 +114,6 @@ def _model_needs_download(
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
# FLASH instances don't need model downloads
|
||||
if isinstance(runner.bound_instance.instance, FLASHInstance):
|
||||
continue
|
||||
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
@@ -319,12 +277,14 @@ def _pending_tasks(
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
@@ -4,11 +4,7 @@ import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
FLASHInstance,
|
||||
MlxJacclInstance,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
@@ -21,27 +17,20 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
# Route based on instance type
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
if isinstance(bound_instance.instance, FLASHInstance):
|
||||
# FLASH MPI simulation runner
|
||||
from exo.worker.runner.flash_runner import main
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
else:
|
||||
# MLX runner (default)
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
|
||||
|
||||
Exo-native distributed MPI:
|
||||
- Exo handles node discovery and coordination
|
||||
- Coordinator generates hostfile from Exo topology
|
||||
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
|
||||
- Each Exo node runs an RSH server on port 52416 for remote execution
|
||||
- Workers just report ready and wait
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Find mpirun in PATH, fallback to common locations
|
||||
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
|
||||
|
||||
# exo-rsh is installed as console script by exo package
|
||||
_exo_rsh_path = shutil.which("exo-rsh")
|
||||
if not _exo_rsh_path:
|
||||
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
|
||||
EXO_RSH_PATH: str = _exo_rsh_path
|
||||
|
||||
|
||||
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
|
||||
"""Determine this node's rank based on position in hosts_by_node."""
|
||||
for i, node_id in enumerate(instance.hosts_by_node.keys()):
|
||||
if str(node_id) == str(my_node_id):
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
def get_coordinator_host(instance: FLASHInstance) -> str:
|
||||
"""Get the IP of the coordinator node."""
|
||||
return instance.coordinator_ip
|
||||
|
||||
|
||||
def resolve_host(host: str) -> str:
|
||||
"""Resolve host string to a usable hostname for MPI hostfile.
|
||||
|
||||
Accepts either an IP address or hostname. For IPs, attempts to resolve
|
||||
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
|
||||
"""
|
||||
# Check if input is already a hostname (not an IP)
|
||||
try:
|
||||
socket.inet_aton(host)
|
||||
is_ip = True
|
||||
except socket.error:
|
||||
is_ip = False
|
||||
|
||||
if not is_ip:
|
||||
# Already a hostname, verify it resolves and return as-is
|
||||
try:
|
||||
socket.gethostbyname(host)
|
||||
return host
|
||||
except socket.gaierror:
|
||||
logger.warning(f"Hostname {host} does not resolve, using anyway")
|
||||
return host
|
||||
|
||||
# It's an IP address, try to resolve to hostname
|
||||
try:
|
||||
hostname, _, _ = socket.gethostbyaddr(host)
|
||||
hostname = hostname.split(".")[0]
|
||||
logger.info(f"Resolved {host} to {hostname}")
|
||||
return hostname
|
||||
except socket.herror:
|
||||
pass
|
||||
|
||||
# Fall back to IP
|
||||
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
|
||||
return host
|
||||
|
||||
|
||||
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
|
||||
"""Generate MPI hostfile from instance topology."""
|
||||
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
|
||||
with open(hostfile_path, "w") as f:
|
||||
for _node_id, hosts in instance.hosts_by_node.items():
|
||||
if hosts:
|
||||
host = resolve_host(hosts[0].ip)
|
||||
f.write(f"{host} slots={instance.ranks_per_node}\n")
|
||||
logger.info(f"Generated hostfile at {hostfile_path}")
|
||||
with open(hostfile_path, "r") as f:
|
||||
logger.info(f"Hostfile contents:\n{f.read()}")
|
||||
return hostfile_path
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
):
|
||||
"""Main FLASH runner loop.
|
||||
|
||||
Coordinator: generates hostfile and runs mpirun (which SSHs to workers)
|
||||
Workers: just report ready and wait for mpirun to spawn processes on them
|
||||
"""
|
||||
assert isinstance(bound_instance.instance, FLASHInstance)
|
||||
instance = bound_instance.instance
|
||||
runner_id = bound_instance.bound_runner_id
|
||||
my_node_id = str(bound_instance.bound_node_id)
|
||||
|
||||
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
|
||||
|
||||
my_rank = get_my_rank(instance, my_node_id)
|
||||
world_size = len(instance.hosts_by_node)
|
||||
is_coordinator = my_rank == 0
|
||||
coordinator_ip = get_coordinator_host(instance)
|
||||
|
||||
logger.info(
|
||||
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
|
||||
)
|
||||
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
|
||||
|
||||
process: subprocess.Popen[bytes] | None = None
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
shutdown_requested = False
|
||||
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
|
||||
"""Monitor FLASH stdout for progress updates."""
|
||||
if proc.stdout is None:
|
||||
return
|
||||
for line in iter(proc.stdout.readline, b""):
|
||||
if shutdown_requested:
|
||||
break
|
||||
try:
|
||||
decoded: str = line.decode("utf-8", errors="replace").strip()
|
||||
if decoded:
|
||||
logger.info(f"[FLASH] {decoded}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing FLASH output: {e}")
|
||||
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
match task:
|
||||
case LoadModel() if isinstance(current_status, RunnerIdle):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("Starting FLASH simulation")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
if is_coordinator:
|
||||
# Coordinator: generate hostfile and run mpirun
|
||||
hostfile = generate_hostfile(
|
||||
instance, instance.working_directory
|
||||
)
|
||||
|
||||
iface = instance.network_interface
|
||||
cmd = [
|
||||
MPIRUN_PATH,
|
||||
"-np",
|
||||
str(instance.total_ranks),
|
||||
"--hostfile",
|
||||
hostfile,
|
||||
"--wdir",
|
||||
instance.working_directory,
|
||||
"--oversubscribe",
|
||||
"--mca",
|
||||
"btl",
|
||||
"tcp,self",
|
||||
"--mca",
|
||||
"btl_tcp_if_include",
|
||||
iface,
|
||||
"--mca",
|
||||
"oob_tcp_if_include",
|
||||
iface,
|
||||
"--mca",
|
||||
"plm_rsh_no_tree_spawn",
|
||||
"1",
|
||||
]
|
||||
|
||||
# Use exo-rsh for remote execution (no SSH needed)
|
||||
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
|
||||
|
||||
cmd.append(instance.flash_executable_path)
|
||||
|
||||
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=instance.working_directory,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
monitor_thread = threading.Thread(
|
||||
target=monitor_output, args=(process,), daemon=True
|
||||
)
|
||||
monitor_thread.start()
|
||||
|
||||
current_status = RunnerRunning()
|
||||
logger.info(
|
||||
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
|
||||
)
|
||||
|
||||
else:
|
||||
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
|
||||
logger.info(
|
||||
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
|
||||
)
|
||||
current_status = RunnerRunning()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start FLASH: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
current_status = RunnerFailed(error_message=str(e))
|
||||
|
||||
case Shutdown():
|
||||
shutdown_requested = True
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("FLASH runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
if process and process.poll() is None:
|
||||
logger.info("Terminating FLASH simulation")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("FLASH didn't terminate, killing")
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
||||
current_status = RunnerShutdown()
|
||||
|
||||
case _:
|
||||
if process and process.poll() is not None:
|
||||
exit_code = process.returncode
|
||||
if exit_code == 0:
|
||||
logger.info("FLASH simulation completed successfully")
|
||||
current_status = RunnerReady()
|
||||
else:
|
||||
logger.error(
|
||||
f"FLASH simulation failed with code {exit_code}"
|
||||
)
|
||||
current_status = RunnerFailed(
|
||||
error_message=f"Exit code {exit_code}"
|
||||
)
|
||||
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait(timeout=5)
|
||||
|
||||
logger.info("FLASH runner exiting")
|
||||
@@ -1,6 +1,8 @@
|
||||
import gc
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from anyio import WouldBlock
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
@@ -21,9 +23,6 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
@@ -39,7 +38,9 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.generator.batch_engine import BatchGenerationEngine
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import DistributedOp, sync_operation
|
||||
from exo.worker.engines.mlx.generator.generate import warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
@@ -48,6 +49,26 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
def _determine_distributed_op(
|
||||
batch_engine: BatchGenerationEngine | None,
|
||||
pending_shutdown: Task | None,
|
||||
should_shutdown: bool,
|
||||
) -> DistributedOp:
|
||||
"""Determine what operation to perform next in distributed mode.
|
||||
|
||||
Only rank 0's result matters - this gets broadcast to all ranks via sync_operation.
|
||||
"""
|
||||
if should_shutdown:
|
||||
return DistributedOp.SHUTDOWN
|
||||
if pending_shutdown is not None and batch_engine is not None and not batch_engine.has_active_requests:
|
||||
return DistributedOp.SHUTDOWN
|
||||
if batch_engine is not None and batch_engine.has_pending_inserts:
|
||||
return DistributedOp.INSERT
|
||||
if batch_engine is not None and batch_engine.has_active_requests:
|
||||
return DistributedOp.STEP
|
||||
return DistributedOp.NOOP
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
@@ -69,142 +90,247 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
batch_engine: BatchGenerationEngine | None = None
|
||||
pending_shutdown: Shutdown | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
|
||||
def send_status(status: RunnerStatus) -> None:
|
||||
event_sender.send(RunnerStatusUpdated(runner_id=runner_id, runner_status=status))
|
||||
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
send_status(current_status)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
def handle_task(task: Task, is_deferred: bool = False) -> bool:
|
||||
nonlocal current_status, model, tokenizer, group, batch_engine, pending_shutdown
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
# For Shutdown, check if we need to defer BEFORE sending Running/Acknowledged
|
||||
if isinstance(task, Shutdown) and not is_deferred:
|
||||
if batch_engine is not None and (batch_engine.has_active_requests or batch_engine.has_pending_inserts):
|
||||
logger.info("deferring shutdown until active requests complete")
|
||||
pending_shutdown = task
|
||||
return True
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running))
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
send_status(current_status)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
send_status(current_status)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model is not None
|
||||
assert tokenizer is not None
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
send_status(current_status)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(model=model, tokenizer=tokenizer)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(f"runner initialized in {time.time() - setup_start_time} seconds")
|
||||
|
||||
batch_engine = BatchGenerationEngine(model=model, tokenizer=tokenizer, group=group)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, (RunnerReady, RunnerRunning)):
|
||||
assert batch_engine is not None
|
||||
|
||||
if task_params.messages and task_params.messages[0].content is not None:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
# Queue the request - actual insertion happens in sync_and_insert_pending()
|
||||
# In distributed mode, only rank 0 receives tasks from control plane
|
||||
batch_engine.queue_request(command_id=command_id, task_id=task.task_id, task_params=task_params)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
# Status will be updated after actual insertion in the main loop
|
||||
# For now, set to RunnerRunning to indicate we're processing
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count + batch_engine.pending_insert_count)
|
||||
send_status(current_status)
|
||||
|
||||
gc.collect()
|
||||
break
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
send_status(current_status)
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
current_status = RunnerShutdown()
|
||||
send_status(current_status)
|
||||
return False
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
with task_receiver as tasks:
|
||||
running = True
|
||||
is_rank_0 = shard_metadata.device_rank == 0
|
||||
is_distributed = group is not None and group.size() > 1
|
||||
|
||||
while running:
|
||||
if is_distributed:
|
||||
assert group is not None
|
||||
assert batch_engine is not None
|
||||
|
||||
# Distributed mode: synchronize operations across all ranks
|
||||
# Step 1: Only rank 0 checks for tasks and determines operation
|
||||
should_shutdown = False
|
||||
if is_rank_0:
|
||||
while True:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
task_result = handle_task(task)
|
||||
if not task_result:
|
||||
should_shutdown = True
|
||||
break
|
||||
except WouldBlock:
|
||||
break
|
||||
|
||||
op = _determine_distributed_op(batch_engine, pending_shutdown, should_shutdown)
|
||||
else:
|
||||
op = None
|
||||
|
||||
# Step 2: Sync operation across all ranks
|
||||
synced_op = sync_operation(op, shard_metadata.device_rank, group)
|
||||
|
||||
# Step 3: All ranks execute the same operation
|
||||
match synced_op:
|
||||
case DistributedOp.INSERT:
|
||||
batch_engine.sync_and_insert_pending()
|
||||
if is_rank_0:
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count)
|
||||
send_status(current_status)
|
||||
|
||||
case DistributedOp.STEP:
|
||||
for resp in batch_engine.step():
|
||||
if is_rank_0:
|
||||
event_sender.send(ChunkGenerated(
|
||||
command_id=resp.command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=resp.response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=resp.response.text,
|
||||
token_id=resp.response.token,
|
||||
finish_reason=resp.response.finish_reason,
|
||||
stats=resp.response.stats,
|
||||
),
|
||||
))
|
||||
if resp.response.finish_reason is not None:
|
||||
if is_rank_0:
|
||||
event_sender.send(TaskStatusUpdated(task_id=resp.task_id, task_status=TaskStatus.Complete))
|
||||
|
||||
if is_rank_0:
|
||||
if batch_engine.has_active_requests:
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count)
|
||||
else:
|
||||
current_status = RunnerReady()
|
||||
send_status(current_status)
|
||||
|
||||
case DistributedOp.SHUTDOWN:
|
||||
running = False
|
||||
if is_rank_0 and pending_shutdown is not None:
|
||||
handle_task(pending_shutdown, is_deferred=True)
|
||||
|
||||
case DistributedOp.NOOP:
|
||||
# No work to do - all ranks poll together
|
||||
# We can't have rank 0 block while others try to sync
|
||||
if is_rank_0:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
task_result = handle_task(task)
|
||||
if not task_result:
|
||||
# Will sync SHUTDOWN on next iteration
|
||||
pass
|
||||
except WouldBlock:
|
||||
pass
|
||||
# All ranks: short sleep before looping back to sync
|
||||
time.sleep(0.001)
|
||||
|
||||
else:
|
||||
# Non-distributed mode: original logic with queue + insert
|
||||
while True:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
running = handle_task(task)
|
||||
if not running:
|
||||
break
|
||||
except WouldBlock:
|
||||
break
|
||||
|
||||
if not running:
|
||||
break
|
||||
|
||||
# Insert any queued requests (non-distributed just inserts directly)
|
||||
# Status was already sent in handle_task when queueing
|
||||
if batch_engine is not None and batch_engine.has_pending_inserts:
|
||||
batch_engine.sync_and_insert_pending()
|
||||
|
||||
if batch_engine is not None and batch_engine.has_active_requests:
|
||||
for resp in batch_engine.step():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(ChunkGenerated(
|
||||
command_id=resp.command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=resp.response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=resp.response.text,
|
||||
token_id=resp.response.token,
|
||||
finish_reason=resp.response.finish_reason,
|
||||
stats=resp.response.stats,
|
||||
),
|
||||
))
|
||||
if resp.response.finish_reason is not None:
|
||||
event_sender.send(TaskStatusUpdated(task_id=resp.task_id, task_status=TaskStatus.Complete))
|
||||
|
||||
if batch_engine.has_active_requests:
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count)
|
||||
else:
|
||||
current_status = RunnerReady()
|
||||
send_status(current_status)
|
||||
|
||||
# Process deferred shutdown after all requests complete
|
||||
if pending_shutdown is not None and not batch_engine.has_active_requests and not batch_engine.has_pending_inserts:
|
||||
running = handle_task(pending_shutdown, is_deferred=True)
|
||||
else:
|
||||
task = tasks.receive()
|
||||
running = handle_task(task)
|
||||
|
||||
# Cleanup
|
||||
del model, tokenizer, group, batch_engine
|
||||
mx.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
|
||||
@@ -105,7 +105,7 @@ class RunnerSupervisor:
|
||||
return
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
logger.warning("Runner process didn't shutdown successfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
@@ -128,9 +128,11 @@ class RunnerSupervisor:
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
logger.info(f"Skipping task {task.task_id} - already completed")
|
||||
return
|
||||
if task.task_id in self.pending:
|
||||
logger.info(f"Skipping task {task.task_id} - already pending")
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -149,13 +151,17 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
# Just set the event to unblock start_task, but keep in pending
|
||||
# to prevent duplicate forwarding until completion
|
||||
if event.task_id in self.pending:
|
||||
self.pending[event.task_id].set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
and event.task_status == TaskStatus.Complete
|
||||
if isinstance(event, TaskStatusUpdated) and event.task_status in (
|
||||
TaskStatus.Complete,
|
||||
TaskStatus.TimedOut,
|
||||
TaskStatus.Failed,
|
||||
):
|
||||
# If a task has just been completed, we should be working on it.
|
||||
# If a task has just finished, we should be working on it.
|
||||
assert isinstance(
|
||||
self.status,
|
||||
(
|
||||
@@ -166,6 +172,8 @@ class RunnerSupervisor:
|
||||
RunnerShuttingDown,
|
||||
),
|
||||
)
|
||||
# Now safe to remove from pending and add to completed
|
||||
self.pending.pop(event.task_id, None)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
|
||||
@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Tests for continuous batching behavior in the runner.
|
||||
|
||||
These tests verify that:
|
||||
1. Single requests work through the batch path
|
||||
2. Multiple concurrent requests batch together
|
||||
3. Tokens are routed to the correct requests
|
||||
4. Requests complete at different times appropriately
|
||||
"""
|
||||
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerRunning
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""
|
||||
Fake batch engine that generates a specified number of tokens per request.
|
||||
|
||||
This simulates realistic batch generation behavior where:
|
||||
- Requests are queued on insert
|
||||
- Each step() call generates one token for all active requests
|
||||
- Requests complete when they've generated all their tokens
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
|
||||
self._uid_counter = 0
|
||||
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._do_insert(command_id, task_id, task_params)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def _do_insert(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams | None,
|
||||
) -> int:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
# Track: (command_id, task_id, tokens_generated, max_tokens)
|
||||
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish_reason = "stop" if tokens_gen >= max_tokens else None
|
||||
text = f"token{tokens_gen}"
|
||||
|
||||
if finish_reason:
|
||||
uids_to_remove.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (
|
||||
command_id,
|
||||
task_id,
|
||||
tokens_gen,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for uid in uids_to_remove:
|
||||
del self._active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
|
||||
"""
|
||||
Run tasks through the runner, adding shutdown at the end.
|
||||
|
||||
Tasks are sent in order, with shutdown sent last.
|
||||
The batch engine processes between task handling.
|
||||
"""
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
shutdown_task = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
with task_sender, event_receiver:
|
||||
# Send all tasks including shutdown
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown_task)
|
||||
|
||||
# Disable cleanup methods to prevent issues
|
||||
event_sender.close = lambda: None
|
||||
event_sender.join = lambda: None
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> ChatCompletion:
|
||||
return ChatCompletion(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_single_request_generates_tokens(patch_batch_engine: None):
|
||||
"""
|
||||
Verify a single request generates the expected tokens through the batch path.
|
||||
|
||||
Note: With the current non-blocking design, shutdown is processed before
|
||||
batch steps run when all tasks are queued together. This test verifies
|
||||
the runner status reflects active requests.
|
||||
"""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events - this shows the request was inserted
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
|
||||
"""Verify RunnerRunning status includes active_requests count."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_chat_task_acknowledged(patch_batch_engine: None):
|
||||
"""Verify chat completion task is acknowledged with proper status updates."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find the chat task status events
|
||||
chat_running = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Running
|
||||
]
|
||||
|
||||
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
|
||||
|
||||
|
||||
def test_multiple_requests_tracked(patch_batch_engine: None):
|
||||
"""Verify multiple concurrent requests are tracked in active_requests."""
|
||||
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
# Should have at least 2 RunnerRunning events (one per request inserted)
|
||||
assert len(running_events) >= 2, f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
|
||||
|
||||
# First should have 1 active request, second should have 2
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
assert running_events[1].runner_status.active_requests == 2
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Test for distributed synchronization in batch generation.
|
||||
|
||||
These tests verify that all ranks in a distributed setup call the same
|
||||
collective operations in the same order, preventing race conditions and deadlocks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import DistributedOp
|
||||
from exo.worker.runner.runner import _determine_distributed_op
|
||||
|
||||
|
||||
class FakeBatchEngine:
|
||||
"""Minimal fake batch engine for testing _determine_distributed_op."""
|
||||
|
||||
def __init__(self, has_pending_inserts: bool = False, has_active_requests: bool = False):
|
||||
self._has_pending_inserts = has_pending_inserts
|
||||
self._has_active_requests = has_active_requests
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return self._has_pending_inserts
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return self._has_active_requests
|
||||
|
||||
|
||||
def test_distributed_sync_prevents_race_condition():
|
||||
"""
|
||||
Test that the new architecture prevents the race condition.
|
||||
|
||||
In the old code, each rank independently decided what operation to perform
|
||||
based on its local state. This could cause rank 0 to call insert_request()
|
||||
while rank 1 called step(), causing a collective mismatch.
|
||||
|
||||
The fix ensures:
|
||||
1. Only rank 0 determines the operation (via _determine_distributed_op)
|
||||
2. Rank 0's decision is broadcast to all ranks (via sync_operation)
|
||||
3. All ranks then execute the same operation
|
||||
"""
|
||||
# Scenario: rank 0 has pending inserts, rank 1 has active requests
|
||||
# In old code: rank 0 would INSERT, rank 1 would STEP -> race condition
|
||||
# In new code: rank 0's decision (INSERT) is broadcast to all ranks
|
||||
|
||||
# Test _determine_distributed_op gives priority to INSERT over STEP
|
||||
batch_engine_with_inserts = FakeBatchEngine(has_pending_inserts=True, has_active_requests=True)
|
||||
op = _determine_distributed_op(batch_engine_with_inserts, pending_shutdown=None, should_shutdown=False)
|
||||
assert op == DistributedOp.INSERT, "INSERT should take priority when there are pending inserts"
|
||||
|
||||
# Test step is used when there are active requests but no pending inserts
|
||||
batch_engine_with_active = FakeBatchEngine(has_pending_inserts=False, has_active_requests=True)
|
||||
op = _determine_distributed_op(batch_engine_with_active, pending_shutdown=None, should_shutdown=False)
|
||||
assert op == DistributedOp.STEP, "STEP should be used when there are active requests"
|
||||
|
||||
# Test NOOP when nothing to do
|
||||
batch_engine_idle = FakeBatchEngine(has_pending_inserts=False, has_active_requests=False)
|
||||
op = _determine_distributed_op(batch_engine_idle, pending_shutdown=None, should_shutdown=False)
|
||||
assert op == DistributedOp.NOOP, "NOOP should be used when nothing to do"
|
||||
|
||||
|
||||
def test_distributed_sync_shutdown_handling():
|
||||
"""Test that shutdown is properly coordinated across ranks."""
|
||||
batch_engine = FakeBatchEngine(has_pending_inserts=False, has_active_requests=False)
|
||||
|
||||
# Test direct shutdown request
|
||||
op = _determine_distributed_op(batch_engine, pending_shutdown=None, should_shutdown=True)
|
||||
assert op == DistributedOp.SHUTDOWN, "SHUTDOWN should be returned when should_shutdown is True"
|
||||
|
||||
# Test pending shutdown with no active requests
|
||||
from unittest.mock import MagicMock
|
||||
pending_shutdown = MagicMock()
|
||||
op = _determine_distributed_op(batch_engine, pending_shutdown=pending_shutdown, should_shutdown=False)
|
||||
assert op == DistributedOp.SHUTDOWN, "SHUTDOWN should be returned when pending_shutdown and no active requests"
|
||||
|
||||
# Test pending shutdown with active requests - should continue processing
|
||||
batch_engine_active = FakeBatchEngine(has_pending_inserts=False, has_active_requests=True)
|
||||
op = _determine_distributed_op(batch_engine_active, pending_shutdown=pending_shutdown, should_shutdown=False)
|
||||
assert op == DistributedOp.STEP, "Should continue STEP while requests are active even with pending shutdown"
|
||||
@@ -1,11 +1,16 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
@@ -22,6 +27,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
@@ -38,6 +44,9 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
@@ -107,18 +116,85 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
class FakeBatchEngine:
|
||||
"""
|
||||
Fake batch engine for testing.
|
||||
|
||||
Queues requests on insert, returns one token per step.
|
||||
The runner's non-blocking loop drains all tasks before running batch steps,
|
||||
so this engine queues requests and has_active_requests returns True only
|
||||
after at least one request has been inserted.
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
|
||||
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
|
||||
self._uid_counter = 0
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, _task_params in self._pending_inserts:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
self._active_requests[uid] = (command_id, task_id)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
# Process all active requests - return one token and complete
|
||||
for uid, (command_id, task_id) in list(self._active_requests.items()):
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=0,
|
||||
text="hi",
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
)
|
||||
del self._active_requests[uid]
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
|
||||
# initialize_mlx returns a fake "group" (non-None for state machine)
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock())))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
@@ -148,7 +224,8 @@ def _run(tasks: Iterable[Task]):
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPatch):
|
||||
"""Verify chat completion generates tokens, completes, and runner returns to Ready."""
|
||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||
|
||||
expected_chunk = ChunkGenerated(
|
||||
@@ -191,7 +268,9 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerRunning(active_requests=1)
|
||||
),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
@@ -206,7 +285,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user