mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-19 12:15:07 -04:00
Andrei/mp capture stdio (#2056)
## Motivation Process-isolated runner crashes and C-extension failures can write directly to fd-level stdout/stderr, bypassing Python/loguru. We need to capture that output per runner process without polluting the main process or other workers, and without breaking operation when the parent stdio is detached. ## Changes - Added `AsyncProcess`, a spawn-only multiprocessing wrapper that redirects child stdout/stderr to pipes and exposes them as in-memory `Receiver[bytes]`s - Replaced runner-supervisor's raw `multiprocessing.Process` usage with `AsyncProcess` - Added `--no-stdio`, redirecting stdin/stdout/stderr to `/dev/null` after logging is configured - Disabled verbose MLX - Added tests covering stdio capture, child crashes, repeated bad children, SIGTERM/SIGKILL shutdown escalation, stdio detachment, and spawning captured children from a stdio-detached parent ## Why It Works The parent can redirect its own stdio fds to `/dev/null`, while `AsyncProcess` installs fresh pipe fds over fd 1 and 2 inside each spawned child. That keeps stdio-detached parents quiet while preserving per-runner stdout/stderr capture. Runner shutdown is still bounded: SIGTERM grace first, then SIGKILL escalation if needed. Next direction: the runner supervisor currently drains captured output and logs it as stdout/debug and stderr/warning. This should be split into more useful process-isolated error reporting instead of just log forwarding (regex match on errors to obtain "reason" string, best effort). ## Test Plan ### Manual Testing Ran on 4 Mac Minis in a Thunderbolt 4 ring, can see that runner's stdout/stderr contents are being captured. ### Automated Testing - Added async-process tests for fd-level stdout/stderr capture, Python traceback capture, bounded-buffer output, child `exit`/abort, parent stdio preservation, fd leak checks, spawn-context mp channels, and SIGTERM/SIGKILL shutdown behavior - Added stdio-detach tests proving stdio detaches to `/dev/null`, a stdio-detached parent can still spawn and capture a child, and the same stdio-detached parent can spawn/capture multiple children sequentially - Updated runner-supervisor tests for the new `AsyncProcess.exitcode` path
This commit is contained in:
@@ -23,6 +23,7 @@ from exo.shared.election import Election, ElectionResult
|
||||
from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.daemon import detach_stdio_to_devnull
|
||||
from exo.utils.pidfile import PidfileLockError, acquire_exo_pidfile
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
@@ -282,6 +283,10 @@ def main():
|
||||
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
if args.no_stdio:
|
||||
detach_stdio_to_devnull()
|
||||
logger.info("Detached stdio to /dev/null")
|
||||
|
||||
logger.info(f"{'=' * 40}")
|
||||
logger.info(f"Starting EXO | pid={os.getpid()}")
|
||||
logger.info(f"{'=' * 40}")
|
||||
@@ -330,6 +335,7 @@ class Args(FrozenModel):
|
||||
offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true"
|
||||
no_batch: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
no_stdio: bool = False
|
||||
bootstrap_peers: list[str] = []
|
||||
libp2p_port: int
|
||||
|
||||
@@ -389,6 +395,11 @@ class Args(FrozenModel):
|
||||
action="store_true",
|
||||
help="Disable continuous batching, use sequential generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stdio",
|
||||
action="store_true",
|
||||
help="Detach stdin/stdout/stderr to /dev/null after logging is configured",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bootstrap-peers",
|
||||
type=lambda s: [p for p in s.split(",") if p],
|
||||
|
||||
290
src/exo/utils/async_process.py
Normal file
290
src/exo/utils/async_process.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import faulthandler
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.resource_sharer import DupFd
|
||||
from typing import final
|
||||
|
||||
from anyio import (
|
||||
TASK_STATUS_IGNORED,
|
||||
BrokenResourceError,
|
||||
CancelScope,
|
||||
ClosedResourceError,
|
||||
Event,
|
||||
create_task_group,
|
||||
move_on_after,
|
||||
sleep,
|
||||
wait_readable,
|
||||
)
|
||||
from anyio.abc import TaskStatus
|
||||
from loguru import logger
|
||||
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
|
||||
_STDOUT_FD = 1
|
||||
_STDERR_FD = 2
|
||||
_READ_CHUNK_SIZE = 64 * 1024
|
||||
_TERMINATE_GRACE_SECONDS = 10.0
|
||||
_TERMINATE_RETRY_GRACE_SECONDS = 2.0
|
||||
_TERMINATE_ATTEMPTS = 10
|
||||
_KILL_GRACE_SECONDS = 5.0
|
||||
|
||||
|
||||
@final
|
||||
class AsyncProcess:
|
||||
def __init__(
|
||||
self,
|
||||
target: Callable[..., object] | None = None,
|
||||
name: str | None = None,
|
||||
args: Iterable[object] = (),
|
||||
kwargs: Mapping[str, object] | None = None,
|
||||
*,
|
||||
daemon: bool | None = None,
|
||||
) -> None:
|
||||
# setup state
|
||||
self._target = target
|
||||
self._name = name
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self._daemon = daemon
|
||||
|
||||
# lifecycle state
|
||||
self._process: BaseProcess | None = None
|
||||
self._pid: int | None = None
|
||||
self._stdout_tx, self._stdout_rx = channel[bytes]()
|
||||
self._stderr_tx, self._stderr_rx = channel[bytes]()
|
||||
self._started = Event()
|
||||
self._done = Event()
|
||||
self._run_cancel_scope: CancelScope | None = None
|
||||
self._start_error: BaseException | None = None
|
||||
self._exitcode: int | None = None
|
||||
|
||||
async def run(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
|
||||
if self._run_cancel_scope is not None or self._done.is_set():
|
||||
raise RuntimeError("process has already been started")
|
||||
|
||||
stdout_read_fd: int | None = None
|
||||
stdout_write_fd: int | None = None
|
||||
stderr_read_fd: int | None = None
|
||||
stderr_write_fd: int | None = None
|
||||
|
||||
def cleanup_stdio_fd() -> None:
|
||||
nonlocal stdout_read_fd, stdout_write_fd, stderr_read_fd, stderr_write_fd
|
||||
stdout_read_fd = _close_fd(stdout_read_fd)
|
||||
stdout_write_fd = _close_fd(stdout_write_fd)
|
||||
stderr_read_fd = _close_fd(stderr_read_fd)
|
||||
stderr_write_fd = _close_fd(stderr_write_fd)
|
||||
|
||||
try:
|
||||
with CancelScope() as run_cancel_scope:
|
||||
self._run_cancel_scope = run_cancel_scope
|
||||
stdout_read_fd, stdout_write_fd = os.pipe()
|
||||
stderr_read_fd, stderr_write_fd = os.pipe()
|
||||
|
||||
process = mp.Process(
|
||||
target=_run_with_captured_stdio,
|
||||
name=self._name,
|
||||
args=(
|
||||
DupFd(stdout_write_fd),
|
||||
DupFd(stderr_write_fd),
|
||||
self._target,
|
||||
*self._args,
|
||||
),
|
||||
kwargs={} if self._kwargs is None else self._kwargs,
|
||||
daemon=self._daemon,
|
||||
)
|
||||
process.start()
|
||||
pid = process.pid
|
||||
if pid is None:
|
||||
raise RuntimeError("started process has no pid")
|
||||
|
||||
# important to close parent write-side FD to prevent hangs
|
||||
stdout_write_fd = _close_fd(stdout_write_fd)
|
||||
stderr_write_fd = _close_fd(stderr_write_fd)
|
||||
|
||||
self._process = process
|
||||
self._pid = pid
|
||||
self._started.set()
|
||||
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(_drain_fd, stdout_read_fd, self._stdout_tx)
|
||||
stdout_read_fd = None
|
||||
tg.start_soon(_drain_fd, stderr_read_fd, self._stderr_tx)
|
||||
stderr_read_fd = None
|
||||
task_status.started()
|
||||
await self.wait()
|
||||
except BaseException as exc:
|
||||
if not self._started.is_set():
|
||||
self._start_error = exc
|
||||
self._started.set()
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
with CancelScope(shield=True):
|
||||
await self._terminate_if_still_alive()
|
||||
finally:
|
||||
cleanup_stdio_fd()
|
||||
for tx in (self._stdout_tx, self._stderr_tx):
|
||||
with contextlib.suppress(Exception):
|
||||
await tx.aclose()
|
||||
if self._process is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
self._process.close()
|
||||
self._run_cancel_scope = None
|
||||
self._done.set()
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._run_cancel_scope is None and not self._done.is_set():
|
||||
raise RuntimeError("process has not been started")
|
||||
if self._run_cancel_scope is not None:
|
||||
self._run_cancel_scope.cancel()
|
||||
await self._done.wait()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.stop()
|
||||
|
||||
async def wait(self) -> int:
|
||||
if self._exitcode is not None:
|
||||
return self._exitcode
|
||||
|
||||
await self._started.wait()
|
||||
if self._start_error is not None:
|
||||
raise self._start_error
|
||||
assert self._process is not None
|
||||
|
||||
while True:
|
||||
exitcode = self.exitcode
|
||||
if exitcode is not None:
|
||||
return exitcode
|
||||
await sleep(0.01)
|
||||
|
||||
@property
|
||||
def pid(self) -> int:
|
||||
if self._pid is None:
|
||||
raise RuntimeError("process has not been started")
|
||||
return self._pid
|
||||
|
||||
@property
|
||||
def exitcode(self) -> int | None:
|
||||
if self._exitcode is not None:
|
||||
return self._exitcode
|
||||
if self._process is None:
|
||||
return None
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
exitcode = self._process.exitcode
|
||||
if exitcode is not None:
|
||||
self._exitcode = exitcode
|
||||
return exitcode
|
||||
return None
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
if self._process is None:
|
||||
return False
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
return self._process.is_alive()
|
||||
return False
|
||||
|
||||
# TODO: maybe in the future if needed, create stdin that is also installed,
|
||||
# and a ByteSendStream handle is provided for it :)
|
||||
|
||||
@property
|
||||
def stdout(self) -> Receiver[bytes]:
|
||||
return self._stdout_rx
|
||||
|
||||
@property
|
||||
def stderr(self) -> Receiver[bytes]:
|
||||
return self._stderr_rx
|
||||
|
||||
async def _terminate_if_still_alive(self) -> None:
|
||||
process = self._process
|
||||
if process is None:
|
||||
return
|
||||
|
||||
if self.exitcode is not None:
|
||||
return
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
if not process.is_alive():
|
||||
return
|
||||
|
||||
logger.warning("Child process didn't shut down successfully, terminating")
|
||||
process.terminate()
|
||||
with move_on_after(_TERMINATE_GRACE_SECONDS):
|
||||
await self.wait()
|
||||
|
||||
if self.exitcode is not None or not process.is_alive():
|
||||
logger.warning("Terminated nicely in the first attempt!")
|
||||
return
|
||||
|
||||
for attempt in range(2, _TERMINATE_ATTEMPTS + 1):
|
||||
process.terminate()
|
||||
with move_on_after(_TERMINATE_RETRY_GRACE_SECONDS):
|
||||
await self.wait()
|
||||
|
||||
if self.exitcode is not None or not process.is_alive():
|
||||
logger.warning(f"That took {attempt} attempts :)")
|
||||
return
|
||||
|
||||
logger.critical("Child process didn't respond to SIGTERM, killing")
|
||||
j = 0
|
||||
while True:
|
||||
process.kill()
|
||||
with move_on_after(_KILL_GRACE_SECONDS):
|
||||
await self.wait()
|
||||
j += 1
|
||||
if self.exitcode is not None or not process.is_alive():
|
||||
break
|
||||
logger.warning(f"That took {j} attempts :(")
|
||||
|
||||
|
||||
# Spawn-mode multiprocessing requires a module-level target that can be pickled.
|
||||
def _run_with_captured_stdio(
|
||||
stdout: DupFd,
|
||||
stderr: DupFd,
|
||||
target: Callable[..., object] | None,
|
||||
*target_args: object,
|
||||
**target_kwargs: object,
|
||||
) -> None:
|
||||
stdout_fd = stdout.detach()
|
||||
stderr_fd = stderr.detach()
|
||||
|
||||
try:
|
||||
os.dup2(stdout_fd, _STDOUT_FD)
|
||||
os.dup2(stderr_fd, _STDERR_FD)
|
||||
finally:
|
||||
for fd in (stdout_fd, stderr_fd):
|
||||
if fd not in (_STDOUT_FD, _STDERR_FD):
|
||||
_close_fd(fd)
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=True)
|
||||
if target is not None:
|
||||
target(*target_args, **target_kwargs)
|
||||
|
||||
|
||||
async def _drain_fd(fd: int, tx: Sender[bytes]) -> None:
|
||||
try:
|
||||
while True:
|
||||
await wait_readable(fd)
|
||||
chunk = os.read(fd, _READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
return
|
||||
await tx.send(chunk)
|
||||
except (BrokenPipeError, BrokenResourceError, ClosedResourceError):
|
||||
pass
|
||||
finally:
|
||||
_close_fd(fd)
|
||||
await tx.aclose()
|
||||
|
||||
|
||||
def _close_fd(fd: int | None) -> None:
|
||||
if fd is None:
|
||||
return
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
28
src/exo/utils/daemon.py
Normal file
28
src/exo/utils/daemon.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
_STDIN_FD = 0
|
||||
_STDOUT_FD = 1
|
||||
_STDERR_FD = 2
|
||||
|
||||
|
||||
def detach_stdio_to_devnull() -> None:
|
||||
"""Redirect process stdio file descriptors to /dev/null."""
|
||||
|
||||
for stream in (sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__):
|
||||
if stream is not None:
|
||||
stream.flush()
|
||||
|
||||
stdin_fd = os.open(os.devnull, os.O_RDONLY)
|
||||
stdout_fd = os.open(os.devnull, os.O_WRONLY)
|
||||
stderr_fd = os.open(os.devnull, os.O_WRONLY)
|
||||
|
||||
try:
|
||||
# dup2 closes the target fd first, but leaves the source fd open.
|
||||
os.dup2(stdin_fd, _STDIN_FD)
|
||||
os.dup2(stdout_fd, _STDOUT_FD)
|
||||
os.dup2(stderr_fd, _STDERR_FD)
|
||||
finally:
|
||||
for fd in (stdin_fd, stdout_fd, stderr_fd):
|
||||
if fd not in (_STDIN_FD, _STDOUT_FD, _STDERR_FD):
|
||||
os.close(fd)
|
||||
515
src/exo/utils/tests/test_async_process.py
Normal file
515
src/exo/utils/tests/test_async_process.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from types import FrameType
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
from _pytest.capture import CaptureFixture
|
||||
from anyio import EndOfStream, create_task_group, fail_after
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
import exo.utils.async_process as async_process
|
||||
from exo.utils.async_process import (
|
||||
AsyncProcess,
|
||||
)
|
||||
from exo.utils.channels import MpSender, Receiver, mp_channel
|
||||
|
||||
|
||||
def _write_to_stdio(prefix: str, *, stderr_suffix: str) -> None:
|
||||
print(f"{prefix}: python stdout")
|
||||
print(f"{prefix}: python stderr {stderr_suffix}", file=sys.stderr)
|
||||
os.write(1, f"{prefix}: fd stdout\n".encode())
|
||||
os.write(2, f"{prefix}: fd stderr {stderr_suffix}\n".encode())
|
||||
|
||||
|
||||
def _write_large_output() -> None:
|
||||
os.write(1, b"stdout-0123456789")
|
||||
os.write(2, b"stderr-0123456789")
|
||||
|
||||
|
||||
def _write_all(fd: int, data: bytes) -> None:
|
||||
remaining = memoryview(data)
|
||||
while remaining:
|
||||
written = os.write(fd, remaining)
|
||||
remaining = remaining[written:]
|
||||
|
||||
|
||||
def _write_large_exact_output(size: int) -> None:
|
||||
_write_all(1, b"stdout:" + (b"x" * size))
|
||||
_write_all(2, b"stderr:" + (b"y" * size))
|
||||
|
||||
|
||||
def _raise_after_stderr_write() -> None:
|
||||
os.write(2, b"stderr before exception\n")
|
||||
raise RuntimeError("child boom")
|
||||
|
||||
|
||||
def _exit_after_stdio_write(prefix: str, exitcode: int) -> None:
|
||||
os.write(1, f"{prefix}: stdout before _exit\n".encode())
|
||||
os.write(2, f"{prefix}: stderr before _exit\n".encode())
|
||||
os._exit(exitcode)
|
||||
|
||||
|
||||
def _abort_after_stdio_write(prefix: str) -> None:
|
||||
os.write(1, f"{prefix}: stdout before abort\n".encode())
|
||||
os.write(2, f"{prefix}: stderr before abort\n".encode())
|
||||
os.abort()
|
||||
|
||||
|
||||
def _close_stdio_and_exit() -> None:
|
||||
os.close(1)
|
||||
os.close(2)
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def _exit_on_sigterm(exitcode: int) -> None:
|
||||
def handle_sigterm(_signum: int, _frame: FrameType | None) -> None:
|
||||
os._exit(exitcode)
|
||||
|
||||
signal.signal(signal.SIGTERM, handle_sigterm)
|
||||
os.write(1, b"sigterm-ready\n")
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def _exit_after_repeated_sigterm(required_count: int, exitcode: int) -> None:
|
||||
sigterm_count = 0
|
||||
|
||||
def handle_sigterm(_signum: int, _frame: FrameType | None) -> None:
|
||||
nonlocal sigterm_count
|
||||
sigterm_count += 1
|
||||
if sigterm_count >= required_count:
|
||||
os._exit(exitcode)
|
||||
|
||||
signal.signal(signal.SIGTERM, handle_sigterm)
|
||||
os.write(1, b"sigterm-ready\n")
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def _ignore_sigterm_forever() -> None:
|
||||
signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||
os.write(1, b"sigterm-ready\n")
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def _sleep_forever() -> None:
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def _send_over_mp_channel(send: MpSender[str]) -> None:
|
||||
send.send("hello from child")
|
||||
send.close()
|
||||
|
||||
|
||||
def _mlx_force_oom(size: int = 40_000) -> None:
|
||||
"""
|
||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||
"""
|
||||
print("CHILD: start")
|
||||
|
||||
mx.set_default_device(mx.gpu)
|
||||
a = mx.random.uniform(shape=(size, size), dtype=mx.float32)
|
||||
b = mx.random.uniform(shape=(size, size), dtype=mx.float32)
|
||||
mx.eval(a, b)
|
||||
c = mx.matmul(a, b)
|
||||
d = mx.matmul(a, c)
|
||||
e = mx.matmul(b, c)
|
||||
f = mx.sigmoid(d + e)
|
||||
mx.eval(f)
|
||||
|
||||
print("CHILD: end")
|
||||
|
||||
|
||||
async def _collect_stream(
|
||||
stream: Receiver[bytes],
|
||||
output: bytearray,
|
||||
) -> None:
|
||||
while True:
|
||||
try:
|
||||
output.extend(await stream.receive())
|
||||
except EndOfStream:
|
||||
return
|
||||
|
||||
|
||||
async def _collect_process_output(
|
||||
process: AsyncProcess,
|
||||
) -> tuple[int, bytes, bytes]:
|
||||
stdout = bytearray()
|
||||
stderr = bytearray()
|
||||
exitcodes: list[int] = []
|
||||
|
||||
async with create_task_group() as task_group:
|
||||
task_group.start_soon(_collect_stream, process.stdout, stdout)
|
||||
task_group.start_soon(_collect_stream, process.stderr, stderr)
|
||||
exitcodes.append(await process.wait())
|
||||
|
||||
if not exitcodes:
|
||||
raise RuntimeError("process exited without a return code")
|
||||
return exitcodes[0], bytes(stdout), bytes(stderr)
|
||||
|
||||
|
||||
def _fd_identity(fd: int) -> tuple[int, int]:
|
||||
fd_stat = os.fstat(fd)
|
||||
return fd_stat.st_dev, fd_stat.st_ino
|
||||
|
||||
|
||||
def _fd_count() -> int | None:
|
||||
for fd_dir in ("/proc/self/fd", "/dev/fd"):
|
||||
with contextlib.suppress(OSError):
|
||||
return len(os.listdir(fd_dir))
|
||||
return None
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _started_process(process: AsyncProcess) -> AsyncIterator[None]:
|
||||
async with create_task_group() as task_group:
|
||||
await task_group.start(process.run)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await process.stop()
|
||||
|
||||
|
||||
async def _run_and_collect(
|
||||
target: Callable[..., object] | None,
|
||||
*,
|
||||
args: tuple[object, ...] = (),
|
||||
kwargs: dict[str, object] | None = None,
|
||||
) -> tuple[int, bytes, bytes]:
|
||||
process = AsyncProcess(
|
||||
target,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
async with _started_process(process):
|
||||
return await _collect_process_output(process)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_spawn_process_captures_stdout_and_stderr_separately(
|
||||
capfd: CaptureFixture[str],
|
||||
) -> None:
|
||||
process = AsyncProcess(
|
||||
_write_to_stdio,
|
||||
args=("child",),
|
||||
kwargs={"stderr_suffix": "error"},
|
||||
)
|
||||
async with _started_process(process):
|
||||
exitcode, stdout_bytes, stderr_bytes = await _collect_process_output(process)
|
||||
|
||||
parent_output = capfd.readouterr()
|
||||
stdout = stdout_bytes.decode("utf-8", errors="replace")
|
||||
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
assert exitcode == 0
|
||||
assert "child: python stdout" in stdout
|
||||
assert "child: fd stdout" in stdout
|
||||
assert "child: python stderr error" in stderr
|
||||
assert "child: fd stderr error" in stderr
|
||||
assert "child:" not in parent_output.out
|
||||
assert "child:" not in parent_output.err
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_process_with_no_target_exits_successfully() -> None:
|
||||
exitcode, stdout, stderr = await _run_and_collect(None)
|
||||
|
||||
assert exitcode == 0
|
||||
assert stdout == b""
|
||||
assert stderr == b""
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_output_receivers_and_wait_are_safe_immediately_after_run_starts() -> (
|
||||
None
|
||||
):
|
||||
process = AsyncProcess(
|
||||
_write_to_stdio,
|
||||
args=("immediate",),
|
||||
kwargs={"stderr_suffix": "error"},
|
||||
)
|
||||
result: tuple[int, bytes, bytes] | None = None
|
||||
|
||||
async with create_task_group() as task_group:
|
||||
await task_group.start(process.run)
|
||||
try:
|
||||
result = await _collect_process_output(process)
|
||||
finally:
|
||||
await process.stop()
|
||||
|
||||
assert result is not None
|
||||
exitcode, stdout, stderr = result
|
||||
assert exitcode == 0
|
||||
assert b"immediate: fd stdout\n" in stdout
|
||||
assert b"immediate: fd stderr error\n" in stderr
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stop_before_run_raises() -> None:
|
||||
process = AsyncProcess(
|
||||
_write_to_stdio,
|
||||
args=("never",),
|
||||
kwargs={"stderr_suffix": "run"},
|
||||
)
|
||||
|
||||
assert not process.is_alive()
|
||||
with pytest.raises(RuntimeError, match="process has not been started"):
|
||||
await process.stop()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_process_run_is_one_shot() -> None:
|
||||
process = AsyncProcess(None)
|
||||
|
||||
await process.run()
|
||||
|
||||
with pytest.raises(RuntimeError, match="process has already been started"):
|
||||
await process.run()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_process_started_with_task_group_start_can_stop_immediately() -> None:
|
||||
process = AsyncProcess(_sleep_forever)
|
||||
|
||||
async with create_task_group() as task_group:
|
||||
await task_group.start(process.run)
|
||||
assert process.is_alive()
|
||||
with fail_after(2):
|
||||
await process.stop()
|
||||
|
||||
assert not process.is_alive()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stdout_receiver_yields_bytes_chunks() -> None:
|
||||
process = AsyncProcess(_write_large_output)
|
||||
|
||||
async with _started_process(process):
|
||||
first_stdout = await process.stdout.receive()
|
||||
exitcode, remaining_stdout, stderr = await _collect_process_output(process)
|
||||
|
||||
assert exitcode == 0
|
||||
assert first_stdout + remaining_stdout == b"stdout-0123456789"
|
||||
assert stderr == b"stderr-0123456789"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_output_can_be_read_after_process_exits() -> None:
|
||||
process = AsyncProcess(_write_large_output)
|
||||
|
||||
async with create_task_group() as task_group:
|
||||
await task_group.start(process.run)
|
||||
assert await process.wait() == 0
|
||||
|
||||
assert await process.stdout.receive() == b"stdout-0123456789"
|
||||
assert await process.stderr.receive() == b"stderr-0123456789"
|
||||
with pytest.raises(EndOfStream):
|
||||
await process.stdout.receive()
|
||||
with pytest.raises(EndOfStream):
|
||||
await process.stderr.receive()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_large_stdout_and_stderr_are_not_lost() -> None:
|
||||
size = 1024 * 1024
|
||||
exitcode, stdout, stderr = await _run_and_collect(
|
||||
_write_large_exact_output,
|
||||
args=(size,),
|
||||
)
|
||||
|
||||
assert exitcode == 0
|
||||
assert stdout == b"stdout:" + (b"x" * size)
|
||||
assert stderr == b"stderr:" + (b"y" * size)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_child_exception_traceback_is_captured_from_stderr() -> None:
|
||||
process = AsyncProcess(_raise_after_stderr_write)
|
||||
|
||||
async with _started_process(process):
|
||||
exitcode, _, stderr_bytes = await _collect_process_output(process)
|
||||
|
||||
assert exitcode == 1
|
||||
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
||||
assert "stderr before exception" in stderr
|
||||
assert "RuntimeError: child boom" in stderr
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_repeated_bad_children_do_not_pollute_or_replace_parent_stdio(
|
||||
capfd: CaptureFixture[str],
|
||||
) -> None:
|
||||
stdout_object = sys.stdout
|
||||
stderr_object = sys.stderr
|
||||
stdout_identity = _fd_identity(1)
|
||||
stderr_identity = _fd_identity(2)
|
||||
|
||||
cases: tuple[tuple[Callable[..., object], tuple[object, ...]], ...] = (
|
||||
(_raise_after_stderr_write, ()),
|
||||
(_exit_after_stdio_write, ("exit-child", 17)),
|
||||
(_abort_after_stdio_write, ("abort-child",)),
|
||||
)
|
||||
|
||||
for iteration in range(3):
|
||||
for target, args in cases:
|
||||
exitcode, stdout, stderr = await _run_and_collect(
|
||||
target,
|
||||
args=args,
|
||||
)
|
||||
|
||||
assert exitcode != 0
|
||||
if target is _exit_after_stdio_write:
|
||||
assert stdout == b"exit-child: stdout before _exit\n"
|
||||
assert stderr == b"exit-child: stderr before _exit\n"
|
||||
elif target is _abort_after_stdio_write:
|
||||
assert b"abort-child: stdout before abort\n" in stdout
|
||||
assert b"abort-child: stderr before abort\n" in stderr
|
||||
assert exitcode == -signal.SIGABRT
|
||||
else:
|
||||
assert stdout == b""
|
||||
assert b"stderr before exception\n" in stderr
|
||||
assert b"RuntimeError: child boom" in stderr
|
||||
|
||||
print(f"parent stdout still works {iteration}")
|
||||
print(f"parent stderr still works {iteration}", file=sys.stderr)
|
||||
|
||||
parent_output = capfd.readouterr()
|
||||
|
||||
assert sys.stdout is stdout_object
|
||||
assert sys.stderr is stderr_object
|
||||
assert _fd_identity(1) == stdout_identity
|
||||
assert _fd_identity(2) == stderr_identity
|
||||
assert "parent stdout still works 0" in parent_output.out
|
||||
assert "parent stdout still works 2" in parent_output.out
|
||||
assert "parent stderr still works 0" in parent_output.err
|
||||
assert "parent stderr still works 2" in parent_output.err
|
||||
assert "exit-child:" not in parent_output.out
|
||||
assert "exit-child:" not in parent_output.err
|
||||
assert "abort-child:" not in parent_output.out
|
||||
assert "abort-child:" not in parent_output.err
|
||||
assert "child boom" not in parent_output.err
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_child_can_close_stdio_without_corrupting_parent_stdio(
|
||||
capfd: CaptureFixture[str],
|
||||
) -> None:
|
||||
stdout_identity = _fd_identity(1)
|
||||
stderr_identity = _fd_identity(2)
|
||||
|
||||
exitcode, stdout, stderr = await _run_and_collect(_close_stdio_and_exit)
|
||||
os.write(1, b"parent stdout after child closed stdio\n")
|
||||
os.write(2, b"parent stderr after child closed stdio\n")
|
||||
parent_output = capfd.readouterr()
|
||||
|
||||
assert exitcode == 0
|
||||
assert stdout == b""
|
||||
assert stderr == b""
|
||||
assert _fd_identity(1) == stdout_identity
|
||||
assert _fd_identity(2) == stderr_identity
|
||||
assert "parent stdout after child closed stdio" in parent_output.out
|
||||
assert "parent stderr after child closed stdio" in parent_output.err
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_repeated_crashing_children_do_not_grow_parent_fd_table() -> None:
|
||||
await _run_and_collect(_exit_after_stdio_write, args=("warmup", 23))
|
||||
before = _fd_count()
|
||||
if before is None:
|
||||
pytest.skip("fd table count is not available on this platform")
|
||||
|
||||
for iteration in range(20):
|
||||
exitcode, stdout, stderr = await _run_and_collect(
|
||||
_exit_after_stdio_write,
|
||||
args=(f"fd-child-{iteration}", 31),
|
||||
)
|
||||
|
||||
assert exitcode == 31
|
||||
assert stdout == f"fd-child-{iteration}: stdout before _exit\n".encode()
|
||||
assert stderr == f"fd-child-{iteration}: stderr before _exit\n".encode()
|
||||
|
||||
after = _fd_count()
|
||||
assert after is not None
|
||||
assert after <= before + 2
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stop_allows_child_to_exit_after_sigterm() -> None:
|
||||
process = AsyncProcess(_exit_on_sigterm, args=(43,))
|
||||
|
||||
async with _started_process(process):
|
||||
assert await process.stdout.receive() == b"sigterm-ready\n"
|
||||
|
||||
with fail_after(2):
|
||||
await process.stop()
|
||||
|
||||
assert process.exitcode == 43
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stop_retries_sigterm_before_sigkill(monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(async_process, "_TERMINATE_GRACE_SECONDS", 0.01)
|
||||
monkeypatch.setattr(async_process, "_TERMINATE_RETRY_GRACE_SECONDS", 0.01)
|
||||
process = AsyncProcess(_exit_after_repeated_sigterm, args=(3, 44))
|
||||
|
||||
async with _started_process(process):
|
||||
assert await process.stdout.receive() == b"sigterm-ready\n"
|
||||
|
||||
with fail_after(2):
|
||||
await process.stop()
|
||||
|
||||
assert process.exitcode == 44
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stop_escalates_to_sigkill_when_child_ignores_sigterm(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(async_process, "_TERMINATE_GRACE_SECONDS", 0.1)
|
||||
monkeypatch.setattr(async_process, "_TERMINATE_RETRY_GRACE_SECONDS", 0.01)
|
||||
process = AsyncProcess(_ignore_sigterm_forever)
|
||||
|
||||
async with _started_process(process):
|
||||
assert await process.stdout.receive() == b"sigterm-ready\n"
|
||||
|
||||
with fail_after(3):
|
||||
await process.stop()
|
||||
|
||||
assert process.exitcode == -signal.SIGKILL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_process_can_use_mp_channel_with_global_spawn_context() -> None:
|
||||
send, recv = mp_channel[str]()
|
||||
process = AsyncProcess(_send_over_mp_channel, args=(send,))
|
||||
|
||||
async with _started_process(process):
|
||||
with fail_after(2):
|
||||
assert await recv.receive_async() == "hello from child"
|
||||
assert await process.wait() == 0
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
recv.close()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.skip(reason="manual MLX OOM isolation check")
|
||||
async def test_death(capsys: CaptureFixture[str]) -> None:
|
||||
with capsys.disabled():
|
||||
process = AsyncProcess(_mlx_force_oom)
|
||||
stdout = b""
|
||||
stderr = b""
|
||||
async with _started_process(process):
|
||||
_, stdout, stderr = await _collect_process_output(process)
|
||||
|
||||
print("PARENT: done")
|
||||
|
||||
print("CHILD out:", stdout.decode("utf-8", errors="replace"))
|
||||
print("CHILD err:", stderr.decode("utf-8", errors="replace"), "hello :)")
|
||||
168
src/exo/utils/tests/test_daemon.py
Normal file
168
src/exo/utils/tests/test_daemon.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import contextlib
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from anyio import EndOfStream, create_task_group, fail_after
|
||||
|
||||
from exo.utils.async_process import AsyncProcess
|
||||
from exo.utils.channels import MpReceiver, MpSender, Receiver, mp_channel
|
||||
from exo.utils.daemon import detach_stdio_to_devnull
|
||||
|
||||
|
||||
def _write_before_and_after_detach() -> None:
|
||||
os.write(1, b"before stdout\n")
|
||||
os.write(2, b"before stderr\n")
|
||||
detach_stdio_to_devnull()
|
||||
os.write(1, b"after stdout\n")
|
||||
os.write(2, b"after stderr\n")
|
||||
|
||||
|
||||
def _write_grandchild_stdio(label: str) -> None:
|
||||
os.write(1, f"{label} stdout\n".encode())
|
||||
os.write(2, f"{label} stderr\n".encode())
|
||||
|
||||
|
||||
async def _spawn_grandchild_and_report(
|
||||
result_sender: MpSender[tuple[int, bytes, bytes]],
|
||||
label: str,
|
||||
) -> None:
|
||||
result_sender.send(await _collect_spawned_child(label))
|
||||
result_sender.close()
|
||||
|
||||
|
||||
async def _collect_spawned_child(label: str) -> tuple[int, bytes, bytes]:
|
||||
process = AsyncProcess(_write_grandchild_stdio, args=(label,))
|
||||
async with _started_process(process):
|
||||
return await _collect_process_output(process)
|
||||
|
||||
|
||||
def _detach_stdio_then_spawn_captured_child(
|
||||
result_sender: MpSender[tuple[int, bytes, bytes]],
|
||||
) -> None:
|
||||
detach_stdio_to_devnull()
|
||||
anyio.run(_spawn_grandchild_and_report, result_sender, "grandchild")
|
||||
|
||||
|
||||
def _detach_stdio_then_spawn_captured_children_sequentially(
|
||||
result_sender: MpSender[list[tuple[int, bytes, bytes]]],
|
||||
) -> None:
|
||||
async def run_children() -> list[tuple[int, bytes, bytes]]:
|
||||
results: list[tuple[int, bytes, bytes]] = []
|
||||
for index in range(5):
|
||||
results.append(await _collect_spawned_child(f"grandchild-{index}"))
|
||||
return results
|
||||
|
||||
detach_stdio_to_devnull()
|
||||
result_sender.send(anyio.run(run_children))
|
||||
result_sender.close()
|
||||
|
||||
|
||||
async def _collect_stream(stream: Receiver[bytes], output: bytearray) -> None:
|
||||
while True:
|
||||
try:
|
||||
output.extend(await stream.receive())
|
||||
except EndOfStream:
|
||||
return
|
||||
|
||||
|
||||
async def _collect_process_output(
|
||||
process: AsyncProcess,
|
||||
) -> tuple[int, bytes, bytes]:
|
||||
stdout = bytearray()
|
||||
stderr = bytearray()
|
||||
exitcodes: list[int] = []
|
||||
|
||||
async with create_task_group() as collect_group:
|
||||
collect_group.start_soon(_collect_stream, process.stdout, stdout)
|
||||
collect_group.start_soon(_collect_stream, process.stderr, stderr)
|
||||
exitcodes.append(await process.wait())
|
||||
|
||||
if not exitcodes:
|
||||
raise RuntimeError("process exited without a return code")
|
||||
return exitcodes[0], bytes(stdout), bytes(stderr)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _started_process(process: AsyncProcess) -> AsyncIterator[None]:
|
||||
async with create_task_group() as task_group:
|
||||
await task_group.start(process.run)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await process.stop()
|
||||
|
||||
|
||||
async def _run_process_and_receive[T](
|
||||
process: AsyncProcess,
|
||||
recv: MpReceiver[T],
|
||||
*,
|
||||
timeout: float,
|
||||
) -> tuple[int, T]:
|
||||
async with _started_process(process):
|
||||
with fail_after(timeout):
|
||||
result = await recv.receive_async()
|
||||
exitcode = await process.wait()
|
||||
|
||||
return exitcode, result
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_detach_stdio_to_devnull_redirects_stdio_away_from_capture() -> None:
|
||||
process = AsyncProcess(_write_before_and_after_detach)
|
||||
|
||||
async with _started_process(process):
|
||||
exitcode, stdout, stderr = await _collect_process_output(process)
|
||||
|
||||
assert exitcode == 0
|
||||
assert stdout == b"before stdout\n"
|
||||
assert stderr == b"before stderr\n"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_detached_stdio_process_can_spawn_and_capture_child_stdio() -> None:
|
||||
send, recv = mp_channel[tuple[int, bytes, bytes]]()
|
||||
process = AsyncProcess(_detach_stdio_then_spawn_captured_child, args=(send,))
|
||||
|
||||
try:
|
||||
daemonized_parent_exitcode, result = await _run_process_and_receive(
|
||||
process, recv, timeout=5
|
||||
)
|
||||
finally:
|
||||
recv.close()
|
||||
|
||||
child_exitcode, child_stdout, child_stderr = result
|
||||
|
||||
assert daemonized_parent_exitcode == 0
|
||||
assert child_exitcode == 0
|
||||
assert child_stdout == b"grandchild stdout\n"
|
||||
assert child_stderr == b"grandchild stderr\n"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_detached_stdio_process_can_spawn_captured_children_sequentially() -> (
|
||||
None
|
||||
):
|
||||
send, recv = mp_channel[list[tuple[int, bytes, bytes]]]()
|
||||
process = AsyncProcess(
|
||||
_detach_stdio_then_spawn_captured_children_sequentially,
|
||||
args=(send,),
|
||||
)
|
||||
|
||||
try:
|
||||
daemonized_parent_exitcode, results = await _run_process_and_receive(
|
||||
process, recv, timeout=10
|
||||
)
|
||||
finally:
|
||||
recv.close()
|
||||
|
||||
assert daemonized_parent_exitcode == 0
|
||||
assert results == [
|
||||
(
|
||||
0,
|
||||
f"grandchild-{index} stdout\n".encode(),
|
||||
f"grandchild-{index} stderr\n".encode(),
|
||||
)
|
||||
for index in range(5)
|
||||
]
|
||||
@@ -115,7 +115,8 @@ def mlx_distributed_init(
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
# os.environ["MLX_RING_VERBOSE"] = "1" # NOTE: we don't use it enough to care (turn on again if need to)
|
||||
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
@@ -8,7 +7,7 @@ import anyio
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
to_thread,
|
||||
EndOfStream,
|
||||
)
|
||||
from loguru import logger
|
||||
|
||||
@@ -41,7 +40,8 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.utils.async_process import AsyncProcess
|
||||
from exo.utils.channels import MpReceiver, MpSender, Receiver, Sender, mp_channel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
@@ -53,7 +53,7 @@ DECODE_TIMEOUT_SECONDS = 5
|
||||
class RunnerSupervisor:
|
||||
shard_metadata: ShardMetadata
|
||||
bound_instance: BoundInstance
|
||||
runner_process: mp.Process
|
||||
runner_process: AsyncProcess
|
||||
initialize_timeout: float
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
@@ -81,7 +81,7 @@ class RunnerSupervisor:
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = mp.Process(
|
||||
runner_process = AsyncProcess(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
bound_instance,
|
||||
@@ -109,9 +109,25 @@ class RunnerSupervisor:
|
||||
return self
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
# start the process itself
|
||||
await tg.start(self.runner_process.run)
|
||||
|
||||
# start tasks to drain/collect stdout/stderr into usable errors
|
||||
#
|
||||
# TODO: right now it logs them as warnings, but in the future they should be split
|
||||
# into being logged AND a seperate task which tries to best-effort figure out cause
|
||||
# of error and package into error enum, which then is used by rest of app to act on it;
|
||||
# inferring what the error is would be done by pattern-matching in the text for things
|
||||
# e.g. certain VLLM error codes and so on
|
||||
tg.start_soon(
|
||||
self._forward_runner_output, "stdout", self.runner_process.stdout
|
||||
)
|
||||
tg.start_soon(
|
||||
self._forward_runner_output, "stderr", self.runner_process.stderr
|
||||
)
|
||||
|
||||
tg.start_soon(self._watch_runner)
|
||||
tg.start_soon(self._forward_events)
|
||||
finally:
|
||||
@@ -129,41 +145,11 @@ class RunnerSupervisor:
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.close()
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning(
|
||||
"Runner process didn't shutdown succesfully, terminating"
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.runner_process.stop()
|
||||
logger.info(
|
||||
f"Runner process successfully terminated: {self.runner_process.exitcode}"
|
||||
)
|
||||
self.runner_process.terminate()
|
||||
self.runner_process.join(timeout=10)
|
||||
|
||||
if not self.runner_process.is_alive():
|
||||
logger.warning("Terminated nicely in the first attempt!")
|
||||
|
||||
else:
|
||||
# Try really hard to terminate
|
||||
for i in range(2, 11):
|
||||
self.runner_process.terminate()
|
||||
self.runner_process.join(timeout=2)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.warning(f"That took {i} attempts :)")
|
||||
break
|
||||
# Try even harder to kill
|
||||
else:
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGTERM, killing"
|
||||
)
|
||||
j = 0
|
||||
while self.runner_process.is_alive():
|
||||
j += 1
|
||||
self.runner_process.kill()
|
||||
self.runner_process.join(timeout=5)
|
||||
logger.warning(f"That took {j} attempts :(")
|
||||
else:
|
||||
logger.info("Runner process succesfully terminated")
|
||||
|
||||
self.runner_process.close()
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_tasks()
|
||||
@@ -249,13 +235,33 @@ class RunnerSupervisor:
|
||||
if not self.runner_process.is_alive():
|
||||
await self._check_runner(RuntimeError("Runner found to be dead"))
|
||||
|
||||
async def _forward_runner_output(
|
||||
self,
|
||||
stream_name: str,
|
||||
stream: Receiver[bytes],
|
||||
) -> None:
|
||||
while True:
|
||||
try:
|
||||
chunk = await stream.receive()
|
||||
except (EndOfStream, ClosedResourceError, BrokenResourceError):
|
||||
return
|
||||
|
||||
message = chunk.decode("utf-8", errors="replace").rstrip()
|
||||
if not message:
|
||||
continue
|
||||
if stream_name == "stderr":
|
||||
logger.warning(f"Runner stderr: {message}")
|
||||
else:
|
||||
logger.debug(f"Runner stdout: {message}")
|
||||
|
||||
async def _check_runner(self, e: Exception) -> None:
|
||||
if not self._cancel_watch_runner.cancel_called:
|
||||
self._cancel_watch_runner.cancel()
|
||||
logger.info("Checking runner's status")
|
||||
if self.runner_process.is_alive():
|
||||
logger.info("Runner was found to be alive, attempting to join process")
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
logger.info("Runner was found to be alive, stopping process")
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.runner_process.stop()
|
||||
rc = self.runner_process.exitcode
|
||||
logger.info(f"Runner exited with exit code {rc}")
|
||||
if rc == 0:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import multiprocessing as mp
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
@@ -16,6 +15,7 @@ from exo.shared.types.text_generation import (
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerFailed, RunnerId
|
||||
from exo.utils.async_process import AsyncProcess
|
||||
from exo.utils.channels import channel, mp_channel
|
||||
from exo.worker.runner.supervisor import RunnerSupervisor
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
@@ -24,23 +24,11 @@ from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
class _DeadProcess:
|
||||
exitcode = -6
|
||||
|
||||
def start(self) -> None:
|
||||
return None
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return False
|
||||
|
||||
def join(self, _timeout: float | None = None) -> None:
|
||||
return None
|
||||
|
||||
def terminate(self) -> None:
|
||||
return None
|
||||
|
||||
def kill(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.anyio
|
||||
async def test_check_runner_emits_error_chunk_for_inflight_text_generation() -> None:
|
||||
event_sender, event_receiver = channel[Event]()
|
||||
task_sender, _ = mp_channel[Task]()
|
||||
@@ -57,7 +45,7 @@ async def test_check_runner_emits_error_chunk_for_inflight_text_generation() ->
|
||||
supervisor = RunnerSupervisor(
|
||||
shard_metadata=bound_instance.bound_shard,
|
||||
bound_instance=bound_instance,
|
||||
runner_process=cast("mp.Process", cast(object, _DeadProcess())),
|
||||
runner_process=cast(AsyncProcess, cast(object, _DeadProcess())),
|
||||
initialize_timeout=400,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
|
||||
Reference in New Issue
Block a user