mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
five billion percent better shutdown handling
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import signal
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from dataclasses import dataclass
|
||||
@@ -5,9 +6,9 @@ from typing import Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.logging import logger
|
||||
import exo.routing.topics as topics
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
@@ -101,6 +102,7 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with anyio.create_task_group() as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
self._tg = tg
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.worker.run)
|
||||
@@ -112,13 +114,21 @@ class Node:
|
||||
tg.start_soon(self._elect_loop)
|
||||
tg.start_soon(self._listen_for_kill_command)
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
if self._tg.cancel_scope.cancel_called:
|
||||
import sys
|
||||
sys.exit(1)
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _listen_for_kill_command(self):
|
||||
assert self._tg
|
||||
with self.router.receiver(topics.COMMANDS) as commands:
|
||||
async for command in commands:
|
||||
match command.command:
|
||||
case KillCommand():
|
||||
self._tg.cancel_scope.cancel()
|
||||
self.shutdown()
|
||||
case _:
|
||||
pass
|
||||
|
||||
@@ -198,6 +208,7 @@ class Node:
|
||||
|
||||
def main():
|
||||
args = Args.parse()
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
@@ -205,7 +216,7 @@ def main():
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
|
||||
logger.info("EXO Shutdown complete")
|
||||
logger_cleanup()
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShutdown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
class RunnerFailed(BaseRunnerStatus):
|
||||
error_message: str | None = None
|
||||
|
||||
@@ -56,6 +59,7 @@ RunnerStatus = (
|
||||
| RunnerWarmingUp
|
||||
| RunnerReady
|
||||
| RunnerRunning
|
||||
| RunnerShutdown
|
||||
| RunnerFailed
|
||||
)
|
||||
|
||||
|
||||
@@ -136,7 +136,13 @@ class MpSender[T]:
|
||||
self._state.buffer.put(MP_END_OF_STREAM)
|
||||
self._state.buffer.close()
|
||||
|
||||
# == context manager support ==#
|
||||
# == unique to Mp channels ==
|
||||
def join(self) -> None:
|
||||
"""Ensure any queued messages are resolved before continuing"""
|
||||
assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
|
||||
self._state.buffer.join_thread()
|
||||
|
||||
# == context manager support ==
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
@@ -172,7 +178,8 @@ class MpReceiver[T]:
|
||||
if item is MP_END_OF_STREAM:
|
||||
self.close()
|
||||
raise EndOfStream
|
||||
return item # pyright: ignore[reportReturnType]
|
||||
assert not isinstance(item, _MpEndOfStream)
|
||||
return item
|
||||
except Empty:
|
||||
raise WouldBlock from None
|
||||
except ValueError as e:
|
||||
@@ -187,8 +194,10 @@ class MpReceiver[T]:
|
||||
if item is MP_END_OF_STREAM:
|
||||
self.close()
|
||||
raise EndOfStream from None
|
||||
return item # pyright: ignore[reportReturnType]
|
||||
assert not isinstance(item, _MpEndOfStream)
|
||||
return item
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
|
||||
@@ -197,7 +206,13 @@ class MpReceiver[T]:
|
||||
self._state.closed.set()
|
||||
self._state.buffer.close()
|
||||
|
||||
# == iterator support ==#
|
||||
# == unique to Mp channels ==
|
||||
def join(self) -> None:
|
||||
"""Block until all enqueued messages are drained off our side of the buffer"""
|
||||
assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
|
||||
self._state.buffer.join_thread()
|
||||
|
||||
# == iterator support ==
|
||||
def __iter__(self) -> Self:
|
||||
return self
|
||||
|
||||
@@ -207,7 +222,7 @@ class MpReceiver[T]:
|
||||
except EndOfStream:
|
||||
raise StopIteration from None
|
||||
|
||||
# == async iterator support ==#
|
||||
# == async iterator support ==
|
||||
def __aiter__(self) -> Self:
|
||||
return self
|
||||
|
||||
@@ -217,7 +232,7 @@ class MpReceiver[T]:
|
||||
except EndOfStream:
|
||||
raise StopAsyncIteration from None
|
||||
|
||||
# == context manager support ==#
|
||||
# == context manager support ==
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class Worker:
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
await runner.shutdown()
|
||||
runner.shutdown()
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
@@ -211,13 +211,9 @@ class Worker:
|
||||
task, initial_progress
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
await self.runners[runner_id].shutdown()
|
||||
del self.runners[runner_id]
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
case task:
|
||||
runner = self.runners[self._task_to_runner_id(task)]
|
||||
event = anyio.Event()
|
||||
await runner.start_task(task, event)
|
||||
await event.wait()
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
if self._tg:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import time
|
||||
|
||||
from exo.engines.mlx.utils_mlx import (
|
||||
mx_barrier,
|
||||
initialize_mlx,
|
||||
mlx_force_oom,
|
||||
)
|
||||
@@ -35,8 +34,9 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
RunnerShutdown
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.utils.channels import MpReceiver, MpSender, ClosedResourceError
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
from exo.worker.runner.generate import mlx_generate, warmup_inference
|
||||
|
||||
@@ -197,6 +197,12 @@ def main(
|
||||
)
|
||||
)
|
||||
case Shutdown():
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
break
|
||||
case _:
|
||||
raise ValueError("Received task outside of state machine")
|
||||
@@ -205,7 +211,13 @@ def main(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerShutdown()
|
||||
)
|
||||
)
|
||||
except ClosedResourceError:
|
||||
logger.warning("runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {runner_id} crashed with critical exception {e}"
|
||||
@@ -216,3 +228,9 @@ def main(
|
||||
runner_status=RunnerFailed(error_message=str(e)),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import contextlib
|
||||
import signal
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import Process
|
||||
from typing import Self
|
||||
|
||||
import anyio
|
||||
import psutil
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
EndOfStream,
|
||||
create_task_group,
|
||||
current_time,
|
||||
to_thread,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -22,7 +18,6 @@ from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerError,
|
||||
RunnerFailed,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
@@ -30,9 +25,6 @@ from exo.shared.types.worker.runners import (
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
from exo.worker.runner.utils import (
|
||||
get_weights_size,
|
||||
)
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
DECODE_TIMEOUT_SECONDS = 5
|
||||
@@ -64,20 +56,12 @@ class RunnerSupervisor:
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
|
||||
""" --- not doing this for now
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="child_stderr_", suffix=".log", delete=False
|
||||
) as tmp:
|
||||
err_path = tmp.name
|
||||
|
||||
"""
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
# err_path,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -107,42 +91,55 @@ class RunnerSupervisor:
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self.runner_process.kill()
|
||||
await to_thread.run_sync(self.runner_process.join)
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
async def start_task(self, task: Task, event: anyio.Event):
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGKILL. System resources may have leaked")
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
self._task_sender.send(task)
|
||||
try:
|
||||
self._task_sender.send(task)
|
||||
except ClosedResourceError:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
while True:
|
||||
try:
|
||||
event = await events.receive_async()
|
||||
except (ClosedResourceError, BrokenResourceError, EndOfStream):
|
||||
await self._check_runner()
|
||||
break
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
await self._event_sender.send(event)
|
||||
try:
|
||||
async for event in events:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
required_memory_bytes = get_weights_size(self.shard_metadata).in_bytes
|
||||
start_time = current_time()
|
||||
while True:
|
||||
available_memory_bytes = psutil.virtual_memory().available
|
||||
if available_memory_bytes >= required_memory_bytes:
|
||||
break
|
||||
if current_time() - start_time > 30.0:
|
||||
logger.warning("Runner memory not released after 30 seconds - exiting")
|
||||
break
|
||||
await anyio.sleep(1)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
@@ -150,19 +147,13 @@ class RunnerSupervisor:
|
||||
with contextlib.suppress(ValueError):
|
||||
self.runner_process.kill()
|
||||
|
||||
async def _check_runner(self) -> RunnerError | None:
|
||||
async def _check_runner(self, e: Exception) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
await to_thread.run_sync(self.runner_process.join, 1)
|
||||
rc = self.runner_process.exitcode
|
||||
if rc == 0:
|
||||
logger.warning("Runner closed communication without terminating process")
|
||||
|
||||
""" --- not doing this anymore
|
||||
try:
|
||||
with open(self.err_path, "r", errors="replace") as f:
|
||||
captured = f.read()
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self.err_path)
|
||||
"""
|
||||
#
|
||||
return
|
||||
|
||||
if isinstance(rc, int) and rc < 0:
|
||||
sig = -rc
|
||||
@@ -173,7 +164,7 @@ class RunnerSupervisor:
|
||||
else:
|
||||
cause = f"exitcode={rc}"
|
||||
|
||||
logger.opt(exception=sys.exception()).error(f"Runner terminated ({cause})")
|
||||
logger.opt(exception=e).error(f"Runner terminated ({cause})")
|
||||
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
@@ -181,4 +172,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
await self.shutdown()
|
||||
self.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user