From 45df74ba98ec7f798b3bad9d2b4d6906a257fb33 Mon Sep 17 00:00:00 2001 From: Andrei Cravtov Date: Sat, 9 May 2026 22:45:14 +0100 Subject: [PATCH] Andrei/mp capture stdio (#2056) ## Motivation Process-isolated runner crashes and C-extension failures can write directly to fd-level stdout/stderr, bypassing Python/loguru. We need to capture that output per runner process without polluting the main process or other workers, and without breaking operation when the parent stdio is detached. ## Changes - Added `AsyncProcess`, a spawn-only multiprocessing wrapper that redirects child stdout/stderr to pipes and exposes them as in-memory `Receiver[bytes]`s - Replaced runner-supervisor's raw `multiprocessing.Process` usage with `AsyncProcess` - Added `--no-stdio`, redirecting stdin/stdout/stderr to `/dev/null` after logging is configured - Disabled verbose MLX - Added tests covering stdio capture, child crashes, repeated bad children, SIGTERM/SIGKILL shutdown escalation, stdio detachment, and spawning captured children from a stdio-detached parent ## Why It Works The parent can redirect its own stdio fds to `/dev/null`, while `AsyncProcess` installs fresh pipe fds over fd 1 and 2 inside each spawned child. That keeps stdio-detached parents quiet while preserving per-runner stdout/stderr capture. Runner shutdown is still bounded: SIGTERM grace first, then SIGKILL escalation if needed. Next direction: the runner supervisor currently drains captured output and logs it as stdout/debug and stderr/warning. This should be split into more useful process-isolated error reporting instead of just log forwarding (regex match on errors to obtain "reason" string, best effort). ## Test Plan ### Manual Testing Ran on 4 Mac Minis in a Thunderbolt 4 ring, can see that runner's stdout/stderr contents are being captured. ### Automated Testing - Added async-process tests for fd-level stdout/stderr capture, Python traceback capture, bounded-buffer output, child `exit`/abort, parent stdio preservation, fd leak checks, spawn-context mp channels, and SIGTERM/SIGKILL shutdown behavior - Added stdio-detach tests proving stdio detaches to `/dev/null`, a stdio-detached parent can still spawn and capture a child, and the same stdio-detached parent can spawn/capture multiple children sequentially - Updated runner-supervisor tests for the new `AsyncProcess.exitcode` path --- src/exo/main.py | 11 + src/exo/utils/async_process.py | 290 ++++++++++ src/exo/utils/daemon.py | 28 + src/exo/utils/tests/test_async_process.py | 515 ++++++++++++++++++ src/exo/utils/tests/test_daemon.py | 168 ++++++ src/exo/worker/engines/mlx/utils_mlx.py | 3 +- src/exo/worker/runner/supervisor.py | 90 +-- .../test_runner/test_runner_supervisor.py | 18 +- 8 files changed, 1065 insertions(+), 58 deletions(-) create mode 100644 src/exo/utils/async_process.py create mode 100644 src/exo/utils/daemon.py create mode 100644 src/exo/utils/tests/test_async_process.py create mode 100644 src/exo/utils/tests/test_daemon.py diff --git a/src/exo/main.py b/src/exo/main.py index 520f2e2fb..7419e6883 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -23,6 +23,7 @@ from exo.shared.election import Election, ElectionResult from exo.shared.logging import logger_cleanup, logger_setup from exo.shared.types.common import NodeId, SessionId from exo.utils.channels import Receiver, channel +from exo.utils.daemon import detach_stdio_to_devnull from exo.utils.pidfile import PidfileLockError, acquire_exo_pidfile from exo.utils.pydantic_ext import FrozenModel from exo.utils.task_group import TaskGroup @@ -282,6 +283,10 @@ def main(): # TODO: Refactor the current verbosity system logger_setup(EXO_LOG, args.verbosity) + if args.no_stdio: + detach_stdio_to_devnull() + logger.info("Detached stdio to /dev/null") + logger.info(f"{'=' * 40}") logger.info(f"Starting EXO | pid={os.getpid()}") logger.info(f"{'=' * 40}") @@ -330,6 +335,7 @@ class Args(FrozenModel): offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true" no_batch: bool = False fast_synch: bool | None = None # None = auto, True = force on, False = force off + no_stdio: bool = False bootstrap_peers: list[str] = [] libp2p_port: int @@ -389,6 +395,11 @@ class Args(FrozenModel): action="store_true", help="Disable continuous batching, use sequential generation", ) + parser.add_argument( + "--no-stdio", + action="store_true", + help="Detach stdin/stdout/stderr to /dev/null after logging is configured", + ) parser.add_argument( "--bootstrap-peers", type=lambda s: [p for p in s.split(",") if p], diff --git a/src/exo/utils/async_process.py b/src/exo/utils/async_process.py new file mode 100644 index 000000000..3866037dd --- /dev/null +++ b/src/exo/utils/async_process.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import contextlib +import faulthandler +import multiprocessing as mp +import os +import sys +from collections.abc import Callable, Iterable, Mapping +from multiprocessing.process import BaseProcess +from multiprocessing.resource_sharer import DupFd +from typing import final + +from anyio import ( + TASK_STATUS_IGNORED, + BrokenResourceError, + CancelScope, + ClosedResourceError, + Event, + create_task_group, + move_on_after, + sleep, + wait_readable, +) +from anyio.abc import TaskStatus +from loguru import logger + +from exo.utils.channels import Receiver, Sender, channel + +_STDOUT_FD = 1 +_STDERR_FD = 2 +_READ_CHUNK_SIZE = 64 * 1024 +_TERMINATE_GRACE_SECONDS = 10.0 +_TERMINATE_RETRY_GRACE_SECONDS = 2.0 +_TERMINATE_ATTEMPTS = 10 +_KILL_GRACE_SECONDS = 5.0 + + +@final +class AsyncProcess: + def __init__( + self, + target: Callable[..., object] | None = None, + name: str | None = None, + args: Iterable[object] = (), + kwargs: Mapping[str, object] | None = None, + *, + daemon: bool | None = None, + ) -> None: + # setup state + self._target = target + self._name = name + self._args = args + self._kwargs = kwargs + self._daemon = daemon + + # lifecycle state + self._process: BaseProcess | None = None + self._pid: int | None = None + self._stdout_tx, self._stdout_rx = channel[bytes]() + self._stderr_tx, self._stderr_rx = channel[bytes]() + self._started = Event() + self._done = Event() + self._run_cancel_scope: CancelScope | None = None + self._start_error: BaseException | None = None + self._exitcode: int | None = None + + async def run(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + if self._run_cancel_scope is not None or self._done.is_set(): + raise RuntimeError("process has already been started") + + stdout_read_fd: int | None = None + stdout_write_fd: int | None = None + stderr_read_fd: int | None = None + stderr_write_fd: int | None = None + + def cleanup_stdio_fd() -> None: + nonlocal stdout_read_fd, stdout_write_fd, stderr_read_fd, stderr_write_fd + stdout_read_fd = _close_fd(stdout_read_fd) + stdout_write_fd = _close_fd(stdout_write_fd) + stderr_read_fd = _close_fd(stderr_read_fd) + stderr_write_fd = _close_fd(stderr_write_fd) + + try: + with CancelScope() as run_cancel_scope: + self._run_cancel_scope = run_cancel_scope + stdout_read_fd, stdout_write_fd = os.pipe() + stderr_read_fd, stderr_write_fd = os.pipe() + + process = mp.Process( + target=_run_with_captured_stdio, + name=self._name, + args=( + DupFd(stdout_write_fd), + DupFd(stderr_write_fd), + self._target, + *self._args, + ), + kwargs={} if self._kwargs is None else self._kwargs, + daemon=self._daemon, + ) + process.start() + pid = process.pid + if pid is None: + raise RuntimeError("started process has no pid") + + # important to close parent write-side FD to prevent hangs + stdout_write_fd = _close_fd(stdout_write_fd) + stderr_write_fd = _close_fd(stderr_write_fd) + + self._process = process + self._pid = pid + self._started.set() + + async with create_task_group() as tg: + tg.start_soon(_drain_fd, stdout_read_fd, self._stdout_tx) + stdout_read_fd = None + tg.start_soon(_drain_fd, stderr_read_fd, self._stderr_tx) + stderr_read_fd = None + task_status.started() + await self.wait() + except BaseException as exc: + if not self._started.is_set(): + self._start_error = exc + self._started.set() + raise + finally: + try: + with CancelScope(shield=True): + await self._terminate_if_still_alive() + finally: + cleanup_stdio_fd() + for tx in (self._stdout_tx, self._stderr_tx): + with contextlib.suppress(Exception): + await tx.aclose() + if self._process is not None: + with contextlib.suppress(ValueError): + self._process.close() + self._run_cancel_scope = None + self._done.set() + + async def stop(self) -> None: + if self._run_cancel_scope is None and not self._done.is_set(): + raise RuntimeError("process has not been started") + if self._run_cancel_scope is not None: + self._run_cancel_scope.cancel() + await self._done.wait() + + async def aclose(self) -> None: + await self.stop() + + async def wait(self) -> int: + if self._exitcode is not None: + return self._exitcode + + await self._started.wait() + if self._start_error is not None: + raise self._start_error + assert self._process is not None + + while True: + exitcode = self.exitcode + if exitcode is not None: + return exitcode + await sleep(0.01) + + @property + def pid(self) -> int: + if self._pid is None: + raise RuntimeError("process has not been started") + return self._pid + + @property + def exitcode(self) -> int | None: + if self._exitcode is not None: + return self._exitcode + if self._process is None: + return None + + with contextlib.suppress(ValueError): + exitcode = self._process.exitcode + if exitcode is not None: + self._exitcode = exitcode + return exitcode + return None + + def is_alive(self) -> bool: + if self._process is None: + return False + + with contextlib.suppress(ValueError): + return self._process.is_alive() + return False + + # TODO: maybe in the future if needed, create stdin that is also installed, + # and a ByteSendStream handle is provided for it :) + + @property + def stdout(self) -> Receiver[bytes]: + return self._stdout_rx + + @property + def stderr(self) -> Receiver[bytes]: + return self._stderr_rx + + async def _terminate_if_still_alive(self) -> None: + process = self._process + if process is None: + return + + if self.exitcode is not None: + return + + with contextlib.suppress(ValueError): + if not process.is_alive(): + return + + logger.warning("Child process didn't shut down successfully, terminating") + process.terminate() + with move_on_after(_TERMINATE_GRACE_SECONDS): + await self.wait() + + if self.exitcode is not None or not process.is_alive(): + logger.warning("Terminated nicely in the first attempt!") + return + + for attempt in range(2, _TERMINATE_ATTEMPTS + 1): + process.terminate() + with move_on_after(_TERMINATE_RETRY_GRACE_SECONDS): + await self.wait() + + if self.exitcode is not None or not process.is_alive(): + logger.warning(f"That took {attempt} attempts :)") + return + + logger.critical("Child process didn't respond to SIGTERM, killing") + j = 0 + while True: + process.kill() + with move_on_after(_KILL_GRACE_SECONDS): + await self.wait() + j += 1 + if self.exitcode is not None or not process.is_alive(): + break + logger.warning(f"That took {j} attempts :(") + + +# Spawn-mode multiprocessing requires a module-level target that can be pickled. +def _run_with_captured_stdio( + stdout: DupFd, + stderr: DupFd, + target: Callable[..., object] | None, + *target_args: object, + **target_kwargs: object, +) -> None: + stdout_fd = stdout.detach() + stderr_fd = stderr.detach() + + try: + os.dup2(stdout_fd, _STDOUT_FD) + os.dup2(stderr_fd, _STDERR_FD) + finally: + for fd in (stdout_fd, stderr_fd): + if fd not in (_STDOUT_FD, _STDERR_FD): + _close_fd(fd) + + faulthandler.enable(file=sys.stderr, all_threads=True) + if target is not None: + target(*target_args, **target_kwargs) + + +async def _drain_fd(fd: int, tx: Sender[bytes]) -> None: + try: + while True: + await wait_readable(fd) + chunk = os.read(fd, _READ_CHUNK_SIZE) + if not chunk: + return + await tx.send(chunk) + except (BrokenPipeError, BrokenResourceError, ClosedResourceError): + pass + finally: + _close_fd(fd) + await tx.aclose() + + +def _close_fd(fd: int | None) -> None: + if fd is None: + return + with contextlib.suppress(OSError): + os.close(fd) diff --git a/src/exo/utils/daemon.py b/src/exo/utils/daemon.py new file mode 100644 index 000000000..7636d6808 --- /dev/null +++ b/src/exo/utils/daemon.py @@ -0,0 +1,28 @@ +import os +import sys + +_STDIN_FD = 0 +_STDOUT_FD = 1 +_STDERR_FD = 2 + + +def detach_stdio_to_devnull() -> None: + """Redirect process stdio file descriptors to /dev/null.""" + + for stream in (sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__): + if stream is not None: + stream.flush() + + stdin_fd = os.open(os.devnull, os.O_RDONLY) + stdout_fd = os.open(os.devnull, os.O_WRONLY) + stderr_fd = os.open(os.devnull, os.O_WRONLY) + + try: + # dup2 closes the target fd first, but leaves the source fd open. + os.dup2(stdin_fd, _STDIN_FD) + os.dup2(stdout_fd, _STDOUT_FD) + os.dup2(stderr_fd, _STDERR_FD) + finally: + for fd in (stdin_fd, stdout_fd, stderr_fd): + if fd not in (_STDIN_FD, _STDOUT_FD, _STDERR_FD): + os.close(fd) diff --git a/src/exo/utils/tests/test_async_process.py b/src/exo/utils/tests/test_async_process.py new file mode 100644 index 000000000..0e275cfc0 --- /dev/null +++ b/src/exo/utils/tests/test_async_process.py @@ -0,0 +1,515 @@ +import contextlib +import os +import signal +import sys +import time +from collections.abc import AsyncIterator, Callable +from types import FrameType + +import mlx.core as mx +import pytest +from _pytest.capture import CaptureFixture +from anyio import EndOfStream, create_task_group, fail_after +from pytest import MonkeyPatch + +import exo.utils.async_process as async_process +from exo.utils.async_process import ( + AsyncProcess, +) +from exo.utils.channels import MpSender, Receiver, mp_channel + + +def _write_to_stdio(prefix: str, *, stderr_suffix: str) -> None: + print(f"{prefix}: python stdout") + print(f"{prefix}: python stderr {stderr_suffix}", file=sys.stderr) + os.write(1, f"{prefix}: fd stdout\n".encode()) + os.write(2, f"{prefix}: fd stderr {stderr_suffix}\n".encode()) + + +def _write_large_output() -> None: + os.write(1, b"stdout-0123456789") + os.write(2, b"stderr-0123456789") + + +def _write_all(fd: int, data: bytes) -> None: + remaining = memoryview(data) + while remaining: + written = os.write(fd, remaining) + remaining = remaining[written:] + + +def _write_large_exact_output(size: int) -> None: + _write_all(1, b"stdout:" + (b"x" * size)) + _write_all(2, b"stderr:" + (b"y" * size)) + + +def _raise_after_stderr_write() -> None: + os.write(2, b"stderr before exception\n") + raise RuntimeError("child boom") + + +def _exit_after_stdio_write(prefix: str, exitcode: int) -> None: + os.write(1, f"{prefix}: stdout before _exit\n".encode()) + os.write(2, f"{prefix}: stderr before _exit\n".encode()) + os._exit(exitcode) + + +def _abort_after_stdio_write(prefix: str) -> None: + os.write(1, f"{prefix}: stdout before abort\n".encode()) + os.write(2, f"{prefix}: stderr before abort\n".encode()) + os.abort() + + +def _close_stdio_and_exit() -> None: + os.close(1) + os.close(2) + os._exit(0) + + +def _exit_on_sigterm(exitcode: int) -> None: + def handle_sigterm(_signum: int, _frame: FrameType | None) -> None: + os._exit(exitcode) + + signal.signal(signal.SIGTERM, handle_sigterm) + os.write(1, b"sigterm-ready\n") + while True: + time.sleep(0.1) + + +def _exit_after_repeated_sigterm(required_count: int, exitcode: int) -> None: + sigterm_count = 0 + + def handle_sigterm(_signum: int, _frame: FrameType | None) -> None: + nonlocal sigterm_count + sigterm_count += 1 + if sigterm_count >= required_count: + os._exit(exitcode) + + signal.signal(signal.SIGTERM, handle_sigterm) + os.write(1, b"sigterm-ready\n") + while True: + time.sleep(0.1) + + +def _ignore_sigterm_forever() -> None: + signal.signal(signal.SIGTERM, signal.SIG_IGN) + os.write(1, b"sigterm-ready\n") + while True: + time.sleep(0.1) + + +def _sleep_forever() -> None: + while True: + time.sleep(0.1) + + +def _send_over_mp_channel(send: MpSender[str]) -> None: + send.send("hello from child") + send.close() + + +def _mlx_force_oom(size: int = 40_000) -> None: + """ + Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations. + """ + print("CHILD: start") + + mx.set_default_device(mx.gpu) + a = mx.random.uniform(shape=(size, size), dtype=mx.float32) + b = mx.random.uniform(shape=(size, size), dtype=mx.float32) + mx.eval(a, b) + c = mx.matmul(a, b) + d = mx.matmul(a, c) + e = mx.matmul(b, c) + f = mx.sigmoid(d + e) + mx.eval(f) + + print("CHILD: end") + + +async def _collect_stream( + stream: Receiver[bytes], + output: bytearray, +) -> None: + while True: + try: + output.extend(await stream.receive()) + except EndOfStream: + return + + +async def _collect_process_output( + process: AsyncProcess, +) -> tuple[int, bytes, bytes]: + stdout = bytearray() + stderr = bytearray() + exitcodes: list[int] = [] + + async with create_task_group() as task_group: + task_group.start_soon(_collect_stream, process.stdout, stdout) + task_group.start_soon(_collect_stream, process.stderr, stderr) + exitcodes.append(await process.wait()) + + if not exitcodes: + raise RuntimeError("process exited without a return code") + return exitcodes[0], bytes(stdout), bytes(stderr) + + +def _fd_identity(fd: int) -> tuple[int, int]: + fd_stat = os.fstat(fd) + return fd_stat.st_dev, fd_stat.st_ino + + +def _fd_count() -> int | None: + for fd_dir in ("/proc/self/fd", "/dev/fd"): + with contextlib.suppress(OSError): + return len(os.listdir(fd_dir)) + return None + + +@contextlib.asynccontextmanager +async def _started_process(process: AsyncProcess) -> AsyncIterator[None]: + async with create_task_group() as task_group: + await task_group.start(process.run) + try: + yield + finally: + await process.stop() + + +async def _run_and_collect( + target: Callable[..., object] | None, + *, + args: tuple[object, ...] = (), + kwargs: dict[str, object] | None = None, +) -> tuple[int, bytes, bytes]: + process = AsyncProcess( + target, + args=args, + kwargs=kwargs, + ) + async with _started_process(process): + return await _collect_process_output(process) + + +@pytest.mark.anyio +async def test_spawn_process_captures_stdout_and_stderr_separately( + capfd: CaptureFixture[str], +) -> None: + process = AsyncProcess( + _write_to_stdio, + args=("child",), + kwargs={"stderr_suffix": "error"}, + ) + async with _started_process(process): + exitcode, stdout_bytes, stderr_bytes = await _collect_process_output(process) + + parent_output = capfd.readouterr() + stdout = stdout_bytes.decode("utf-8", errors="replace") + stderr = stderr_bytes.decode("utf-8", errors="replace") + + assert exitcode == 0 + assert "child: python stdout" in stdout + assert "child: fd stdout" in stdout + assert "child: python stderr error" in stderr + assert "child: fd stderr error" in stderr + assert "child:" not in parent_output.out + assert "child:" not in parent_output.err + + +@pytest.mark.anyio +async def test_process_with_no_target_exits_successfully() -> None: + exitcode, stdout, stderr = await _run_and_collect(None) + + assert exitcode == 0 + assert stdout == b"" + assert stderr == b"" + + +@pytest.mark.anyio +async def test_output_receivers_and_wait_are_safe_immediately_after_run_starts() -> ( + None +): + process = AsyncProcess( + _write_to_stdio, + args=("immediate",), + kwargs={"stderr_suffix": "error"}, + ) + result: tuple[int, bytes, bytes] | None = None + + async with create_task_group() as task_group: + await task_group.start(process.run) + try: + result = await _collect_process_output(process) + finally: + await process.stop() + + assert result is not None + exitcode, stdout, stderr = result + assert exitcode == 0 + assert b"immediate: fd stdout\n" in stdout + assert b"immediate: fd stderr error\n" in stderr + + +@pytest.mark.anyio +async def test_stop_before_run_raises() -> None: + process = AsyncProcess( + _write_to_stdio, + args=("never",), + kwargs={"stderr_suffix": "run"}, + ) + + assert not process.is_alive() + with pytest.raises(RuntimeError, match="process has not been started"): + await process.stop() + + +@pytest.mark.anyio +async def test_process_run_is_one_shot() -> None: + process = AsyncProcess(None) + + await process.run() + + with pytest.raises(RuntimeError, match="process has already been started"): + await process.run() + + +@pytest.mark.anyio +async def test_process_started_with_task_group_start_can_stop_immediately() -> None: + process = AsyncProcess(_sleep_forever) + + async with create_task_group() as task_group: + await task_group.start(process.run) + assert process.is_alive() + with fail_after(2): + await process.stop() + + assert not process.is_alive() + + +@pytest.mark.anyio +async def test_stdout_receiver_yields_bytes_chunks() -> None: + process = AsyncProcess(_write_large_output) + + async with _started_process(process): + first_stdout = await process.stdout.receive() + exitcode, remaining_stdout, stderr = await _collect_process_output(process) + + assert exitcode == 0 + assert first_stdout + remaining_stdout == b"stdout-0123456789" + assert stderr == b"stderr-0123456789" + + +@pytest.mark.anyio +async def test_output_can_be_read_after_process_exits() -> None: + process = AsyncProcess(_write_large_output) + + async with create_task_group() as task_group: + await task_group.start(process.run) + assert await process.wait() == 0 + + assert await process.stdout.receive() == b"stdout-0123456789" + assert await process.stderr.receive() == b"stderr-0123456789" + with pytest.raises(EndOfStream): + await process.stdout.receive() + with pytest.raises(EndOfStream): + await process.stderr.receive() + + +@pytest.mark.anyio +async def test_large_stdout_and_stderr_are_not_lost() -> None: + size = 1024 * 1024 + exitcode, stdout, stderr = await _run_and_collect( + _write_large_exact_output, + args=(size,), + ) + + assert exitcode == 0 + assert stdout == b"stdout:" + (b"x" * size) + assert stderr == b"stderr:" + (b"y" * size) + + +@pytest.mark.anyio +async def test_child_exception_traceback_is_captured_from_stderr() -> None: + process = AsyncProcess(_raise_after_stderr_write) + + async with _started_process(process): + exitcode, _, stderr_bytes = await _collect_process_output(process) + + assert exitcode == 1 + stderr = stderr_bytes.decode("utf-8", errors="replace") + assert "stderr before exception" in stderr + assert "RuntimeError: child boom" in stderr + + +@pytest.mark.anyio +async def test_repeated_bad_children_do_not_pollute_or_replace_parent_stdio( + capfd: CaptureFixture[str], +) -> None: + stdout_object = sys.stdout + stderr_object = sys.stderr + stdout_identity = _fd_identity(1) + stderr_identity = _fd_identity(2) + + cases: tuple[tuple[Callable[..., object], tuple[object, ...]], ...] = ( + (_raise_after_stderr_write, ()), + (_exit_after_stdio_write, ("exit-child", 17)), + (_abort_after_stdio_write, ("abort-child",)), + ) + + for iteration in range(3): + for target, args in cases: + exitcode, stdout, stderr = await _run_and_collect( + target, + args=args, + ) + + assert exitcode != 0 + if target is _exit_after_stdio_write: + assert stdout == b"exit-child: stdout before _exit\n" + assert stderr == b"exit-child: stderr before _exit\n" + elif target is _abort_after_stdio_write: + assert b"abort-child: stdout before abort\n" in stdout + assert b"abort-child: stderr before abort\n" in stderr + assert exitcode == -signal.SIGABRT + else: + assert stdout == b"" + assert b"stderr before exception\n" in stderr + assert b"RuntimeError: child boom" in stderr + + print(f"parent stdout still works {iteration}") + print(f"parent stderr still works {iteration}", file=sys.stderr) + + parent_output = capfd.readouterr() + + assert sys.stdout is stdout_object + assert sys.stderr is stderr_object + assert _fd_identity(1) == stdout_identity + assert _fd_identity(2) == stderr_identity + assert "parent stdout still works 0" in parent_output.out + assert "parent stdout still works 2" in parent_output.out + assert "parent stderr still works 0" in parent_output.err + assert "parent stderr still works 2" in parent_output.err + assert "exit-child:" not in parent_output.out + assert "exit-child:" not in parent_output.err + assert "abort-child:" not in parent_output.out + assert "abort-child:" not in parent_output.err + assert "child boom" not in parent_output.err + + +@pytest.mark.anyio +async def test_child_can_close_stdio_without_corrupting_parent_stdio( + capfd: CaptureFixture[str], +) -> None: + stdout_identity = _fd_identity(1) + stderr_identity = _fd_identity(2) + + exitcode, stdout, stderr = await _run_and_collect(_close_stdio_and_exit) + os.write(1, b"parent stdout after child closed stdio\n") + os.write(2, b"parent stderr after child closed stdio\n") + parent_output = capfd.readouterr() + + assert exitcode == 0 + assert stdout == b"" + assert stderr == b"" + assert _fd_identity(1) == stdout_identity + assert _fd_identity(2) == stderr_identity + assert "parent stdout after child closed stdio" in parent_output.out + assert "parent stderr after child closed stdio" in parent_output.err + + +@pytest.mark.anyio +async def test_repeated_crashing_children_do_not_grow_parent_fd_table() -> None: + await _run_and_collect(_exit_after_stdio_write, args=("warmup", 23)) + before = _fd_count() + if before is None: + pytest.skip("fd table count is not available on this platform") + + for iteration in range(20): + exitcode, stdout, stderr = await _run_and_collect( + _exit_after_stdio_write, + args=(f"fd-child-{iteration}", 31), + ) + + assert exitcode == 31 + assert stdout == f"fd-child-{iteration}: stdout before _exit\n".encode() + assert stderr == f"fd-child-{iteration}: stderr before _exit\n".encode() + + after = _fd_count() + assert after is not None + assert after <= before + 2 + + +@pytest.mark.anyio +async def test_stop_allows_child_to_exit_after_sigterm() -> None: + process = AsyncProcess(_exit_on_sigterm, args=(43,)) + + async with _started_process(process): + assert await process.stdout.receive() == b"sigterm-ready\n" + + with fail_after(2): + await process.stop() + + assert process.exitcode == 43 + + +@pytest.mark.anyio +async def test_stop_retries_sigterm_before_sigkill(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setattr(async_process, "_TERMINATE_GRACE_SECONDS", 0.01) + monkeypatch.setattr(async_process, "_TERMINATE_RETRY_GRACE_SECONDS", 0.01) + process = AsyncProcess(_exit_after_repeated_sigterm, args=(3, 44)) + + async with _started_process(process): + assert await process.stdout.receive() == b"sigterm-ready\n" + + with fail_after(2): + await process.stop() + + assert process.exitcode == 44 + + +@pytest.mark.anyio +async def test_stop_escalates_to_sigkill_when_child_ignores_sigterm( + monkeypatch: MonkeyPatch, +) -> None: + monkeypatch.setattr(async_process, "_TERMINATE_GRACE_SECONDS", 0.1) + monkeypatch.setattr(async_process, "_TERMINATE_RETRY_GRACE_SECONDS", 0.01) + process = AsyncProcess(_ignore_sigterm_forever) + + async with _started_process(process): + assert await process.stdout.receive() == b"sigterm-ready\n" + + with fail_after(3): + await process.stop() + + assert process.exitcode == -signal.SIGKILL + + +@pytest.mark.anyio +async def test_process_can_use_mp_channel_with_global_spawn_context() -> None: + send, recv = mp_channel[str]() + process = AsyncProcess(_send_over_mp_channel, args=(send,)) + + async with _started_process(process): + with fail_after(2): + assert await recv.receive_async() == "hello from child" + assert await process.wait() == 0 + + with contextlib.suppress(Exception): + recv.close() + + +@pytest.mark.anyio +@pytest.mark.skip(reason="manual MLX OOM isolation check") +async def test_death(capsys: CaptureFixture[str]) -> None: + with capsys.disabled(): + process = AsyncProcess(_mlx_force_oom) + stdout = b"" + stderr = b"" + async with _started_process(process): + _, stdout, stderr = await _collect_process_output(process) + + print("PARENT: done") + + print("CHILD out:", stdout.decode("utf-8", errors="replace")) + print("CHILD err:", stderr.decode("utf-8", errors="replace"), "hello :)") diff --git a/src/exo/utils/tests/test_daemon.py b/src/exo/utils/tests/test_daemon.py new file mode 100644 index 000000000..964afebf8 --- /dev/null +++ b/src/exo/utils/tests/test_daemon.py @@ -0,0 +1,168 @@ +import contextlib +import os +from collections.abc import AsyncIterator + +import anyio +import pytest +from anyio import EndOfStream, create_task_group, fail_after + +from exo.utils.async_process import AsyncProcess +from exo.utils.channels import MpReceiver, MpSender, Receiver, mp_channel +from exo.utils.daemon import detach_stdio_to_devnull + + +def _write_before_and_after_detach() -> None: + os.write(1, b"before stdout\n") + os.write(2, b"before stderr\n") + detach_stdio_to_devnull() + os.write(1, b"after stdout\n") + os.write(2, b"after stderr\n") + + +def _write_grandchild_stdio(label: str) -> None: + os.write(1, f"{label} stdout\n".encode()) + os.write(2, f"{label} stderr\n".encode()) + + +async def _spawn_grandchild_and_report( + result_sender: MpSender[tuple[int, bytes, bytes]], + label: str, +) -> None: + result_sender.send(await _collect_spawned_child(label)) + result_sender.close() + + +async def _collect_spawned_child(label: str) -> tuple[int, bytes, bytes]: + process = AsyncProcess(_write_grandchild_stdio, args=(label,)) + async with _started_process(process): + return await _collect_process_output(process) + + +def _detach_stdio_then_spawn_captured_child( + result_sender: MpSender[tuple[int, bytes, bytes]], +) -> None: + detach_stdio_to_devnull() + anyio.run(_spawn_grandchild_and_report, result_sender, "grandchild") + + +def _detach_stdio_then_spawn_captured_children_sequentially( + result_sender: MpSender[list[tuple[int, bytes, bytes]]], +) -> None: + async def run_children() -> list[tuple[int, bytes, bytes]]: + results: list[tuple[int, bytes, bytes]] = [] + for index in range(5): + results.append(await _collect_spawned_child(f"grandchild-{index}")) + return results + + detach_stdio_to_devnull() + result_sender.send(anyio.run(run_children)) + result_sender.close() + + +async def _collect_stream(stream: Receiver[bytes], output: bytearray) -> None: + while True: + try: + output.extend(await stream.receive()) + except EndOfStream: + return + + +async def _collect_process_output( + process: AsyncProcess, +) -> tuple[int, bytes, bytes]: + stdout = bytearray() + stderr = bytearray() + exitcodes: list[int] = [] + + async with create_task_group() as collect_group: + collect_group.start_soon(_collect_stream, process.stdout, stdout) + collect_group.start_soon(_collect_stream, process.stderr, stderr) + exitcodes.append(await process.wait()) + + if not exitcodes: + raise RuntimeError("process exited without a return code") + return exitcodes[0], bytes(stdout), bytes(stderr) + + +@contextlib.asynccontextmanager +async def _started_process(process: AsyncProcess) -> AsyncIterator[None]: + async with create_task_group() as task_group: + await task_group.start(process.run) + try: + yield + finally: + await process.stop() + + +async def _run_process_and_receive[T]( + process: AsyncProcess, + recv: MpReceiver[T], + *, + timeout: float, +) -> tuple[int, T]: + async with _started_process(process): + with fail_after(timeout): + result = await recv.receive_async() + exitcode = await process.wait() + + return exitcode, result + + +@pytest.mark.anyio +async def test_detach_stdio_to_devnull_redirects_stdio_away_from_capture() -> None: + process = AsyncProcess(_write_before_and_after_detach) + + async with _started_process(process): + exitcode, stdout, stderr = await _collect_process_output(process) + + assert exitcode == 0 + assert stdout == b"before stdout\n" + assert stderr == b"before stderr\n" + + +@pytest.mark.anyio +async def test_detached_stdio_process_can_spawn_and_capture_child_stdio() -> None: + send, recv = mp_channel[tuple[int, bytes, bytes]]() + process = AsyncProcess(_detach_stdio_then_spawn_captured_child, args=(send,)) + + try: + daemonized_parent_exitcode, result = await _run_process_and_receive( + process, recv, timeout=5 + ) + finally: + recv.close() + + child_exitcode, child_stdout, child_stderr = result + + assert daemonized_parent_exitcode == 0 + assert child_exitcode == 0 + assert child_stdout == b"grandchild stdout\n" + assert child_stderr == b"grandchild stderr\n" + + +@pytest.mark.anyio +async def test_detached_stdio_process_can_spawn_captured_children_sequentially() -> ( + None +): + send, recv = mp_channel[list[tuple[int, bytes, bytes]]]() + process = AsyncProcess( + _detach_stdio_then_spawn_captured_children_sequentially, + args=(send,), + ) + + try: + daemonized_parent_exitcode, results = await _run_process_and_receive( + process, recv, timeout=10 + ) + finally: + recv.close() + + assert daemonized_parent_exitcode == 0 + assert results == [ + ( + 0, + f"grandchild-{index} stdout\n".encode(), + f"grandchild-{index} stderr\n".encode(), + ) + for index in range(5) + ] diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 7021204ac..1dddad2ae 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -115,7 +115,8 @@ def mlx_distributed_init( os.environ["MLX_HOSTFILE"] = coordination_file os.environ["MLX_RANK"] = str(rank) - os.environ["MLX_RING_VERBOSE"] = "1" + # os.environ["MLX_RING_VERBOSE"] = "1" # NOTE: we don't use it enough to care (turn on again if need to) + group = mx.distributed.init(backend="ring", strict=True) case MlxJacclInstance( diff --git a/src/exo/worker/runner/supervisor.py b/src/exo/worker/runner/supervisor.py index 8a48c6bcd..bc90d4181 100644 --- a/src/exo/worker/runner/supervisor.py +++ b/src/exo/worker/runner/supervisor.py @@ -1,5 +1,4 @@ import contextlib -import multiprocessing as mp import signal from dataclasses import dataclass, field from typing import Self @@ -8,7 +7,7 @@ import anyio from anyio import ( BrokenResourceError, ClosedResourceError, - to_thread, + EndOfStream, ) from loguru import logger @@ -41,7 +40,8 @@ from exo.shared.types.worker.runners import ( RunnerWarmingUp, ) from exo.shared.types.worker.shards import ShardMetadata -from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel +from exo.utils.async_process import AsyncProcess +from exo.utils.channels import MpReceiver, MpSender, Receiver, Sender, mp_channel from exo.utils.task_group import TaskGroup from exo.worker.runner.bootstrap import entrypoint @@ -53,7 +53,7 @@ DECODE_TIMEOUT_SECONDS = 5 class RunnerSupervisor: shard_metadata: ShardMetadata bound_instance: BoundInstance - runner_process: mp.Process + runner_process: AsyncProcess initialize_timeout: float _ev_recv: MpReceiver[Event] _task_sender: MpSender[Task] @@ -81,7 +81,7 @@ class RunnerSupervisor: task_sender, task_recv = mp_channel[Task]() cancel_sender, cancel_recv = mp_channel[TaskId]() - runner_process = mp.Process( + runner_process = AsyncProcess( target=entrypoint, args=( bound_instance, @@ -109,9 +109,25 @@ class RunnerSupervisor: return self async def run(self): - self.runner_process.start() try: async with self._tg as tg: + # start the process itself + await tg.start(self.runner_process.run) + + # start tasks to drain/collect stdout/stderr into usable errors + # + # TODO: right now it logs them as warnings, but in the future they should be split + # into being logged AND a seperate task which tries to best-effort figure out cause + # of error and package into error enum, which then is used by rest of app to act on it; + # inferring what the error is would be done by pattern-matching in the text for things + # e.g. certain VLLM error codes and so on + tg.start_soon( + self._forward_runner_output, "stdout", self.runner_process.stdout + ) + tg.start_soon( + self._forward_runner_output, "stderr", self.runner_process.stderr + ) + tg.start_soon(self._watch_runner) tg.start_soon(self._forward_events) finally: @@ -129,41 +145,11 @@ class RunnerSupervisor: with contextlib.suppress(ClosedResourceError): self._cancel_sender.close() - await to_thread.run_sync(self.runner_process.join, 5) - - if self.runner_process.is_alive(): - logger.warning( - "Runner process didn't shutdown succesfully, terminating" + with anyio.CancelScope(shield=True): + await self.runner_process.stop() + logger.info( + f"Runner process successfully terminated: {self.runner_process.exitcode}" ) - self.runner_process.terminate() - self.runner_process.join(timeout=10) - - if not self.runner_process.is_alive(): - logger.warning("Terminated nicely in the first attempt!") - - else: - # Try really hard to terminate - for i in range(2, 11): - self.runner_process.terminate() - self.runner_process.join(timeout=2) - if not self.runner_process.is_alive(): - logger.warning(f"That took {i} attempts :)") - break - # Try even harder to kill - else: - logger.critical( - "Runner process didn't respond to SIGTERM, killing" - ) - j = 0 - while self.runner_process.is_alive(): - j += 1 - self.runner_process.kill() - self.runner_process.join(timeout=5) - logger.warning(f"That took {j} attempts :(") - else: - logger.info("Runner process succesfully terminated") - - self.runner_process.close() def shutdown(self): self._tg.cancel_tasks() @@ -249,13 +235,33 @@ class RunnerSupervisor: if not self.runner_process.is_alive(): await self._check_runner(RuntimeError("Runner found to be dead")) + async def _forward_runner_output( + self, + stream_name: str, + stream: Receiver[bytes], + ) -> None: + while True: + try: + chunk = await stream.receive() + except (EndOfStream, ClosedResourceError, BrokenResourceError): + return + + message = chunk.decode("utf-8", errors="replace").rstrip() + if not message: + continue + if stream_name == "stderr": + logger.warning(f"Runner stderr: {message}") + else: + logger.debug(f"Runner stdout: {message}") + async def _check_runner(self, e: Exception) -> None: if not self._cancel_watch_runner.cancel_called: self._cancel_watch_runner.cancel() logger.info("Checking runner's status") if self.runner_process.is_alive(): - logger.info("Runner was found to be alive, attempting to join process") - await to_thread.run_sync(self.runner_process.join, 5) + logger.info("Runner was found to be alive, stopping process") + with anyio.CancelScope(shield=True): + await self.runner_process.stop() rc = self.runner_process.exitcode logger.info(f"Runner exited with exit code {rc}") if rc == 0: diff --git a/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py b/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py index 39c991a19..3ea7c261a 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py +++ b/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py @@ -1,4 +1,3 @@ -import multiprocessing as mp from typing import cast import anyio @@ -16,6 +15,7 @@ from exo.shared.types.text_generation import ( ) from exo.shared.types.worker.instances import BoundInstance, InstanceId from exo.shared.types.worker.runners import RunnerFailed, RunnerId +from exo.utils.async_process import AsyncProcess from exo.utils.channels import channel, mp_channel from exo.worker.runner.supervisor import RunnerSupervisor from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance @@ -24,23 +24,11 @@ from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance class _DeadProcess: exitcode = -6 - def start(self) -> None: - return None - def is_alive(self) -> bool: return False - def join(self, _timeout: float | None = None) -> None: - return None - def terminate(self) -> None: - return None - - def kill(self) -> None: - return None - - -@pytest.mark.asyncio +@pytest.mark.anyio async def test_check_runner_emits_error_chunk_for_inflight_text_generation() -> None: event_sender, event_receiver = channel[Event]() task_sender, _ = mp_channel[Task]() @@ -57,7 +45,7 @@ async def test_check_runner_emits_error_chunk_for_inflight_text_generation() -> supervisor = RunnerSupervisor( shard_metadata=bound_instance.bound_shard, bound_instance=bound_instance, - runner_process=cast("mp.Process", cast(object, _DeadProcess())), + runner_process=cast(AsyncProcess, cast(object, _DeadProcess())), initialize_timeout=400, _ev_recv=ev_recv, _task_sender=task_sender,