Andrei/mp capture stdio (#2056)

## Motivation

Process-isolated runner crashes and C-extension failures can write
directly to fd-level stdout/stderr, bypassing Python/loguru. We need to
capture that output per runner process without polluting the main
process or other workers, and without breaking operation when the parent
stdio is detached.

## Changes

- Added `AsyncProcess`, a spawn-only multiprocessing wrapper that
redirects child stdout/stderr to pipes and exposes them as in-memory
`Receiver[bytes]`s
- Replaced runner-supervisor's raw `multiprocessing.Process` usage with
`AsyncProcess`
- Added `--no-stdio`, redirecting stdin/stdout/stderr to `/dev/null`
after logging is configured
- Disabled verbose MLX
- Added tests covering stdio capture, child crashes, repeated bad
children, SIGTERM/SIGKILL shutdown escalation, stdio detachment, and
spawning captured children from a stdio-detached parent

## Why It Works

The parent can redirect its own stdio fds to `/dev/null`, while
`AsyncProcess` installs fresh pipe fds over fd 1 and 2 inside each
spawned child. That keeps stdio-detached parents quiet while preserving
per-runner stdout/stderr capture. Runner shutdown is still bounded:
SIGTERM grace first, then SIGKILL escalation if needed.

Next direction: the runner supervisor currently drains captured output
and logs it as stdout/debug and stderr/warning. This should be split
into more useful process-isolated error reporting instead of just log
forwarding (regex match on errors to obtain "reason" string, best
effort).

## Test Plan

### Manual Testing

Ran on 4 Mac Minis in a Thunderbolt 4 ring, can see that runner's
stdout/stderr contents are being captured.

### Automated Testing

- Added async-process tests for fd-level stdout/stderr capture, Python
traceback capture, bounded-buffer output, child `exit`/abort, parent
stdio preservation, fd leak checks, spawn-context mp channels, and
SIGTERM/SIGKILL shutdown behavior
- Added stdio-detach tests proving stdio detaches to `/dev/null`, a
stdio-detached parent can still spawn and capture a child, and the same
stdio-detached parent can spawn/capture multiple children sequentially
- Updated runner-supervisor tests for the new `AsyncProcess.exitcode`
path
This commit is contained in:
Andrei Cravtov
2026-05-09 22:45:14 +01:00
committed by GitHub
parent ce37bdceb6
commit 45df74ba98
8 changed files with 1065 additions and 58 deletions

View File

@@ -23,6 +23,7 @@ from exo.shared.election import Election, ElectionResult
from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.daemon import detach_stdio_to_devnull
from exo.utils.pidfile import PidfileLockError, acquire_exo_pidfile
from exo.utils.pydantic_ext import FrozenModel
from exo.utils.task_group import TaskGroup
@@ -282,6 +283,10 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
if args.no_stdio:
detach_stdio_to_devnull()
logger.info("Detached stdio to /dev/null")
logger.info(f"{'=' * 40}")
logger.info(f"Starting EXO | pid={os.getpid()}")
logger.info(f"{'=' * 40}")
@@ -330,6 +335,7 @@ class Args(FrozenModel):
offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true"
no_batch: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
no_stdio: bool = False
bootstrap_peers: list[str] = []
libp2p_port: int
@@ -389,6 +395,11 @@ class Args(FrozenModel):
action="store_true",
help="Disable continuous batching, use sequential generation",
)
parser.add_argument(
"--no-stdio",
action="store_true",
help="Detach stdin/stdout/stderr to /dev/null after logging is configured",
)
parser.add_argument(
"--bootstrap-peers",
type=lambda s: [p for p in s.split(",") if p],

View File

@@ -0,0 +1,290 @@
from __future__ import annotations
import contextlib
import faulthandler
import multiprocessing as mp
import os
import sys
from collections.abc import Callable, Iterable, Mapping
from multiprocessing.process import BaseProcess
from multiprocessing.resource_sharer import DupFd
from typing import final
from anyio import (
TASK_STATUS_IGNORED,
BrokenResourceError,
CancelScope,
ClosedResourceError,
Event,
create_task_group,
move_on_after,
sleep,
wait_readable,
)
from anyio.abc import TaskStatus
from loguru import logger
from exo.utils.channels import Receiver, Sender, channel
_STDOUT_FD = 1
_STDERR_FD = 2
_READ_CHUNK_SIZE = 64 * 1024
_TERMINATE_GRACE_SECONDS = 10.0
_TERMINATE_RETRY_GRACE_SECONDS = 2.0
_TERMINATE_ATTEMPTS = 10
_KILL_GRACE_SECONDS = 5.0
@final
class AsyncProcess:
def __init__(
self,
target: Callable[..., object] | None = None,
name: str | None = None,
args: Iterable[object] = (),
kwargs: Mapping[str, object] | None = None,
*,
daemon: bool | None = None,
) -> None:
# setup state
self._target = target
self._name = name
self._args = args
self._kwargs = kwargs
self._daemon = daemon
# lifecycle state
self._process: BaseProcess | None = None
self._pid: int | None = None
self._stdout_tx, self._stdout_rx = channel[bytes]()
self._stderr_tx, self._stderr_rx = channel[bytes]()
self._started = Event()
self._done = Event()
self._run_cancel_scope: CancelScope | None = None
self._start_error: BaseException | None = None
self._exitcode: int | None = None
async def run(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
if self._run_cancel_scope is not None or self._done.is_set():
raise RuntimeError("process has already been started")
stdout_read_fd: int | None = None
stdout_write_fd: int | None = None
stderr_read_fd: int | None = None
stderr_write_fd: int | None = None
def cleanup_stdio_fd() -> None:
nonlocal stdout_read_fd, stdout_write_fd, stderr_read_fd, stderr_write_fd
stdout_read_fd = _close_fd(stdout_read_fd)
stdout_write_fd = _close_fd(stdout_write_fd)
stderr_read_fd = _close_fd(stderr_read_fd)
stderr_write_fd = _close_fd(stderr_write_fd)
try:
with CancelScope() as run_cancel_scope:
self._run_cancel_scope = run_cancel_scope
stdout_read_fd, stdout_write_fd = os.pipe()
stderr_read_fd, stderr_write_fd = os.pipe()
process = mp.Process(
target=_run_with_captured_stdio,
name=self._name,
args=(
DupFd(stdout_write_fd),
DupFd(stderr_write_fd),
self._target,
*self._args,
),
kwargs={} if self._kwargs is None else self._kwargs,
daemon=self._daemon,
)
process.start()
pid = process.pid
if pid is None:
raise RuntimeError("started process has no pid")
# important to close parent write-side FD to prevent hangs
stdout_write_fd = _close_fd(stdout_write_fd)
stderr_write_fd = _close_fd(stderr_write_fd)
self._process = process
self._pid = pid
self._started.set()
async with create_task_group() as tg:
tg.start_soon(_drain_fd, stdout_read_fd, self._stdout_tx)
stdout_read_fd = None
tg.start_soon(_drain_fd, stderr_read_fd, self._stderr_tx)
stderr_read_fd = None
task_status.started()
await self.wait()
except BaseException as exc:
if not self._started.is_set():
self._start_error = exc
self._started.set()
raise
finally:
try:
with CancelScope(shield=True):
await self._terminate_if_still_alive()
finally:
cleanup_stdio_fd()
for tx in (self._stdout_tx, self._stderr_tx):
with contextlib.suppress(Exception):
await tx.aclose()
if self._process is not None:
with contextlib.suppress(ValueError):
self._process.close()
self._run_cancel_scope = None
self._done.set()
async def stop(self) -> None:
if self._run_cancel_scope is None and not self._done.is_set():
raise RuntimeError("process has not been started")
if self._run_cancel_scope is not None:
self._run_cancel_scope.cancel()
await self._done.wait()
async def aclose(self) -> None:
await self.stop()
async def wait(self) -> int:
if self._exitcode is not None:
return self._exitcode
await self._started.wait()
if self._start_error is not None:
raise self._start_error
assert self._process is not None
while True:
exitcode = self.exitcode
if exitcode is not None:
return exitcode
await sleep(0.01)
@property
def pid(self) -> int:
if self._pid is None:
raise RuntimeError("process has not been started")
return self._pid
@property
def exitcode(self) -> int | None:
if self._exitcode is not None:
return self._exitcode
if self._process is None:
return None
with contextlib.suppress(ValueError):
exitcode = self._process.exitcode
if exitcode is not None:
self._exitcode = exitcode
return exitcode
return None
def is_alive(self) -> bool:
if self._process is None:
return False
with contextlib.suppress(ValueError):
return self._process.is_alive()
return False
# TODO: maybe in the future if needed, create stdin that is also installed,
# and a ByteSendStream handle is provided for it :)
@property
def stdout(self) -> Receiver[bytes]:
return self._stdout_rx
@property
def stderr(self) -> Receiver[bytes]:
return self._stderr_rx
async def _terminate_if_still_alive(self) -> None:
process = self._process
if process is None:
return
if self.exitcode is not None:
return
with contextlib.suppress(ValueError):
if not process.is_alive():
return
logger.warning("Child process didn't shut down successfully, terminating")
process.terminate()
with move_on_after(_TERMINATE_GRACE_SECONDS):
await self.wait()
if self.exitcode is not None or not process.is_alive():
logger.warning("Terminated nicely in the first attempt!")
return
for attempt in range(2, _TERMINATE_ATTEMPTS + 1):
process.terminate()
with move_on_after(_TERMINATE_RETRY_GRACE_SECONDS):
await self.wait()
if self.exitcode is not None or not process.is_alive():
logger.warning(f"That took {attempt} attempts :)")
return
logger.critical("Child process didn't respond to SIGTERM, killing")
j = 0
while True:
process.kill()
with move_on_after(_KILL_GRACE_SECONDS):
await self.wait()
j += 1
if self.exitcode is not None or not process.is_alive():
break
logger.warning(f"That took {j} attempts :(")
# Spawn-mode multiprocessing requires a module-level target that can be pickled.
def _run_with_captured_stdio(
stdout: DupFd,
stderr: DupFd,
target: Callable[..., object] | None,
*target_args: object,
**target_kwargs: object,
) -> None:
stdout_fd = stdout.detach()
stderr_fd = stderr.detach()
try:
os.dup2(stdout_fd, _STDOUT_FD)
os.dup2(stderr_fd, _STDERR_FD)
finally:
for fd in (stdout_fd, stderr_fd):
if fd not in (_STDOUT_FD, _STDERR_FD):
_close_fd(fd)
faulthandler.enable(file=sys.stderr, all_threads=True)
if target is not None:
target(*target_args, **target_kwargs)
async def _drain_fd(fd: int, tx: Sender[bytes]) -> None:
try:
while True:
await wait_readable(fd)
chunk = os.read(fd, _READ_CHUNK_SIZE)
if not chunk:
return
await tx.send(chunk)
except (BrokenPipeError, BrokenResourceError, ClosedResourceError):
pass
finally:
_close_fd(fd)
await tx.aclose()
def _close_fd(fd: int | None) -> None:
if fd is None:
return
with contextlib.suppress(OSError):
os.close(fd)

28
src/exo/utils/daemon.py Normal file
View File

@@ -0,0 +1,28 @@
import os
import sys
_STDIN_FD = 0
_STDOUT_FD = 1
_STDERR_FD = 2
def detach_stdio_to_devnull() -> None:
"""Redirect process stdio file descriptors to /dev/null."""
for stream in (sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__):
if stream is not None:
stream.flush()
stdin_fd = os.open(os.devnull, os.O_RDONLY)
stdout_fd = os.open(os.devnull, os.O_WRONLY)
stderr_fd = os.open(os.devnull, os.O_WRONLY)
try:
# dup2 closes the target fd first, but leaves the source fd open.
os.dup2(stdin_fd, _STDIN_FD)
os.dup2(stdout_fd, _STDOUT_FD)
os.dup2(stderr_fd, _STDERR_FD)
finally:
for fd in (stdin_fd, stdout_fd, stderr_fd):
if fd not in (_STDIN_FD, _STDOUT_FD, _STDERR_FD):
os.close(fd)

View File

@@ -0,0 +1,515 @@
import contextlib
import os
import signal
import sys
import time
from collections.abc import AsyncIterator, Callable
from types import FrameType
import mlx.core as mx
import pytest
from _pytest.capture import CaptureFixture
from anyio import EndOfStream, create_task_group, fail_after
from pytest import MonkeyPatch
import exo.utils.async_process as async_process
from exo.utils.async_process import (
AsyncProcess,
)
from exo.utils.channels import MpSender, Receiver, mp_channel
def _write_to_stdio(prefix: str, *, stderr_suffix: str) -> None:
print(f"{prefix}: python stdout")
print(f"{prefix}: python stderr {stderr_suffix}", file=sys.stderr)
os.write(1, f"{prefix}: fd stdout\n".encode())
os.write(2, f"{prefix}: fd stderr {stderr_suffix}\n".encode())
def _write_large_output() -> None:
os.write(1, b"stdout-0123456789")
os.write(2, b"stderr-0123456789")
def _write_all(fd: int, data: bytes) -> None:
remaining = memoryview(data)
while remaining:
written = os.write(fd, remaining)
remaining = remaining[written:]
def _write_large_exact_output(size: int) -> None:
_write_all(1, b"stdout:" + (b"x" * size))
_write_all(2, b"stderr:" + (b"y" * size))
def _raise_after_stderr_write() -> None:
os.write(2, b"stderr before exception\n")
raise RuntimeError("child boom")
def _exit_after_stdio_write(prefix: str, exitcode: int) -> None:
os.write(1, f"{prefix}: stdout before _exit\n".encode())
os.write(2, f"{prefix}: stderr before _exit\n".encode())
os._exit(exitcode)
def _abort_after_stdio_write(prefix: str) -> None:
os.write(1, f"{prefix}: stdout before abort\n".encode())
os.write(2, f"{prefix}: stderr before abort\n".encode())
os.abort()
def _close_stdio_and_exit() -> None:
os.close(1)
os.close(2)
os._exit(0)
def _exit_on_sigterm(exitcode: int) -> None:
def handle_sigterm(_signum: int, _frame: FrameType | None) -> None:
os._exit(exitcode)
signal.signal(signal.SIGTERM, handle_sigterm)
os.write(1, b"sigterm-ready\n")
while True:
time.sleep(0.1)
def _exit_after_repeated_sigterm(required_count: int, exitcode: int) -> None:
sigterm_count = 0
def handle_sigterm(_signum: int, _frame: FrameType | None) -> None:
nonlocal sigterm_count
sigterm_count += 1
if sigterm_count >= required_count:
os._exit(exitcode)
signal.signal(signal.SIGTERM, handle_sigterm)
os.write(1, b"sigterm-ready\n")
while True:
time.sleep(0.1)
def _ignore_sigterm_forever() -> None:
signal.signal(signal.SIGTERM, signal.SIG_IGN)
os.write(1, b"sigterm-ready\n")
while True:
time.sleep(0.1)
def _sleep_forever() -> None:
while True:
time.sleep(0.1)
def _send_over_mp_channel(send: MpSender[str]) -> None:
send.send("hello from child")
send.close()
def _mlx_force_oom(size: int = 40_000) -> None:
"""
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
"""
print("CHILD: start")
mx.set_default_device(mx.gpu)
a = mx.random.uniform(shape=(size, size), dtype=mx.float32)
b = mx.random.uniform(shape=(size, size), dtype=mx.float32)
mx.eval(a, b)
c = mx.matmul(a, b)
d = mx.matmul(a, c)
e = mx.matmul(b, c)
f = mx.sigmoid(d + e)
mx.eval(f)
print("CHILD: end")
async def _collect_stream(
stream: Receiver[bytes],
output: bytearray,
) -> None:
while True:
try:
output.extend(await stream.receive())
except EndOfStream:
return
async def _collect_process_output(
process: AsyncProcess,
) -> tuple[int, bytes, bytes]:
stdout = bytearray()
stderr = bytearray()
exitcodes: list[int] = []
async with create_task_group() as task_group:
task_group.start_soon(_collect_stream, process.stdout, stdout)
task_group.start_soon(_collect_stream, process.stderr, stderr)
exitcodes.append(await process.wait())
if not exitcodes:
raise RuntimeError("process exited without a return code")
return exitcodes[0], bytes(stdout), bytes(stderr)
def _fd_identity(fd: int) -> tuple[int, int]:
fd_stat = os.fstat(fd)
return fd_stat.st_dev, fd_stat.st_ino
def _fd_count() -> int | None:
for fd_dir in ("/proc/self/fd", "/dev/fd"):
with contextlib.suppress(OSError):
return len(os.listdir(fd_dir))
return None
@contextlib.asynccontextmanager
async def _started_process(process: AsyncProcess) -> AsyncIterator[None]:
async with create_task_group() as task_group:
await task_group.start(process.run)
try:
yield
finally:
await process.stop()
async def _run_and_collect(
target: Callable[..., object] | None,
*,
args: tuple[object, ...] = (),
kwargs: dict[str, object] | None = None,
) -> tuple[int, bytes, bytes]:
process = AsyncProcess(
target,
args=args,
kwargs=kwargs,
)
async with _started_process(process):
return await _collect_process_output(process)
@pytest.mark.anyio
async def test_spawn_process_captures_stdout_and_stderr_separately(
capfd: CaptureFixture[str],
) -> None:
process = AsyncProcess(
_write_to_stdio,
args=("child",),
kwargs={"stderr_suffix": "error"},
)
async with _started_process(process):
exitcode, stdout_bytes, stderr_bytes = await _collect_process_output(process)
parent_output = capfd.readouterr()
stdout = stdout_bytes.decode("utf-8", errors="replace")
stderr = stderr_bytes.decode("utf-8", errors="replace")
assert exitcode == 0
assert "child: python stdout" in stdout
assert "child: fd stdout" in stdout
assert "child: python stderr error" in stderr
assert "child: fd stderr error" in stderr
assert "child:" not in parent_output.out
assert "child:" not in parent_output.err
@pytest.mark.anyio
async def test_process_with_no_target_exits_successfully() -> None:
exitcode, stdout, stderr = await _run_and_collect(None)
assert exitcode == 0
assert stdout == b""
assert stderr == b""
@pytest.mark.anyio
async def test_output_receivers_and_wait_are_safe_immediately_after_run_starts() -> (
None
):
process = AsyncProcess(
_write_to_stdio,
args=("immediate",),
kwargs={"stderr_suffix": "error"},
)
result: tuple[int, bytes, bytes] | None = None
async with create_task_group() as task_group:
await task_group.start(process.run)
try:
result = await _collect_process_output(process)
finally:
await process.stop()
assert result is not None
exitcode, stdout, stderr = result
assert exitcode == 0
assert b"immediate: fd stdout\n" in stdout
assert b"immediate: fd stderr error\n" in stderr
@pytest.mark.anyio
async def test_stop_before_run_raises() -> None:
process = AsyncProcess(
_write_to_stdio,
args=("never",),
kwargs={"stderr_suffix": "run"},
)
assert not process.is_alive()
with pytest.raises(RuntimeError, match="process has not been started"):
await process.stop()
@pytest.mark.anyio
async def test_process_run_is_one_shot() -> None:
process = AsyncProcess(None)
await process.run()
with pytest.raises(RuntimeError, match="process has already been started"):
await process.run()
@pytest.mark.anyio
async def test_process_started_with_task_group_start_can_stop_immediately() -> None:
process = AsyncProcess(_sleep_forever)
async with create_task_group() as task_group:
await task_group.start(process.run)
assert process.is_alive()
with fail_after(2):
await process.stop()
assert not process.is_alive()
@pytest.mark.anyio
async def test_stdout_receiver_yields_bytes_chunks() -> None:
process = AsyncProcess(_write_large_output)
async with _started_process(process):
first_stdout = await process.stdout.receive()
exitcode, remaining_stdout, stderr = await _collect_process_output(process)
assert exitcode == 0
assert first_stdout + remaining_stdout == b"stdout-0123456789"
assert stderr == b"stderr-0123456789"
@pytest.mark.anyio
async def test_output_can_be_read_after_process_exits() -> None:
process = AsyncProcess(_write_large_output)
async with create_task_group() as task_group:
await task_group.start(process.run)
assert await process.wait() == 0
assert await process.stdout.receive() == b"stdout-0123456789"
assert await process.stderr.receive() == b"stderr-0123456789"
with pytest.raises(EndOfStream):
await process.stdout.receive()
with pytest.raises(EndOfStream):
await process.stderr.receive()
@pytest.mark.anyio
async def test_large_stdout_and_stderr_are_not_lost() -> None:
size = 1024 * 1024
exitcode, stdout, stderr = await _run_and_collect(
_write_large_exact_output,
args=(size,),
)
assert exitcode == 0
assert stdout == b"stdout:" + (b"x" * size)
assert stderr == b"stderr:" + (b"y" * size)
@pytest.mark.anyio
async def test_child_exception_traceback_is_captured_from_stderr() -> None:
process = AsyncProcess(_raise_after_stderr_write)
async with _started_process(process):
exitcode, _, stderr_bytes = await _collect_process_output(process)
assert exitcode == 1
stderr = stderr_bytes.decode("utf-8", errors="replace")
assert "stderr before exception" in stderr
assert "RuntimeError: child boom" in stderr
@pytest.mark.anyio
async def test_repeated_bad_children_do_not_pollute_or_replace_parent_stdio(
capfd: CaptureFixture[str],
) -> None:
stdout_object = sys.stdout
stderr_object = sys.stderr
stdout_identity = _fd_identity(1)
stderr_identity = _fd_identity(2)
cases: tuple[tuple[Callable[..., object], tuple[object, ...]], ...] = (
(_raise_after_stderr_write, ()),
(_exit_after_stdio_write, ("exit-child", 17)),
(_abort_after_stdio_write, ("abort-child",)),
)
for iteration in range(3):
for target, args in cases:
exitcode, stdout, stderr = await _run_and_collect(
target,
args=args,
)
assert exitcode != 0
if target is _exit_after_stdio_write:
assert stdout == b"exit-child: stdout before _exit\n"
assert stderr == b"exit-child: stderr before _exit\n"
elif target is _abort_after_stdio_write:
assert b"abort-child: stdout before abort\n" in stdout
assert b"abort-child: stderr before abort\n" in stderr
assert exitcode == -signal.SIGABRT
else:
assert stdout == b""
assert b"stderr before exception\n" in stderr
assert b"RuntimeError: child boom" in stderr
print(f"parent stdout still works {iteration}")
print(f"parent stderr still works {iteration}", file=sys.stderr)
parent_output = capfd.readouterr()
assert sys.stdout is stdout_object
assert sys.stderr is stderr_object
assert _fd_identity(1) == stdout_identity
assert _fd_identity(2) == stderr_identity
assert "parent stdout still works 0" in parent_output.out
assert "parent stdout still works 2" in parent_output.out
assert "parent stderr still works 0" in parent_output.err
assert "parent stderr still works 2" in parent_output.err
assert "exit-child:" not in parent_output.out
assert "exit-child:" not in parent_output.err
assert "abort-child:" not in parent_output.out
assert "abort-child:" not in parent_output.err
assert "child boom" not in parent_output.err
@pytest.mark.anyio
async def test_child_can_close_stdio_without_corrupting_parent_stdio(
capfd: CaptureFixture[str],
) -> None:
stdout_identity = _fd_identity(1)
stderr_identity = _fd_identity(2)
exitcode, stdout, stderr = await _run_and_collect(_close_stdio_and_exit)
os.write(1, b"parent stdout after child closed stdio\n")
os.write(2, b"parent stderr after child closed stdio\n")
parent_output = capfd.readouterr()
assert exitcode == 0
assert stdout == b""
assert stderr == b""
assert _fd_identity(1) == stdout_identity
assert _fd_identity(2) == stderr_identity
assert "parent stdout after child closed stdio" in parent_output.out
assert "parent stderr after child closed stdio" in parent_output.err
@pytest.mark.anyio
async def test_repeated_crashing_children_do_not_grow_parent_fd_table() -> None:
await _run_and_collect(_exit_after_stdio_write, args=("warmup", 23))
before = _fd_count()
if before is None:
pytest.skip("fd table count is not available on this platform")
for iteration in range(20):
exitcode, stdout, stderr = await _run_and_collect(
_exit_after_stdio_write,
args=(f"fd-child-{iteration}", 31),
)
assert exitcode == 31
assert stdout == f"fd-child-{iteration}: stdout before _exit\n".encode()
assert stderr == f"fd-child-{iteration}: stderr before _exit\n".encode()
after = _fd_count()
assert after is not None
assert after <= before + 2
@pytest.mark.anyio
async def test_stop_allows_child_to_exit_after_sigterm() -> None:
process = AsyncProcess(_exit_on_sigterm, args=(43,))
async with _started_process(process):
assert await process.stdout.receive() == b"sigterm-ready\n"
with fail_after(2):
await process.stop()
assert process.exitcode == 43
@pytest.mark.anyio
async def test_stop_retries_sigterm_before_sigkill(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(async_process, "_TERMINATE_GRACE_SECONDS", 0.01)
monkeypatch.setattr(async_process, "_TERMINATE_RETRY_GRACE_SECONDS", 0.01)
process = AsyncProcess(_exit_after_repeated_sigterm, args=(3, 44))
async with _started_process(process):
assert await process.stdout.receive() == b"sigterm-ready\n"
with fail_after(2):
await process.stop()
assert process.exitcode == 44
@pytest.mark.anyio
async def test_stop_escalates_to_sigkill_when_child_ignores_sigterm(
monkeypatch: MonkeyPatch,
) -> None:
monkeypatch.setattr(async_process, "_TERMINATE_GRACE_SECONDS", 0.1)
monkeypatch.setattr(async_process, "_TERMINATE_RETRY_GRACE_SECONDS", 0.01)
process = AsyncProcess(_ignore_sigterm_forever)
async with _started_process(process):
assert await process.stdout.receive() == b"sigterm-ready\n"
with fail_after(3):
await process.stop()
assert process.exitcode == -signal.SIGKILL
@pytest.mark.anyio
async def test_process_can_use_mp_channel_with_global_spawn_context() -> None:
send, recv = mp_channel[str]()
process = AsyncProcess(_send_over_mp_channel, args=(send,))
async with _started_process(process):
with fail_after(2):
assert await recv.receive_async() == "hello from child"
assert await process.wait() == 0
with contextlib.suppress(Exception):
recv.close()
@pytest.mark.anyio
@pytest.mark.skip(reason="manual MLX OOM isolation check")
async def test_death(capsys: CaptureFixture[str]) -> None:
with capsys.disabled():
process = AsyncProcess(_mlx_force_oom)
stdout = b""
stderr = b""
async with _started_process(process):
_, stdout, stderr = await _collect_process_output(process)
print("PARENT: done")
print("CHILD out:", stdout.decode("utf-8", errors="replace"))
print("CHILD err:", stderr.decode("utf-8", errors="replace"), "hello :)")

View File

@@ -0,0 +1,168 @@
import contextlib
import os
from collections.abc import AsyncIterator
import anyio
import pytest
from anyio import EndOfStream, create_task_group, fail_after
from exo.utils.async_process import AsyncProcess
from exo.utils.channels import MpReceiver, MpSender, Receiver, mp_channel
from exo.utils.daemon import detach_stdio_to_devnull
def _write_before_and_after_detach() -> None:
os.write(1, b"before stdout\n")
os.write(2, b"before stderr\n")
detach_stdio_to_devnull()
os.write(1, b"after stdout\n")
os.write(2, b"after stderr\n")
def _write_grandchild_stdio(label: str) -> None:
os.write(1, f"{label} stdout\n".encode())
os.write(2, f"{label} stderr\n".encode())
async def _spawn_grandchild_and_report(
result_sender: MpSender[tuple[int, bytes, bytes]],
label: str,
) -> None:
result_sender.send(await _collect_spawned_child(label))
result_sender.close()
async def _collect_spawned_child(label: str) -> tuple[int, bytes, bytes]:
process = AsyncProcess(_write_grandchild_stdio, args=(label,))
async with _started_process(process):
return await _collect_process_output(process)
def _detach_stdio_then_spawn_captured_child(
result_sender: MpSender[tuple[int, bytes, bytes]],
) -> None:
detach_stdio_to_devnull()
anyio.run(_spawn_grandchild_and_report, result_sender, "grandchild")
def _detach_stdio_then_spawn_captured_children_sequentially(
result_sender: MpSender[list[tuple[int, bytes, bytes]]],
) -> None:
async def run_children() -> list[tuple[int, bytes, bytes]]:
results: list[tuple[int, bytes, bytes]] = []
for index in range(5):
results.append(await _collect_spawned_child(f"grandchild-{index}"))
return results
detach_stdio_to_devnull()
result_sender.send(anyio.run(run_children))
result_sender.close()
async def _collect_stream(stream: Receiver[bytes], output: bytearray) -> None:
while True:
try:
output.extend(await stream.receive())
except EndOfStream:
return
async def _collect_process_output(
process: AsyncProcess,
) -> tuple[int, bytes, bytes]:
stdout = bytearray()
stderr = bytearray()
exitcodes: list[int] = []
async with create_task_group() as collect_group:
collect_group.start_soon(_collect_stream, process.stdout, stdout)
collect_group.start_soon(_collect_stream, process.stderr, stderr)
exitcodes.append(await process.wait())
if not exitcodes:
raise RuntimeError("process exited without a return code")
return exitcodes[0], bytes(stdout), bytes(stderr)
@contextlib.asynccontextmanager
async def _started_process(process: AsyncProcess) -> AsyncIterator[None]:
async with create_task_group() as task_group:
await task_group.start(process.run)
try:
yield
finally:
await process.stop()
async def _run_process_and_receive[T](
process: AsyncProcess,
recv: MpReceiver[T],
*,
timeout: float,
) -> tuple[int, T]:
async with _started_process(process):
with fail_after(timeout):
result = await recv.receive_async()
exitcode = await process.wait()
return exitcode, result
@pytest.mark.anyio
async def test_detach_stdio_to_devnull_redirects_stdio_away_from_capture() -> None:
process = AsyncProcess(_write_before_and_after_detach)
async with _started_process(process):
exitcode, stdout, stderr = await _collect_process_output(process)
assert exitcode == 0
assert stdout == b"before stdout\n"
assert stderr == b"before stderr\n"
@pytest.mark.anyio
async def test_detached_stdio_process_can_spawn_and_capture_child_stdio() -> None:
send, recv = mp_channel[tuple[int, bytes, bytes]]()
process = AsyncProcess(_detach_stdio_then_spawn_captured_child, args=(send,))
try:
daemonized_parent_exitcode, result = await _run_process_and_receive(
process, recv, timeout=5
)
finally:
recv.close()
child_exitcode, child_stdout, child_stderr = result
assert daemonized_parent_exitcode == 0
assert child_exitcode == 0
assert child_stdout == b"grandchild stdout\n"
assert child_stderr == b"grandchild stderr\n"
@pytest.mark.anyio
async def test_detached_stdio_process_can_spawn_captured_children_sequentially() -> (
None
):
send, recv = mp_channel[list[tuple[int, bytes, bytes]]]()
process = AsyncProcess(
_detach_stdio_then_spawn_captured_children_sequentially,
args=(send,),
)
try:
daemonized_parent_exitcode, results = await _run_process_and_receive(
process, recv, timeout=10
)
finally:
recv.close()
assert daemonized_parent_exitcode == 0
assert results == [
(
0,
f"grandchild-{index} stdout\n".encode(),
f"grandchild-{index} stderr\n".encode(),
)
for index in range(5)
]

View File

@@ -115,7 +115,8 @@ def mlx_distributed_init(
os.environ["MLX_HOSTFILE"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
# os.environ["MLX_RING_VERBOSE"] = "1" # NOTE: we don't use it enough to care (turn on again if need to)
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(

View File

@@ -1,5 +1,4 @@
import contextlib
import multiprocessing as mp
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -8,7 +7,7 @@ import anyio
from anyio import (
BrokenResourceError,
ClosedResourceError,
to_thread,
EndOfStream,
)
from loguru import logger
@@ -41,7 +40,8 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.utils.async_process import AsyncProcess
from exo.utils.channels import MpReceiver, MpSender, Receiver, Sender, mp_channel
from exo.utils.task_group import TaskGroup
from exo.worker.runner.bootstrap import entrypoint
@@ -53,7 +53,7 @@ DECODE_TIMEOUT_SECONDS = 5
class RunnerSupervisor:
shard_metadata: ShardMetadata
bound_instance: BoundInstance
runner_process: mp.Process
runner_process: AsyncProcess
initialize_timeout: float
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
@@ -81,7 +81,7 @@ class RunnerSupervisor:
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = mp.Process(
runner_process = AsyncProcess(
target=entrypoint,
args=(
bound_instance,
@@ -109,9 +109,25 @@ class RunnerSupervisor:
return self
async def run(self):
self.runner_process.start()
try:
async with self._tg as tg:
# start the process itself
await tg.start(self.runner_process.run)
# start tasks to drain/collect stdout/stderr into usable errors
#
# TODO: right now it logs them as warnings, but in the future they should be split
# into being logged AND a seperate task which tries to best-effort figure out cause
# of error and package into error enum, which then is used by rest of app to act on it;
# inferring what the error is would be done by pattern-matching in the text for things
# e.g. certain VLLM error codes and so on
tg.start_soon(
self._forward_runner_output, "stdout", self.runner_process.stdout
)
tg.start_soon(
self._forward_runner_output, "stderr", self.runner_process.stderr
)
tg.start_soon(self._watch_runner)
tg.start_soon(self._forward_events)
finally:
@@ -129,41 +145,11 @@ class RunnerSupervisor:
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.close()
await to_thread.run_sync(self.runner_process.join, 5)
if self.runner_process.is_alive():
logger.warning(
"Runner process didn't shutdown succesfully, terminating"
with anyio.CancelScope(shield=True):
await self.runner_process.stop()
logger.info(
f"Runner process successfully terminated: {self.runner_process.exitcode}"
)
self.runner_process.terminate()
self.runner_process.join(timeout=10)
if not self.runner_process.is_alive():
logger.warning("Terminated nicely in the first attempt!")
else:
# Try really hard to terminate
for i in range(2, 11):
self.runner_process.terminate()
self.runner_process.join(timeout=2)
if not self.runner_process.is_alive():
logger.warning(f"That took {i} attempts :)")
break
# Try even harder to kill
else:
logger.critical(
"Runner process didn't respond to SIGTERM, killing"
)
j = 0
while self.runner_process.is_alive():
j += 1
self.runner_process.kill()
self.runner_process.join(timeout=5)
logger.warning(f"That took {j} attempts :(")
else:
logger.info("Runner process succesfully terminated")
self.runner_process.close()
def shutdown(self):
self._tg.cancel_tasks()
@@ -249,13 +235,33 @@ class RunnerSupervisor:
if not self.runner_process.is_alive():
await self._check_runner(RuntimeError("Runner found to be dead"))
async def _forward_runner_output(
self,
stream_name: str,
stream: Receiver[bytes],
) -> None:
while True:
try:
chunk = await stream.receive()
except (EndOfStream, ClosedResourceError, BrokenResourceError):
return
message = chunk.decode("utf-8", errors="replace").rstrip()
if not message:
continue
if stream_name == "stderr":
logger.warning(f"Runner stderr: {message}")
else:
logger.debug(f"Runner stdout: {message}")
async def _check_runner(self, e: Exception) -> None:
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
logger.info("Checking runner's status")
if self.runner_process.is_alive():
logger.info("Runner was found to be alive, attempting to join process")
await to_thread.run_sync(self.runner_process.join, 5)
logger.info("Runner was found to be alive, stopping process")
with anyio.CancelScope(shield=True):
await self.runner_process.stop()
rc = self.runner_process.exitcode
logger.info(f"Runner exited with exit code {rc}")
if rc == 0:

View File

@@ -1,4 +1,3 @@
import multiprocessing as mp
from typing import cast
import anyio
@@ -16,6 +15,7 @@ from exo.shared.types.text_generation import (
)
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerFailed, RunnerId
from exo.utils.async_process import AsyncProcess
from exo.utils.channels import channel, mp_channel
from exo.worker.runner.supervisor import RunnerSupervisor
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
@@ -24,23 +24,11 @@ from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
class _DeadProcess:
exitcode = -6
def start(self) -> None:
return None
def is_alive(self) -> bool:
return False
def join(self, _timeout: float | None = None) -> None:
return None
def terminate(self) -> None:
return None
def kill(self) -> None:
return None
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_check_runner_emits_error_chunk_for_inflight_text_generation() -> None:
event_sender, event_receiver = channel[Event]()
task_sender, _ = mp_channel[Task]()
@@ -57,7 +45,7 @@ async def test_check_runner_emits_error_chunk_for_inflight_text_generation() ->
supervisor = RunnerSupervisor(
shard_metadata=bound_instance.bound_shard,
bound_instance=bound_instance,
runner_process=cast("mp.Process", cast(object, _DeadProcess())),
runner_process=cast(AsyncProcess, cast(object, _DeadProcess())),
initialize_timeout=400,
_ev_recv=ev_recv,
_task_sender=task_sender,