five billion percent better shutdown handling

This commit is contained in:
Evan Quiney
2025-11-11 17:43:53 +00:00
committed by GitHub
parent aa519b8c03
commit 364087b91f
6 changed files with 113 additions and 78 deletions

View File

@@ -1,3 +1,4 @@
import signal
import argparse
import multiprocessing as mp
from dataclasses import dataclass
@@ -5,9 +6,9 @@ from typing import Self
import anyio
from anyio.abc import TaskGroup
from loguru import logger
from pydantic import PositiveInt
from exo.shared.logging import logger
import exo.routing.topics as topics
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
@@ -101,6 +102,7 @@ class Node:
async def run(self):
async with anyio.create_task_group() as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
self._tg = tg
tg.start_soon(self.router.run)
tg.start_soon(self.worker.run)
@@ -112,13 +114,21 @@ class Node:
tg.start_soon(self._elect_loop)
tg.start_soon(self._listen_for_kill_command)
def shutdown(self):
assert self._tg
# if this is our second call to shutdown, just sys.exit
if self._tg.cancel_scope.cancel_called:
import sys
sys.exit(1)
self._tg.cancel_scope.cancel()
async def _listen_for_kill_command(self):
assert self._tg
with self.router.receiver(topics.COMMANDS) as commands:
async for command in commands:
match command.command:
case KillCommand():
self._tg.cancel_scope.cancel()
self.shutdown()
case _:
pass
@@ -198,6 +208,7 @@ class Node:
def main():
args = Args.parse()
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
@@ -205,7 +216,7 @@ def main():
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
logger_cleanup()

View File

@@ -45,6 +45,9 @@ class RunnerRunning(BaseRunnerStatus):
pass
class RunnerShutdown(BaseRunnerStatus):
pass
class RunnerFailed(BaseRunnerStatus):
error_message: str | None = None
@@ -56,6 +59,7 @@ RunnerStatus = (
| RunnerWarmingUp
| RunnerReady
| RunnerRunning
| RunnerShutdown
| RunnerFailed
)

View File

@@ -136,7 +136,13 @@ class MpSender[T]:
self._state.buffer.put(MP_END_OF_STREAM)
self._state.buffer.close()
# == context manager support ==#
# == unique to Mp channels ==
def join(self) -> None:
"""Ensure any queued messages are resolved before continuing"""
assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
self._state.buffer.join_thread()
# == context manager support ==
def __enter__(self) -> Self:
return self
@@ -172,7 +178,8 @@ class MpReceiver[T]:
if item is MP_END_OF_STREAM:
self.close()
raise EndOfStream
return item # pyright: ignore[reportReturnType]
assert not isinstance(item, _MpEndOfStream)
return item
except Empty:
raise WouldBlock from None
except ValueError as e:
@@ -187,8 +194,10 @@ class MpReceiver[T]:
if item is MP_END_OF_STREAM:
self.close()
raise EndOfStream from None
return item # pyright: ignore[reportReturnType]
assert not isinstance(item, _MpEndOfStream)
return item
# nb: this function will not cancel particularly well
async def receive_async(self) -> T:
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
@@ -197,7 +206,13 @@ class MpReceiver[T]:
self._state.closed.set()
self._state.buffer.close()
# == iterator support ==#
# == unique to Mp channels ==
def join(self) -> None:
"""Block until all enqueued messages are drained off our side of the buffer"""
assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
self._state.buffer.join_thread()
# == iterator support ==
def __iter__(self) -> Self:
return self
@@ -207,7 +222,7 @@ class MpReceiver[T]:
except EndOfStream:
raise StopIteration from None
# == async iterator support ==#
# == async iterator support ==
def __aiter__(self) -> Self:
return self
@@ -217,7 +232,7 @@ class MpReceiver[T]:
except EndOfStream:
raise StopAsyncIteration from None
# == context manager support ==#
# == context manager support ==
def __enter__(self) -> Self:
return self

View File

@@ -126,7 +126,7 @@ class Worker:
self.local_event_sender.close()
self.command_sender.close()
for runner in self.runners.values():
await runner.shutdown()
runner.shutdown()
async def _event_applier(self):
with self.global_event_receiver as events:
@@ -211,13 +211,9 @@ class Worker:
task, initial_progress
)
case Shutdown(runner_id=runner_id):
await self.runners[runner_id].shutdown()
del self.runners[runner_id]
await self.runners.pop(runner_id).start_task(task)
case task:
runner = self.runners[self._task_to_runner_id(task)]
event = anyio.Event()
await runner.start_task(task, event)
await event.wait()
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
if self._tg:

