Compare commits

...

1 Commits

Author SHA1 Message Date
Sami Khan
7cbafa768a flash+exo 2026-01-12 10:26:16 +05:00
8 changed files with 613 additions and 14 deletions

View File

@@ -1,3 +1,5 @@
import os
import subprocess
import time
from collections.abc import AsyncGenerator
from typing import cast
@@ -51,7 +53,9 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
LaunchFLASH,
PlaceInstance,
StopFLASH,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -60,7 +64,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -178,6 +182,10 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
# FLASH simulation endpoints
self.app.post("/flash/launch")(self.launch_flash)
self.app.delete("/flash/{instance_id}")(self.stop_flash)
self.app.get("/flash/instances")(self.list_flash_instances)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -622,6 +630,83 @@ class API:
]
)
async def launch_flash(
self,
simulation_name: str,
flash_executable_path: str,
working_directory: str,
parameter_file_path: str = "",
ranks_per_node: int = 1,
min_nodes: int = 1,
hosts: str = "",
) -> dict:
"""Launch a FLASH MPI simulation across the cluster.
Args:
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
If not provided, IPs are discovered from topology edges.
"""
command = LaunchFLASH(
simulation_name=simulation_name,
flash_executable_path=flash_executable_path,
parameter_file_path=parameter_file_path,
working_directory=working_directory,
ranks_per_node=ranks_per_node,
min_nodes=min_nodes,
hosts=hosts,
)
await self._send(command)
return {
"message": "FLASH launch command received",
"command_id": str(command.command_id),
"simulation_name": simulation_name,
}
async def stop_flash(self, instance_id: InstanceId) -> dict:
"""Stop a running FLASH simulation."""
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
instance = self.state.instances[instance_id]
if not isinstance(instance, FLASHInstance):
raise HTTPException(
status_code=400, detail="Instance is not a FLASH simulation"
)
command = StopFLASH(instance_id=instance_id)
await self._send(command)
return {
"message": "Stop command received",
"command_id": str(command.command_id),
"instance_id": str(instance_id),
}
async def list_flash_instances(self) -> list[dict]:
"""List all FLASH simulation instances."""
flash_instances = []
for instance_id, instance in self.state.instances.items():
if isinstance(instance, FLASHInstance):
# Get runner statuses for this instance
runner_statuses = {}
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
runner_status = self.state.runners.get(runner_id)
runner_statuses[str(node_id)] = (
str(runner_status) if runner_status else None
)
flash_instances.append(
{
"instance_id": str(instance_id),
"simulation_name": instance.simulation_name,
"total_ranks": instance.total_ranks,
"working_directory": instance.working_directory,
"runner_statuses": runner_statuses,
}
)
return flash_instances
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -8,6 +8,7 @@ from exo.master.placement import (
add_instance_to_placements,
delete_instance,
get_transition_events,
place_flash_instance,
place_instance,
)
from exo.shared.apply import apply
@@ -16,8 +17,10 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
LaunchFLASH,
PlaceInstance,
RequestEventLog,
StopFLASH,
TaskFinished,
TestCommand,
)
@@ -173,6 +176,26 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case LaunchFLASH():
placement = place_flash_instance(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case StopFLASH():
# Reuse delete_instance logic to stop FLASH simulation
placement = delete_instance(
DeleteInstance(instance_id=command.instance_id),
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -17,19 +17,24 @@ from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
LaunchFLASH,
PlaceInstance,
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import NodeInfo
from exo.shared.types.common import Host, NodeId
from exo.shared.types.worker.instances import (
FLASHInstance,
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.shared.types.worker.shards import Sharding
@@ -180,6 +185,138 @@ def delete_instance(
raise ValueError(f"Instance {command.instance_id} not found")
def place_flash_instance(
command: LaunchFLASH,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
"""Place a FLASH simulation instance across available nodes.
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
FLASH instances use MPI for communication. We just need to provide the
node IPs so the runner can generate an MPI hostfile.
"""
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
all_nodes = list(topology.list_nodes())
if len(all_nodes) < command.min_nodes:
raise ValueError(
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
)
# Select nodes (take the first min_nodes)
selected_nodes = all_nodes[: command.min_nodes]
logger.info(
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
)
# Build shard assignments (one runner per node for FLASH)
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict = {}
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
flash_model_meta = ModelMetadata(
model_id=ModelId(command.simulation_name),
pretty_name=f"FLASH: {command.simulation_name}",
storage_size=Memory(in_bytes=0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
)
for i, node_info in enumerate(selected_nodes):
runner_id = RunnerId()
node_to_runner[node_info.node_id] = runner_id
runner_to_shard[runner_id] = PipelineShardMetadata(
device_rank=i,
world_size=len(selected_nodes),
model_meta=flash_model_meta,
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=ModelId(command.simulation_name),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
hosts_by_node: dict = {}
# If explicit hosts are provided, use them directly
if command.hosts:
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
for i, node_info in enumerate(selected_nodes):
if i < len(explicit_hosts):
hosts_by_node[node_info.node_id] = [Host(ip=explicit_hosts[i], port=0)]
logger.info(f"FLASH placement: node {node_info.node_id} (rank {i}) -> IP {explicit_hosts[i]}")
else:
logger.warning(f"Not enough hosts provided for node {i}, using localhost")
hosts_by_node[node_info.node_id] = [Host(ip="127.0.0.1", port=0)]
logger.info(f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}")
else:
# Try to get IPs from topology edges
for node_info in selected_nodes:
node_hosts: list[Host] = []
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
for _, edge_data in topology.out_edges(node_info.node_id):
if hasattr(edge_data, "send_back_multiaddr"):
# Extract IP from multiaddr like /ip4/192.168.1.100/tcp/52415
multiaddr = str(edge_data.send_back_multiaddr)
if "/ip4/" in multiaddr:
parts = multiaddr.split("/")
try:
ip_idx = parts.index("ip4") + 1
ip = parts[ip_idx]
# Skip link-local and localhost addresses
if not ip.startswith("169.254.") and not ip.startswith("127."):
node_hosts.append(Host(ip=ip, port=0))
break
except (ValueError, IndexError):
pass
# Last resort: use localhost (will only work for single-node)
if not node_hosts:
logger.warning(
f"Could not determine IP for node {node_info.node_id}, using localhost"
)
node_hosts.append(Host(ip="127.0.0.1", port=0))
hosts_by_node[node_info.node_id] = node_hosts
total_ranks = len(selected_nodes) * command.ranks_per_node
# Determine coordinator IP - first node's first host IP
first_node_id = list(hosts_by_node.keys())[0]
coordinator_ip = hosts_by_node[first_node_id][0].ip if hosts_by_node[first_node_id] else "127.0.0.1"
target_instances[instance_id] = FLASHInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
flash_executable_path=command.flash_executable_path,
parameter_file_path=command.parameter_file_path,
working_directory=command.working_directory,
ranks_per_node=command.ranks_per_node,
total_ranks=total_ranks,
simulation_name=command.simulation_name,
coordinator_ip=coordinator_ip,
)
logger.info(
f"Created FLASH instance {instance_id} with {total_ranks} total ranks"
)
return target_instances
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],

View File

@@ -35,6 +35,26 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class LaunchFLASH(BaseCommand):
"""Command to launch a FLASH MPI simulation."""
simulation_name: str
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
min_nodes: int = 1
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
# Used when topology edges don't contain IP addresses
hosts: str = ""
class StopFLASH(BaseCommand):
"""Command to stop a running FLASH simulation."""
instance_id: InstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -50,6 +70,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| LaunchFLASH
| StopFLASH
| TaskFinished
)

View File

@@ -14,6 +14,7 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
FLASH = "FLASH"
class BaseInstance(TaggedModel):
@@ -34,8 +35,27 @@ class MlxJacclInstance(BaseInstance):
jaccl_coordinators: dict[NodeId, str]
class FLASHInstance(BaseInstance):
"""Instance for FLASH MPI simulation.
Unlike MLX instances which do tensor parallelism, FLASH instances
coordinate MPI processes across nodes. Each node runs one or more
MPI ranks of the FLASH simulation.
"""
hosts_by_node: dict[NodeId, list[Host]]
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
total_ranks: int
simulation_name: str
coordinator_ip: str
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
class BoundInstance(CamelCaseModel):

View File

@@ -21,7 +21,7 @@ from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -50,6 +50,11 @@ def plan(
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
) -> Task | None:
# Check for FLASH instance tasks first
flash_task = _plan_flash(runners, instances)
if flash_task is not None:
return flash_task
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
@@ -62,6 +67,34 @@ def plan(
)
def _plan_flash(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
"""Plan tasks specifically for FLASH instances.
FLASH instances have a simpler lifecycle:
- CreateRunner (handled by _create_runner)
- LoadModel (starts the simulation immediately)
- Shutdown (handled by _kill_runner)
This function handles the LoadModel step for FLASH instances,
skipping the MLX-specific download/init/warmup steps.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
# Only handle FLASH instances
if not isinstance(instance, FLASHInstance):
continue
# If runner is idle, emit LoadModel to start the simulation
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
return None
def _kill_runner(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -114,6 +147,10 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
# FLASH instances don't need model downloads
if isinstance(runner.bound_instance.instance, FLASHInstance):
continue
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status

View File

@@ -4,7 +4,7 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -17,20 +17,27 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
global logger
logger = _logger
# Import main after setting global logger - this lets us just import logger from this module
# Route based on instance type
try:
from exo.worker.runner.runner import main
if isinstance(bound_instance.instance, FLASHInstance):
# FLASH MPI simulation runner
from exo.worker.runner.flash_runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver)
else:
# MLX runner (default)
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -0,0 +1,268 @@
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
Exo-native distributed MPI:
- Exo handles node discovery and coordination
- Coordinator generates hostfile from Exo topology
- mpirun uses SSH (keys already set up) to spawn on remote nodes
- Workers just report ready and wait
"""
import os
import shutil
import socket
import subprocess
import threading
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
LoadModel,
Shutdown,
Task,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.bootstrap import logger
# Find mpirun in PATH, fallback to common locations
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
"""Determine this node's rank based on position in hosts_by_node."""
for i, node_id in enumerate(instance.hosts_by_node.keys()):
if str(node_id) == str(my_node_id):
return i
return -1
def get_coordinator_host(instance: FLASHInstance) -> str:
"""Get the IP of the coordinator node."""
return instance.coordinator_ip
def resolve_host(host: str) -> str:
"""Resolve host string to a usable hostname for MPI hostfile.
Accepts either an IP address or hostname. For IPs, attempts to resolve
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
"""
# Check if input is already a hostname (not an IP)
try:
socket.inet_aton(host)
is_ip = True
except socket.error:
is_ip = False
if not is_ip:
# Already a hostname, verify it resolves and return as-is
try:
socket.gethostbyname(host)
return host
except socket.gaierror:
logger.warning(f"Hostname {host} does not resolve, using anyway")
return host
# It's an IP address, try to resolve to hostname
try:
hostname, _, _ = socket.gethostbyaddr(host)
hostname = hostname.split('.')[0]
logger.info(f"Resolved {host} to {hostname}")
return hostname
except socket.herror:
pass
# Fall back to IP
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
return host
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
"""Generate MPI hostfile from instance topology."""
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
with open(hostfile_path, 'w') as f:
for node_id, hosts in instance.hosts_by_node.items():
if hosts:
host = resolve_host(hosts[0].ip)
f.write(f"{host} slots={instance.ranks_per_node}\n")
logger.info(f"Generated hostfile at {hostfile_path}")
with open(hostfile_path, 'r') as f:
logger.info(f"Hostfile contents:\n{f.read()}")
return hostfile_path
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
):
"""Main FLASH runner loop.
Coordinator: generates hostfile and runs mpirun (which SSHs to workers)
Workers: just report ready and wait for mpirun to spawn processes on them
"""
assert isinstance(bound_instance.instance, FLASHInstance)
instance = bound_instance.instance
runner_id = bound_instance.bound_runner_id
my_node_id = str(bound_instance.bound_node_id)
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
my_rank = get_my_rank(instance, my_node_id)
world_size = len(instance.hosts_by_node)
is_coordinator = (my_rank == 0)
coordinator_ip = get_coordinator_host(instance)
logger.info(f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}")
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
process: subprocess.Popen | None = None
current_status: RunnerStatus = RunnerIdle()
shutdown_requested = False
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
def monitor_output(proc: subprocess.Popen):
"""Monitor FLASH stdout for progress updates."""
if proc.stdout is None:
return
for line in iter(proc.stdout.readline, b""):
if shutdown_requested:
break
try:
decoded = line.decode("utf-8", errors="replace").strip()
if decoded:
logger.info(f"[FLASH] {decoded}")
except Exception as e:
logger.warning(f"Error parsing FLASH output: {e}")
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(current_status, RunnerIdle):
current_status = RunnerLoading()
logger.info("Starting FLASH simulation")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
if is_coordinator:
# Coordinator: generate hostfile and run mpirun
hostfile = generate_hostfile(instance, instance.working_directory)
iface = instance.network_interface
cmd = [
MPIRUN_PATH,
"-np", str(instance.total_ranks),
"--hostfile", hostfile,
"--wdir", instance.working_directory,
"--oversubscribe",
"--mca", "btl", "tcp,self",
"--mca", "btl_tcp_if_include", iface,
"--mca", "oob_tcp_if_include", iface,
"--mca", "plm_rsh_no_tree_spawn", "1",
instance.flash_executable_path,
]
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=instance.working_directory,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
monitor_thread = threading.Thread(
target=monitor_output, args=(process,), daemon=True
)
monitor_thread.start()
current_status = RunnerRunning()
logger.info(f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks")
else:
# Worker: mpirun on coordinator will SSH to spawn processes here
logger.info(f"Worker {my_rank}: Ready for mpirun to spawn processes")
current_status = RunnerRunning()
except Exception as e:
logger.error(f"Failed to start FLASH: {e}")
import traceback
logger.error(traceback.format_exc())
current_status = RunnerFailed(error_message=str(e))
case Shutdown():
shutdown_requested = True
current_status = RunnerShuttingDown()
logger.info("FLASH runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if process and process.poll() is None:
logger.info("Terminating FLASH simulation")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("FLASH didn't terminate, killing")
process.kill()
process.wait()
current_status = RunnerShutdown()
case _:
if process and process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info("FLASH simulation completed successfully")
current_status = RunnerReady()
else:
logger.error(f"FLASH simulation failed with code {exit_code}")
current_status = RunnerFailed(
error_message=f"Exit code {exit_code}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break
if process and process.poll() is None:
process.terminate()
process.wait(timeout=5)
logger.info("FLASH runner exiting")