mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
Compare commits
3 Commits
fix-instan
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad60cf09fa | ||
|
|
ebe02c4722 | ||
|
|
bf034b61e4 |
@@ -1,5 +1,3 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
@@ -35,8 +33,7 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
NodeDisconnected,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TraceEventData,
|
||||
@@ -92,6 +89,8 @@ class Master:
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
self._last_checked_indices: dict[NodeId, int] = {}
|
||||
self._stale_cycles: dict[NodeId, int] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -358,12 +357,25 @@ class Master:
|
||||
)
|
||||
break
|
||||
|
||||
# time out dead nodes
|
||||
for node_id, time in self.state.last_seen.items():
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if now - time > timedelta(seconds=30):
|
||||
logger.info(f"Manually removing node {node_id} due to inactivity")
|
||||
await self.event_sender.send(NodeTimedOut(node_id=node_id))
|
||||
# Check which nodes have gone stale
|
||||
current_indices = dict(self.state.last_event_index_by_node)
|
||||
for node_id in list(current_indices):
|
||||
last_checked = self._last_checked_indices.get(node_id, -1)
|
||||
if current_indices[node_id] == last_checked:
|
||||
self._stale_cycles[node_id] = self._stale_cycles.get(node_id, 0) + 1
|
||||
else:
|
||||
self._stale_cycles.pop(node_id, None)
|
||||
|
||||
self._last_checked_indices = current_indices
|
||||
|
||||
# Disconnect nodes stale for >= 3 consecutive cycles (~30s)
|
||||
for node_id, cycles in list(self._stale_cycles.items()):
|
||||
if cycles >= 3:
|
||||
logger.info(
|
||||
f"Removing node {node_id}: no events for {cycles} plan cycles"
|
||||
)
|
||||
del self._stale_cycles[node_id]
|
||||
await self.event_sender.send(NodeDisconnected(node_id=node_id))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
@@ -387,10 +399,6 @@ class Master:
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
event.when = str(datetime.now(tz=timezone.utc))
|
||||
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Sequence
|
||||
|
||||
import anyio
|
||||
@@ -85,7 +84,6 @@ async def test_master():
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeGatheredInfo(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
info=MemoryUsage(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import copy
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -12,9 +11,9 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeDisconnected,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -72,8 +71,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDisconnected():
|
||||
return apply_node_disconnected(event, state)
|
||||
case NodeDownloadProgress():
|
||||
return apply_node_download_progress(event, state)
|
||||
case NodeGatheredInfo():
|
||||
@@ -104,7 +103,13 @@ def apply(state: State, event: IndexedEvent) -> State:
|
||||
)
|
||||
assert state.last_event_applied_idx == event.idx - 1
|
||||
new_state: State = event_apply(event.event, state)
|
||||
return new_state.model_copy(update={"last_event_applied_idx": event.idx})
|
||||
update: dict[str, object] = {"last_event_applied_idx": event.idx}
|
||||
if isinstance(event.event, NodeGatheredInfo):
|
||||
update["last_event_index_by_node"] = {
|
||||
**new_state.last_event_index_by_node,
|
||||
event.event.node_id: event.idx,
|
||||
}
|
||||
return new_state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_node_download_progress(event: NodeDownloadProgress, state: State) -> State:
|
||||
@@ -208,11 +213,13 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"runners": new_runners})
|
||||
|
||||
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
def apply_node_disconnected(event: NodeDisconnected, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.remove_node(event.node_id)
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
last_event_index_by_node = {
|
||||
key: value
|
||||
for key, value in state.last_event_index_by_node.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
downloads = {
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
@@ -262,7 +269,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
update={
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
"last_event_index_by_node": last_event_index_by_node,
|
||||
"node_identities": node_identities,
|
||||
"node_memory": node_memory,
|
||||
"node_disk": node_disk,
|
||||
@@ -283,10 +290,6 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
|
||||
# Build update dict with only the mappings that change
|
||||
update: dict[str, object] = {
|
||||
"last_seen": {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
},
|
||||
"topology": topology,
|
||||
}
|
||||
|
||||
|
||||
@@ -77,14 +77,12 @@ class RunnerDeleted(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class NodeTimedOut(BaseEvent):
|
||||
class NodeDisconnected(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
# TODO: bikeshed this name
|
||||
class NodeGatheredInfo(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
info: GatheredInfo
|
||||
|
||||
|
||||
@@ -143,7 +141,7 @@ Event = (
|
||||
| InstanceDeleted
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
| NodeDisconnected
|
||||
| NodeGatheredInfo
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_serializer, field_validator
|
||||
@@ -44,7 +43,7 @@ class State(CamelCaseModel):
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
last_seen: Mapping[NodeId, datetime] = {}
|
||||
last_event_index_by_node: Mapping[NodeId, int] = {}
|
||||
topology: Topology = Field(default_factory=Topology)
|
||||
last_event_applied_idx: int = Field(default=-1, ge=-1)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
from typing import Iterator
|
||||
|
||||
@@ -123,7 +122,6 @@ class Worker:
|
||||
await self.event_sender.send(
|
||||
NodeGatheredInfo(
|
||||
node_id=self.node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=info,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user