mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 11:43:17 -05:00
Compare commits
2 Commits
alexcheema
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
572e647908 | ||
|
|
e59ebd986d |
@@ -118,9 +118,10 @@
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit (self'.packages) metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
default = self'.packages.exo;
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -53,11 +53,10 @@ class DownloadCoordinator:
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
|
||||
@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
# I marked this as a dataclass as I want trivial constructors.
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
@@ -136,7 +135,6 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -148,6 +146,8 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
|
||||
@@ -1320,29 +1320,40 @@ class API:
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
tg.start_soon(self.run_api, shutdown_ev)
|
||||
try:
|
||||
await anyio.sleep_forever()
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
shutdown_ev.set()
|
||||
finally:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
cfg.bind = [f"0.0.0.0:{self.port}"]
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = None
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
cfg,
|
||||
shutdown_trigger=lambda: anyio.sleep_forever(),
|
||||
shutdown_trigger=ev.wait,
|
||||
)
|
||||
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -96,16 +96,18 @@ class Master:
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
|
||||
@@ -9,6 +9,7 @@ from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
move_on_after,
|
||||
sleep_forever,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -146,18 +147,21 @@ class Router:
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
with move_on_after(1, shield=True):
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
@@ -166,12 +170,12 @@ class Router:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
logger.info(f"Subscribed to {topic}")
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
logger.info(f"Unsubscribed from {topic}")
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
|
||||
@@ -86,28 +86,29 @@ class Election:
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Election")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election finished")
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
finally:
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election shutdown")
|
||||
|
||||
async def elect(self, em: ElectionMessage) -> None:
|
||||
logger.debug(f"Electing: {em}")
|
||||
|
||||
@@ -194,9 +194,10 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
return await to_thread.run_sync(
|
||||
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -98,21 +98,23 @@ class Worker:
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
finally:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
logger.info("Stopping Worker")
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
|
||||
@@ -8,10 +8,8 @@ import anyio
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
to_thread,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
@@ -49,7 +47,6 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
@@ -93,28 +90,29 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._forward_events)
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
@@ -122,10 +120,6 @@ class RunnerSupervisor:
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
|
||||
@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
wait
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
@@ -34,21 +34,13 @@ reset=$'\e[0m'
|
||||
i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
{
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
|
||||
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
} <<'EOF'
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q "$1"
|
||||
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
|
||||
EOF
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..."
|
||||
until curl -sf "http://$host:52415/models"; do sleep 1; done
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user