Compare commits

...

2 Commits

Author SHA1 Message Date
Sami Khan
2d45544da0 use rsh server instead of ssh 2026-01-13 02:46:25 +05:00
Sami Khan
7cbafa768a flash+exo 2026-01-12 10:26:16 +05:00
13 changed files with 880 additions and 14 deletions

View File

@@ -41,6 +41,7 @@ dependencies = [
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-rsh = "exo.rsh.client:main"
# dependencies only required for development
[dependency-groups]

View File

@@ -13,6 +13,7 @@ from pydantic import PositiveInt
import exo.routing.topics as topics
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.rsh.server import run_rsh_server, RSH_PORT
from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
@@ -112,6 +113,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
# Start RSH server for remote execution (used by MPI)
tg.start_soon(run_rsh_server, RSH_PORT)
def shutdown(self):
# if this is our second call to shutdown, just sys.exit

View File

@@ -1,3 +1,5 @@
import os
import subprocess
import time
from collections.abc import AsyncGenerator
from typing import cast
@@ -51,7 +53,9 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
LaunchFLASH,
PlaceInstance,
StopFLASH,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -60,7 +64,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -178,6 +182,10 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
# FLASH simulation endpoints
self.app.post("/flash/launch")(self.launch_flash)
self.app.delete("/flash/{instance_id}")(self.stop_flash)
self.app.get("/flash/instances")(self.list_flash_instances)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -622,6 +630,83 @@ class API:
]
)
async def launch_flash(
self,
simulation_name: str,
flash_executable_path: str,
working_directory: str,
parameter_file_path: str = "",
ranks_per_node: int = 1,
min_nodes: int = 1,
hosts: str = "",
) -> dict:
"""Launch a FLASH MPI simulation across the cluster.
Args:
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
If not provided, IPs are discovered from topology edges.
"""
command = LaunchFLASH(
simulation_name=simulation_name,
flash_executable_path=flash_executable_path,
parameter_file_path=parameter_file_path,
working_directory=working_directory,
ranks_per_node=ranks_per_node,
min_nodes=min_nodes,
hosts=hosts,
)
await self._send(command)
return {
"message": "FLASH launch command received",
"command_id": str(command.command_id),
"simulation_name": simulation_name,
}
async def stop_flash(self, instance_id: InstanceId) -> dict:
"""Stop a running FLASH simulation."""
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
instance = self.state.instances[instance_id]
if not isinstance(instance, FLASHInstance):
raise HTTPException(
status_code=400, detail="Instance is not a FLASH simulation"
)
command = StopFLASH(instance_id=instance_id)
await self._send(command)
return {
"message": "Stop command received",
"command_id": str(command.command_id),
"instance_id": str(instance_id),
}
async def list_flash_instances(self) -> list[dict]:
"""List all FLASH simulation instances."""
flash_instances = []
for instance_id, instance in self.state.instances.items():
if isinstance(instance, FLASHInstance):
# Get runner statuses for this instance
runner_statuses = {}
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
runner_status = self.state.runners.get(runner_id)
runner_statuses[str(node_id)] = (
str(runner_status) if runner_status else None
)
flash_instances.append(
{
"instance_id": str(instance_id),
"simulation_name": instance.simulation_name,
"total_ranks": instance.total_ranks,
"working_directory": instance.working_directory,
"runner_statuses": runner_statuses,
}
)
return flash_instances
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -8,6 +8,7 @@ from exo.master.placement import (
add_instance_to_placements,
delete_instance,
get_transition_events,
place_flash_instance,
place_instance,
)
from exo.shared.apply import apply
@@ -16,8 +17,10 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
LaunchFLASH,
PlaceInstance,
RequestEventLog,
StopFLASH,
TaskFinished,
TestCommand,
)
@@ -173,6 +176,26 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case LaunchFLASH():
placement = place_flash_instance(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case StopFLASH():
# Reuse delete_instance logic to stop FLASH simulation
placement = delete_instance(
DeleteInstance(instance_id=command.instance_id),
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -17,19 +17,24 @@ from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
LaunchFLASH,
PlaceInstance,
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import NodeInfo
from exo.shared.types.common import Host, NodeId
from exo.shared.types.worker.instances import (
FLASHInstance,
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.shared.types.worker.shards import Sharding
@@ -180,6 +185,138 @@ def delete_instance(
raise ValueError(f"Instance {command.instance_id} not found")
def place_flash_instance(
command: LaunchFLASH,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
"""Place a FLASH simulation instance across available nodes.
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
FLASH instances use MPI for communication. We just need to provide the
node IPs so the runner can generate an MPI hostfile.
"""
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
all_nodes = list(topology.list_nodes())
if len(all_nodes) < command.min_nodes:
raise ValueError(
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
)
# Select nodes (take the first min_nodes)
selected_nodes = all_nodes[: command.min_nodes]
logger.info(
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
)
# Build shard assignments (one runner per node for FLASH)
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict = {}
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
flash_model_meta = ModelMetadata(
model_id=ModelId(command.simulation_name),
pretty_name=f"FLASH: {command.simulation_name}",
storage_size=Memory(in_bytes=0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
)
for i, node_info in enumerate(selected_nodes):
runner_id = RunnerId()
node_to_runner[node_info.node_id] = runner_id
runner_to_shard[runner_id] = PipelineShardMetadata(
device_rank=i,
world_size=len(selected_nodes),
model_meta=flash_model_meta,
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=ModelId(command.simulation_name),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
hosts_by_node: dict = {}
# If explicit hosts are provided, use them directly
if command.hosts:
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
for i, node_info in enumerate(selected_nodes):
if i < len(explicit_hosts):
hosts_by_node[node_info.node_id] = [Host(ip=explicit_hosts[i], port=0)]
logger.info(f"FLASH placement: node {node_info.node_id} (rank {i}) -> IP {explicit_hosts[i]}")
else:
logger.warning(f"Not enough hosts provided for node {i}, using localhost")
hosts_by_node[node_info.node_id] = [Host(ip="127.0.0.1", port=0)]
logger.info(f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}")
else:
# Try to get IPs from topology edges
for node_info in selected_nodes:
node_hosts: list[Host] = []
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
for _, edge_data in topology.out_edges(node_info.node_id):
if hasattr(edge_data, "send_back_multiaddr"):
# Extract IP from multiaddr like /ip4/192.168.1.100/tcp/52415
multiaddr = str(edge_data.send_back_multiaddr)
if "/ip4/" in multiaddr:
parts = multiaddr.split("/")
try:
ip_idx = parts.index("ip4") + 1
ip = parts[ip_idx]
# Skip link-local and localhost addresses
if not ip.startswith("169.254.") and not ip.startswith("127."):
node_hosts.append(Host(ip=ip, port=0))
break
except (ValueError, IndexError):
pass
# Last resort: use localhost (will only work for single-node)
if not node_hosts:
logger.warning(
f"Could not determine IP for node {node_info.node_id}, using localhost"
)
node_hosts.append(Host(ip="127.0.0.1", port=0))
hosts_by_node[node_info.node_id] = node_hosts
total_ranks = len(selected_nodes) * command.ranks_per_node
# Determine coordinator IP - first node's first host IP
first_node_id = list(hosts_by_node.keys())[0]
coordinator_ip = hosts_by_node[first_node_id][0].ip if hosts_by_node[first_node_id] else "127.0.0.1"
target_instances[instance_id] = FLASHInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
flash_executable_path=command.flash_executable_path,
parameter_file_path=command.parameter_file_path,
working_directory=command.working_directory,
ranks_per_node=command.ranks_per_node,
total_ranks=total_ranks,
simulation_name=command.simulation_name,
coordinator_ip=coordinator_ip,
)
logger.info(
f"Created FLASH instance {instance_id} with {total_ranks} total ranks"
)
return target_instances
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],

13
src/exo/rsh/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Exo RSH - Remote Shell for MPI without SSH.
This module provides a remote execution mechanism that allows mpirun to spawn
processes on remote nodes without requiring SSH setup. It works by:
1. Each Exo node runs a small HTTP server (RSH server) on port 52416
2. The exo-rsh script acts as a drop-in replacement for ssh
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target
4. The target executes the command and streams output back
Usage:
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
"""

91
src/exo/rsh/client.py Normal file
View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python3
"""exo-rsh - Remote shell client for MPI.
This script is called by mpirun as a replacement for ssh.
Usage: exo-rsh [ssh-options...] hostname command [args...]
It connects to the target node's RSH server (port 52416) and executes the command.
"""
import socket
import sys
from urllib.request import urlopen, Request
from urllib.error import URLError
import json
RSH_PORT = 52416
def resolve_hostname(hostname: str) -> str:
"""Resolve hostname to IP address."""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
# If resolution fails, try using the hostname directly
return hostname
def main():
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
# SSH options we might see: -x (disable X11), -o options, etc.
args = sys.argv[1:]
# Skip SSH-style options
hostname = None
command_start = 0
i = 0
while i < len(args):
arg = args[i]
if arg.startswith("-"):
# Skip option and its value if needed
if arg in ("-o", "-i", "-l", "-p", "-F"):
i += 2 # Skip option and its argument
continue
i += 1
continue
else:
# First non-option is the hostname
hostname = arg
command_start = i + 1
break
i += 1
if hostname is None or command_start >= len(args):
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
sys.exit(1)
command = args[command_start:]
# Resolve hostname to IP
ip = resolve_hostname(hostname)
# Make request to RSH server
url = f"http://{ip}:{RSH_PORT}/execute"
data = json.dumps({"command": command}).encode("utf-8")
try:
req = Request(url, data=data, headers={"Content-Type": "application/json"})
with urlopen(req, timeout=300) as response:
result = json.loads(response.read().decode("utf-8"))
# Output stdout/stderr
if result.get("stdout"):
sys.stdout.write(result["stdout"])
sys.stdout.flush()
if result.get("stderr"):
sys.stderr.write(result["stderr"])
sys.stderr.flush()
sys.exit(result.get("exit_code", 0))
except URLError as e:
print(f"exo-rsh: Failed to connect to {hostname}:{RSH_PORT}: {e}", file=sys.stderr)
sys.exit(255)
except Exception as e:
print(f"exo-rsh: Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

149
src/exo/rsh/server.py Normal file
View File

@@ -0,0 +1,149 @@
"""RSH Server - runs on each Exo node to accept remote execution requests."""
import asyncio
import subprocess
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
from loguru import logger
from pydantic import BaseModel
RSH_PORT = 52416
class ExecuteRequest(BaseModel):
"""Request to execute a command."""
command: list[str]
cwd: Optional[str] = None
env: Optional[dict[str, str]] = None
class ExecuteResponse(BaseModel):
"""Response from command execution."""
exit_code: int
stdout: str
stderr: str
def create_rsh_app() -> FastAPI:
"""Create the RSH FastAPI application."""
app = FastAPI(title="Exo RSH Server")
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "ok"}
@app.post("/execute")
async def execute(request: ExecuteRequest) -> ExecuteResponse:
"""Execute a command and return the result."""
cmd_str = " ".join(request.command)
logger.info(f"RSH executing: {cmd_str}")
try:
# Build environment
import os
env = os.environ.copy()
if request.env:
env.update(request.env)
# Check if command contains shell metacharacters (semicolons, pipes, etc.)
# If so, run through shell. mpirun sends complex commands like:
# "VAR=value;export VAR;/path/to/prted --args"
needs_shell = any(c in cmd_str for c in ";|&$`")
if needs_shell:
# Run through shell
process = await asyncio.create_subprocess_shell(
cmd_str,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
else:
# Execute directly
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
stdout, stderr = await process.communicate()
exit_code = process.returncode or 0
logger.info(f"RSH command completed with exit code {exit_code}")
return ExecuteResponse(
exit_code=exit_code,
stdout=stdout.decode("utf-8", errors="replace"),
stderr=stderr.decode("utf-8", errors="replace"),
)
except FileNotFoundError as e:
logger.error(f"RSH command not found: {e}")
return ExecuteResponse(
exit_code=127,
stdout="",
stderr=f"Command not found: {request.command[0]}",
)
except Exception as e:
logger.error(f"RSH execution error: {e}")
return ExecuteResponse(
exit_code=1,
stdout="",
stderr=str(e),
)
@app.post("/execute_streaming")
async def execute_streaming(request: ExecuteRequest):
"""Execute a command and stream the output."""
logger.info(f"RSH streaming execute: {' '.join(request.command)}")
async def stream_output():
try:
env = None
if request.env:
import os
env = os.environ.copy()
env.update(request.env)
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
cwd=request.cwd,
env=env,
)
if process.stdout:
async for line in process.stdout:
yield line
await process.wait()
except Exception as e:
yield f"Error: {e}\n".encode()
return StreamingResponse(
stream_output(),
media_type="application/octet-stream",
)
return app
async def run_rsh_server(port: int = RSH_PORT):
"""Run the RSH server."""
app = create_rsh_app()
config = Config()
config.bind = [f"0.0.0.0:{port}"]
config.accesslog = None # Disable access logs for cleaner output
logger.info(f"Starting RSH server on port {port}")
await serve(app, config) # type: ignore

View File

@@ -35,6 +35,26 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class LaunchFLASH(BaseCommand):
"""Command to launch a FLASH MPI simulation."""
simulation_name: str
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
min_nodes: int = 1
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
# Used when topology edges don't contain IP addresses
hosts: str = ""
class StopFLASH(BaseCommand):
"""Command to stop a running FLASH simulation."""
instance_id: InstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -50,6 +70,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| LaunchFLASH
| StopFLASH
| TaskFinished
)

View File

@@ -14,6 +14,7 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
FLASH = "FLASH"
class BaseInstance(TaggedModel):
@@ -34,8 +35,27 @@ class MlxJacclInstance(BaseInstance):
jaccl_coordinators: dict[NodeId, str]
class FLASHInstance(BaseInstance):
"""Instance for FLASH MPI simulation.
Unlike MLX instances which do tensor parallelism, FLASH instances
coordinate MPI processes across nodes. Each node runs one or more
MPI ranks of the FLASH simulation.
"""
hosts_by_node: dict[NodeId, list[Host]]
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
total_ranks: int
simulation_name: str
coordinator_ip: str
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
class BoundInstance(CamelCaseModel):

View File

@@ -21,7 +21,7 @@ from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -50,6 +50,11 @@ def plan(
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
) -> Task | None:
# Check for FLASH instance tasks first
flash_task = _plan_flash(runners, instances)
if flash_task is not None:
return flash_task
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
@@ -62,6 +67,34 @@ def plan(
)
def _plan_flash(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
"""Plan tasks specifically for FLASH instances.
FLASH instances have a simpler lifecycle:
- CreateRunner (handled by _create_runner)
- LoadModel (starts the simulation immediately)
- Shutdown (handled by _kill_runner)
This function handles the LoadModel step for FLASH instances,
skipping the MLX-specific download/init/warmup steps.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
# Only handle FLASH instances
if not isinstance(instance, FLASHInstance):
continue
# If runner is idle, emit LoadModel to start the simulation
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
return None
def _kill_runner(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -114,6 +147,10 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
# FLASH instances don't need model downloads
if isinstance(runner.bound_instance.instance, FLASHInstance):
continue
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status

View File

@@ -4,7 +4,7 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -17,20 +17,27 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
global logger
logger = _logger
# Import main after setting global logger - this lets us just import logger from this module
# Route based on instance type
try:
from exo.worker.runner.runner import main
if isinstance(bound_instance.instance, FLASHInstance):
# FLASH MPI simulation runner
from exo.worker.runner.flash_runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver)
else:
# MLX runner (default)
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -0,0 +1,278 @@
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
Exo-native distributed MPI:
- Exo handles node discovery and coordination
- Coordinator generates hostfile from Exo topology
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
- Each Exo node runs an RSH server on port 52416 for remote execution
- Workers just report ready and wait
"""
import os
import shutil
import socket
import subprocess
import threading
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
LoadModel,
Shutdown,
Task,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.bootstrap import logger
# Find mpirun in PATH, fallback to common locations
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
# exo-rsh is installed as console script by exo package
EXO_RSH_PATH = shutil.which("exo-rsh")
if not EXO_RSH_PATH:
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
"""Determine this node's rank based on position in hosts_by_node."""
for i, node_id in enumerate(instance.hosts_by_node.keys()):
if str(node_id) == str(my_node_id):
return i
return -1
def get_coordinator_host(instance: FLASHInstance) -> str:
"""Get the IP of the coordinator node."""
return instance.coordinator_ip
def resolve_host(host: str) -> str:
"""Resolve host string to a usable hostname for MPI hostfile.
Accepts either an IP address or hostname. For IPs, attempts to resolve
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
"""
# Check if input is already a hostname (not an IP)
try:
socket.inet_aton(host)
is_ip = True
except socket.error:
is_ip = False
if not is_ip:
# Already a hostname, verify it resolves and return as-is
try:
socket.gethostbyname(host)
return host
except socket.gaierror:
logger.warning(f"Hostname {host} does not resolve, using anyway")
return host
# It's an IP address, try to resolve to hostname
try:
hostname, _, _ = socket.gethostbyaddr(host)
hostname = hostname.split('.')[0]
logger.info(f"Resolved {host} to {hostname}")
return hostname
except socket.herror:
pass
# Fall back to IP
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
return host
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
"""Generate MPI hostfile from instance topology."""
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
with open(hostfile_path, 'w') as f:
for node_id, hosts in instance.hosts_by_node.items():
if hosts:
host = resolve_host(hosts[0].ip)
f.write(f"{host} slots={instance.ranks_per_node}\n")
logger.info(f"Generated hostfile at {hostfile_path}")
with open(hostfile_path, 'r') as f:
logger.info(f"Hostfile contents:\n{f.read()}")
return hostfile_path
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
):
"""Main FLASH runner loop.
Coordinator: generates hostfile and runs mpirun (which SSHs to workers)
Workers: just report ready and wait for mpirun to spawn processes on them
"""
assert isinstance(bound_instance.instance, FLASHInstance)
instance = bound_instance.instance
runner_id = bound_instance.bound_runner_id
my_node_id = str(bound_instance.bound_node_id)
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
my_rank = get_my_rank(instance, my_node_id)
world_size = len(instance.hosts_by_node)
is_coordinator = (my_rank == 0)
coordinator_ip = get_coordinator_host(instance)
logger.info(f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}")
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
process: subprocess.Popen | None = None
current_status: RunnerStatus = RunnerIdle()
shutdown_requested = False
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
def monitor_output(proc: subprocess.Popen):
"""Monitor FLASH stdout for progress updates."""
if proc.stdout is None:
return
for line in iter(proc.stdout.readline, b""):
if shutdown_requested:
break
try:
decoded = line.decode("utf-8", errors="replace").strip()
if decoded:
logger.info(f"[FLASH] {decoded}")
except Exception as e:
logger.warning(f"Error parsing FLASH output: {e}")
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(current_status, RunnerIdle):
current_status = RunnerLoading()
logger.info("Starting FLASH simulation")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
if is_coordinator:
# Coordinator: generate hostfile and run mpirun
hostfile = generate_hostfile(instance, instance.working_directory)
iface = instance.network_interface
cmd = [
MPIRUN_PATH,
"-np", str(instance.total_ranks),
"--hostfile", hostfile,
"--wdir", instance.working_directory,
"--oversubscribe",
"--mca", "btl", "tcp,self",
"--mca", "btl_tcp_if_include", iface,
"--mca", "oob_tcp_if_include", iface,
"--mca", "plm_rsh_no_tree_spawn", "1",
]
# Use exo-rsh for remote execution (no SSH needed)
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
cmd.append(instance.flash_executable_path)
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=instance.working_directory,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
monitor_thread = threading.Thread(
target=monitor_output, args=(process,), daemon=True
)
monitor_thread.start()
current_status = RunnerRunning()
logger.info(f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks")
else:
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
logger.info(f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh")
current_status = RunnerRunning()
except Exception as e:
logger.error(f"Failed to start FLASH: {e}")
import traceback
logger.error(traceback.format_exc())
current_status = RunnerFailed(error_message=str(e))
case Shutdown():
shutdown_requested = True
current_status = RunnerShuttingDown()
logger.info("FLASH runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if process and process.poll() is None:
logger.info("Terminating FLASH simulation")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("FLASH didn't terminate, killing")
process.kill()
process.wait()
current_status = RunnerShutdown()
case _:
if process and process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info("FLASH simulation completed successfully")
current_status = RunnerReady()
else:
logger.error(f"FLASH simulation failed with code {exit_code}")
current_status = RunnerFailed(
error_message=f"Exit code {exit_code}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break
if process and process.poll() is None:
process.terminate()
process.wait(timeout=5)
logger.info("FLASH runner exiting")