View File

@@ -1,7 +1,6 @@
import time
from exo.engines.mlx.utils_mlx import (
mx_barrier,
initialize_mlx,
mlx_force_oom,
)
@@ -35,8 +34,9 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
RunnerShutdown
)
from exo.utils.channels import MpReceiver, MpSender
from exo.utils.channels import MpReceiver, MpSender, ClosedResourceError
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.generate import mlx_generate, warmup_inference
@@ -197,6 +197,12 @@ def main(
)
)
case Shutdown():
logger.info("runner shutting down")
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
break
case _:
raise ValueError("Received task outside of state machine")
@@ -205,7 +211,13 @@ def main(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerShutdown()
)
)
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")
except Exception as e:
logger.opt(exception=e).warning(
f"Runner {runner_id} crashed with critical exception {e}"
@@ -216,3 +228,9 @@ def main(
runner_status=RunnerFailed(error_message=str(e)),
)
)
finally:
event_sender.close()
task_receiver.close()
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")

View File

@@ -1,18 +1,14 @@
import contextlib
import signal
import sys
from dataclasses import dataclass, field
from multiprocessing import Process
from typing import Self
import anyio
import psutil
from anyio import (
BrokenResourceError,
ClosedResourceError,
EndOfStream,
create_task_group,
current_time,
to_thread,
)
from anyio.abc import TaskGroup
@@ -22,7 +18,6 @@ from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerError,
RunnerFailed,
RunnerStatus,
RunnerWaitingForModel,
@@ -30,9 +25,6 @@ from exo.shared.types.worker.runners import (
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.worker.runner.bootstrap import entrypoint
from exo.worker.runner.utils import (
get_weights_size,
)
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -64,20 +56,12 @@ class RunnerSupervisor:
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
""" --- not doing this for now
with tempfile.NamedTemporaryFile(
prefix="child_stderr_", suffix=".log", delete=False
) as tmp:
err_path = tmp.name
"""
runner_process = Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
# err_path,
logger,
),
daemon=True,
@@ -107,42 +91,55 @@ class RunnerSupervisor:
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join)
await to_thread.run_sync(self.runner_process.join, 30)
if not self.runner_process.is_alive():
return
async def start_task(self, task: Task, event: anyio.Event):
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGKILL. System resources may have leaked")
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
event = anyio.Event()
self.pending[task.task_id] = event
self._task_sender.send(task)
try:
self._task_sender.send(task)
except ClosedResourceError:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
async def _forward_events(self):
with self._ev_recv as events:
while True:
try:
event = await events.receive_async()
except (ClosedResourceError, BrokenResourceError, EndOfStream):
await self._check_runner()
break
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
continue
await self._event_sender.send(event)
try:
async for event in events:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
continue
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)
async def shutdown(self) -> None:
assert self._tg
self._tg.cancel_scope.cancel()
required_memory_bytes = get_weights_size(self.shard_metadata).in_bytes
start_time = current_time()
while True:
available_memory_bytes = psutil.virtual_memory().available
if available_memory_bytes >= required_memory_bytes:
break
if current_time() - start_time > 30.0:
logger.warning("Runner memory not released after 30 seconds - exiting")
break
await anyio.sleep(1)
def __del__(self) -> None:
if self.runner_process.is_alive():
@@ -150,19 +147,13 @@ class RunnerSupervisor:
with contextlib.suppress(ValueError):
self.runner_process.kill()
async def _check_runner(self) -> RunnerError | None:
async def _check_runner(self, e: Exception) -> None:
if self.runner_process.is_alive():
await to_thread.run_sync(self.runner_process.join, 1)
rc = self.runner_process.exitcode
if rc == 0:
logger.warning("Runner closed communication without terminating process")
""" --- not doing this anymore
try:
with open(self.err_path, "r", errors="replace") as f:
captured = f.read()
finally:
with contextlib.suppress(OSError):
os.unlink(self.err_path)
"""
#
return
if isinstance(rc, int) and rc < 0:
sig = -rc
@@ -173,7 +164,7 @@ class RunnerSupervisor:
else:
cause = f"exitcode={rc}"
logger.opt(exception=sys.exception()).error(f"Runner terminated ({cause})")
logger.opt(exception=e).error(f"Runner terminated ({cause})")
await self._event_sender.send(
RunnerStatusUpdated(
@@ -181,4 +172,4 @@ class RunnerSupervisor:
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
await self.shutdown()
self.shutdown()