diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index 9bb597cb..2f7a0276 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -1,6 +1,7 @@ import copy from collections.abc import Mapping, Sequence from datetime import datetime +from typing import cast from loguru import logger @@ -11,10 +12,8 @@ from exo.shared.types.events import ( IndexedEvent, InstanceCreated, InstanceDeleted, - NodeCreated, NodeDownloadProgress, - NodeMemoryMeasured, - NodePerformanceMeasured, + NodeGatheredInfo, NodeTimedOut, RunnerDeleted, RunnerStatusUpdated, @@ -27,13 +26,16 @@ from exo.shared.types.events import ( TopologyEdgeCreated, TopologyEdgeDeleted, ) -from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile +from exo.shared.types.profiling import NodePerformanceProfile from exo.shared.types.state import State from exo.shared.types.tasks import Task, TaskId, TaskStatus -from exo.shared.types.topology import NodeInfo +# from exo.shared.types.topology import NodeInfo from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.instances import Instance, InstanceId from exo.shared.types.worker.runners import RunnerId, RunnerStatus +from exo.utils.info_gatherer.info_gatherer import ( +MacmonMetrics , MemoryUsage, NetworkInterfaceInfo, TBIdentifier, TBConnection, NodeConfig, MiscData, StaticNodeInformation + ) def event_apply(event: Event, state: State) -> State: @@ -47,16 +49,12 @@ def event_apply(event: Event, state: State) -> State: return apply_instance_created(event, state) case InstanceDeleted(): return apply_instance_deleted(event, state) - case NodeCreated(): - return apply_topology_node_created(event, state) case NodeTimedOut(): return apply_node_timed_out(event, state) - case NodePerformanceMeasured(): - return apply_node_performance_measured(event, state) case NodeDownloadProgress(): return apply_node_download_progress(event, state) - case NodeMemoryMeasured(): - return apply_node_memory_measured(event, state) + case NodeGatheredInfo(): + return apply_node_gathered_info(event, state) case RunnerDeleted(): return apply_runner_deleted(event, state) case RunnerStatusUpdated(): @@ -188,7 +186,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State: def apply_node_timed_out(event: NodeTimedOut, state: State) -> State: - topology = copy.copy(state.topology) + topology = copy.deepcopy(state.topology) state.topology.remove_node(event.node_id) node_profiles = { key: value for key, value in state.node_profiles.items() if key != event.node_id @@ -205,101 +203,48 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State: ) -def apply_node_performance_measured( - event: NodePerformanceMeasured, state: State -) -> State: - new_profiles: Mapping[NodeId, NodePerformanceProfile] = { - **state.node_profiles, - event.node_id: event.node_profile, - } - last_seen: Mapping[NodeId, datetime] = { - **state.last_seen, - event.node_id: datetime.fromisoformat(event.when), - } - state = state.model_copy(update={"node_profiles": new_profiles}) - topology = copy.copy(state.topology) - # TODO: NodeCreated - if not topology.contains_node(event.node_id): - topology.add_node(NodeInfo(node_id=event.node_id)) - topology.update_node_profile(event.node_id, event.node_profile) - return state.model_copy( - update={ - "node_profiles": new_profiles, - "topology": topology, - "last_seen": last_seen, - } - ) +def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: + info = event.info + profile = state.node_profiles.get(event.node_id, NodePerformanceProfile()) + match info: + case MacmonMetrics(): + profile.system = info.system_profile + profile.memory = info.memory + case MemoryUsage(): + profile.memory = info + case NodeConfig(): + pass + case MiscData(): + profile.friendly_name = info.friendly_name + case StaticNodeInformation(): + profile.model_id = info.model + profile.chip_id = info.chip + # TODO: makes me slightly sad + case Sequence(): + if info != []: + match info[0]: + case NetworkInterfaceInfo(): + profile.network_interfaces = cast(Sequence[NetworkInterfaceInfo], info) + case TBIdentifier(): + profile.tb_interfaces = cast(Sequence[TBIdentifier], info) + case TBConnection(): + # TODO: + pass + last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)} + new_profiles = {**state.node_profiles, event.node_id: profile} + return state.model_copy(update={"node_profiles": new_profiles, "last_seen": last_seen}) -def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State: - existing = state.node_profiles.get(event.node_id) - topology = copy.copy(state.topology) - - if existing is None: - created = NodePerformanceProfile( - model_id="unknown", - chip_id="unknown", - friendly_name="Unknown", - memory=event.memory, - network_interfaces=[], - system=SystemPerformanceProfile( - # TODO: flops_fp16=0.0, - gpu_usage=0.0, - temp=0.0, - sys_power=0.0, - pcpu_usage=0.0, - ecpu_usage=0.0, - ane_power=0.0, - ), - ) - created_profiles: Mapping[NodeId, NodePerformanceProfile] = { - **state.node_profiles, - event.node_id: created, - } - last_seen: Mapping[NodeId, datetime] = { - **state.last_seen, - event.node_id: datetime.fromisoformat(event.when), - } - if not topology.contains_node(event.node_id): - topology.add_node(NodeInfo(node_id=event.node_id)) - # TODO: NodeCreated - topology.update_node_profile(event.node_id, created) - return state.model_copy( - update={ - "node_profiles": created_profiles, - "topology": topology, - "last_seen": last_seen, - } - ) - - updated = existing.model_copy(update={"memory": event.memory}) - updated_profiles: Mapping[NodeId, NodePerformanceProfile] = { - **state.node_profiles, - event.node_id: updated, - } - # TODO: NodeCreated - if not topology.contains_node(event.node_id): - topology.add_node(NodeInfo(node_id=event.node_id)) - topology.update_node_profile(event.node_id, updated) - return state.model_copy( - update={"node_profiles": updated_profiles, "topology": topology} - ) - - -def apply_topology_node_created(event: NodeCreated, state: State) -> State: - topology = copy.copy(state.topology) - topology.add_node(NodeInfo(node_id=event.node_id)) - return state.model_copy(update={"topology": topology}) def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State: - topology = copy.copy(state.topology) + topology = copy.deepcopy(state.topology) topology.add_connection(event.edge) return state.model_copy(update={"topology": topology}) def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State: - topology = copy.copy(state.topology) + topology = copy.deepcopy(state.topology) if not topology.contains_connection(event.edge): return state topology.remove_connection(event.edge) diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py index 29b750ef..24fc5ec0 100644 --- a/src/exo/shared/types/events.py +++ b/src/exo/shared/types/events.py @@ -2,14 +2,14 @@ from datetime import datetime from pydantic import Field -from exo.shared.topology import Connection, NodePerformanceProfile +from exo.shared.topology import Connection from exo.shared.types.chunks import GenerationChunk from exo.shared.types.common import CommandId, Id, NodeId, SessionId -from exo.shared.types.profiling import MemoryPerformanceProfile from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.instances import Instance, InstanceId from exo.shared.types.worker.runners import RunnerId, RunnerStatus +from exo.utils.info_gatherer.info_gatherer import GatheredInfo from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel @@ -76,25 +76,15 @@ class RunnerDeleted(BaseEvent): runner_id: RunnerId -# TODO -class NodeCreated(BaseEvent): - node_id: NodeId - - class NodeTimedOut(BaseEvent): node_id: NodeId -class NodePerformanceMeasured(BaseEvent): +# TODO: bikeshed this naem +class NodeGatheredInfo(BaseEvent): node_id: NodeId when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device - node_profile: NodePerformanceProfile - - -class NodeMemoryMeasured(BaseEvent): - node_id: NodeId - when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device - memory: MemoryPerformanceProfile + info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors. class NodeDownloadProgress(BaseEvent): @@ -125,10 +115,8 @@ Event = ( | InstanceDeleted | RunnerStatusUpdated | RunnerDeleted - | NodeCreated | NodeTimedOut - | NodePerformanceMeasured - | NodeMemoryMeasured + | NodeGatheredInfo | NodeDownloadProgress | ChunkGenerated | TopologyEdgeCreated diff --git a/src/exo/shared/types/profiling.py b/src/exo/shared/types/profiling.py index 5ed6e0d4..779ac08c 100644 --- a/src/exo/shared/types/profiling.py +++ b/src/exo/shared/types/profiling.py @@ -1,12 +1,14 @@ +from collections.abc import Sequence from typing import Self import psutil from exo.shared.types.memory import Memory +from exo.shared.types.thunderbolt import TBIdentifier from exo.utils.pydantic_ext import CamelCaseModel -class MemoryPerformanceProfile(CamelCaseModel): +class MemoryUsage(CamelCaseModel): ram_total: Memory ram_available: Memory swap_total: Memory @@ -44,7 +46,6 @@ class SystemPerformanceProfile(CamelCaseModel): sys_power: float = 0.0 pcpu_usage: float = 0.0 ecpu_usage: float = 0.0 - ane_power: float = 0.0 class NetworkInterfaceInfo(CamelCaseModel): @@ -53,15 +54,14 @@ class NetworkInterfaceInfo(CamelCaseModel): class NodePerformanceProfile(CamelCaseModel): - model_id: str - chip_id: str - friendly_name: str - memory: MemoryPerformanceProfile - network_interfaces: list[NetworkInterfaceInfo] = [] - system: SystemPerformanceProfile + model_id: str = "Unknown" + chip_id: str = "Unknown" + friendly_name: str = "Unknown" + memory: MemoryUsage = MemoryUsage.from_bytes(ram_total=0, ram_available=0, swap_total=0, swap_available=0) + network_interfaces: Sequence[NetworkInterfaceInfo] = [] + tb_interfaces: Sequence[TBIdentifier] = [] + system: SystemPerformanceProfile = SystemPerformanceProfile() class ConnectionProfile(CamelCaseModel): - throughput: float - latency: float - jitter: float + pass diff --git a/src/exo/shared/types/thunderbolt.py b/src/exo/shared/types/thunderbolt.py new file mode 100644 index 00000000..6def8a02 --- /dev/null +++ b/src/exo/shared/types/thunderbolt.py @@ -0,0 +1,64 @@ +import anyio +from pydantic import BaseModel + +from exo.utils.pydantic_ext import CamelCaseModel + + +class TBConnection(CamelCaseModel): + source: str + sink: str + + +class TBIdentifier(CamelCaseModel): + rdma_interface: str + domain_uuid: str + + +# Intentionally minimal, only collecting data we care about - there's a lot more + + +class TBReceptacleTag(BaseModel, extra="ignore"): + receptacle_id_key: str + + +class TBConnectivityItem(BaseModel, extra="ignore"): + domain_uuid_key: str | None + + +class TBConnectivityData(BaseModel, extra="ignore"): + domain_uuid_key: str | None + device_name_key: str + _items: list[TBConnectivityItem] | None + receptacle_1_tag: TBReceptacleTag + + def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None: + if self.domain_uuid_key is None: + return + tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}" + iface = f"rdma_{ifaces[tag]}" + return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key) + + def conn(self) -> TBConnection | None: + if self.domain_uuid_key is None or self._items is None: + return + + sink_key = next( + item.domain_uuid_key + for item in self._items + if item.domain_uuid_key is not None + ) + return TBConnection(source=self.domain_uuid_key, sink=sink_key) + + +class TBConnectivity(BaseModel): + SPThunderboltDataType: list[TBConnectivityData] + + @classmethod + async def gather(cls) -> list[TBConnectivityData] | None: + proc = await anyio.run_process( + ["system_profiler", "SPThunderboltDataType", "-json"], check=False + ) + if proc.returncode != 0: + return None + # Saving you from PascalCase while avoiding too much pydantic + return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType diff --git a/src/exo/shared/types/worker/resource_monitor.py b/src/exo/shared/types/worker/resource_monitor.py deleted file mode 100644 index b351963c..00000000 --- a/src/exo/shared/types/worker/resource_monitor.py +++ /dev/null @@ -1,43 +0,0 @@ -import asyncio -from abc import ABC, abstractmethod -from collections.abc import Coroutine -from typing import Callable - -from exo.shared.types.profiling import ( - MemoryPerformanceProfile, - SystemPerformanceProfile, -) - - -class ResourceCollector(ABC): - @abstractmethod - async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ... - - -class SystemResourceCollector(ResourceCollector): - async def collect(self) -> SystemPerformanceProfile: ... - - -class MemoryResourceCollector(ResourceCollector): - async def collect(self) -> MemoryPerformanceProfile: ... - - -class ResourceMonitor: - data_collectors: list[ResourceCollector] - effect_handlers: set[ - Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None] - ] - - async def _collect( - self, - ) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]: - tasks: list[ - Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile] - ] = [collector.collect() for collector in self.data_collectors] - return await asyncio.gather(*tasks) - - async def collect(self) -> None: - profiles = await self._collect() - for profile in profiles: - for effect_handler in self.effect_handlers: - effect_handler(profile) diff --git a/src/exo/utils/info_gatherer/__init__.py b/src/exo/utils/info_gatherer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/exo/utils/info_gatherer/info_gatherer.py b/src/exo/utils/info_gatherer/info_gatherer.py new file mode 100644 index 00000000..9fa44404 --- /dev/null +++ b/src/exo/utils/info_gatherer/info_gatherer.py @@ -0,0 +1,217 @@ +import os +import shutil +import sys +import tomllib +from dataclasses import dataclass, field +from subprocess import CalledProcessError +from typing import Self, cast +from collections.abc import Sequence + +import anyio +from anyio import create_task_group, open_process +from anyio.abc import TaskGroup +from anyio.streams.text import TextReceiveStream +from loguru import logger + +from exo.shared.constants import EXO_CONFIG_FILE +from exo.shared.types.memory import Memory +from exo.shared.types.profiling import ( + MemoryUsage, + NetworkInterfaceInfo, +) +from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier +from exo.utils.channels import Sender +from exo.utils.pydantic_ext import CamelCaseModel + +from .macmon import MacmonMetrics +from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces + +IS_DARWIN = sys.platform == "darwin" + + +class StaticNodeInformation(CamelCaseModel): + """Node information that should NEVER change, to be gathered once at startup""" + + model: str + chip: str + + @classmethod + async def gather(cls) -> Self: + model, chip = await get_model_and_chip() + return cls(model=model, chip=chip) + + +class NodeConfig(CamelCaseModel): + """Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there""" + + # TODO + + @classmethod + async def gather(cls) -> Self | None: + f = anyio.Path(EXO_CONFIG_FILE) + await f.touch(exist_ok=True) + async with await f.open("rb") as f: + try: + contents = (await f.read()).decode("utf-8") + data = tomllib.loads(contents) + return cls.model_validate(data) + except (tomllib.TOMLDecodeError, UnicodeDecodeError): + logger.warning("Invalid config file, skipping...") + return None + + +class MiscData(CamelCaseModel): + """Node information that may change that doesn't fall into the other categories""" + + friendly_name: str + + @classmethod + async def gather(cls) -> Self: + return cls(friendly_name=await get_friendly_name()) + + +async def _gather_iface_map() -> dict[str, str] | None: + proc = await anyio.run_process( + ["networksetup", "-listallhardwareports"], check=False + ) + if proc.returncode != 0: + return None + + ports: dict[str, str] = {} + port = "" + for line in proc.stdout.decode("utf-8").split("\n"): + if line.startswith("Hardware Port:"): + port = line.strip()[15:] + elif line.startswith("Device:"): + ports[port] = line.strip()[8:] + port = "" + if "" in ports: + del ports[""] + return ports + + +GatheredInfo = ( + MacmonMetrics + | MemoryUsage + | Sequence[NetworkInterfaceInfo] + | Sequence[TBIdentifier] + | Sequence[TBConnection] + | NodeConfig + | MiscData + | StaticNodeInformation +) + + +@dataclass +class InfoGatherer: + info_sender: Sender[GatheredInfo] + interface_watcher_interval: float | None = 10 + misc_poll_interval: float | None = 60 + system_profiler_interval: float | None = 5 if IS_DARWIN else None + memory_poll_rate: float | None = None if IS_DARWIN else 1 + macmon_interval: float | None = 1 if IS_DARWIN else None + _tg: TaskGroup = field(init=False, default_factory=create_task_group) + + async def run(self): + async with self._tg as tg: + if (macmon_path := shutil.which("macmon")) is not None: + tg.start_soon(self._monitor_macmon, macmon_path) + if IS_DARWIN: + tg.start_soon(self._monitor_system_profiler) + tg.start_soon(self._watch_system_info) + tg.start_soon(self._monitor_memory_usage) + tg.start_soon(self._monitor_misc) + + nc = await NodeConfig.gather() + if nc is not None: + await self.info_sender.send(nc) + sni = await StaticNodeInformation.gather() + await self.info_sender.send(sni) + + def shutdown(self): + self._tg.cancel_scope.cancel() + + async def _monitor_misc(self): + if self.misc_poll_interval is None: + return + while True: + await self.info_sender.send(await MiscData.gather()) + await anyio.sleep(self.misc_poll_interval) + + async def _monitor_system_profiler(self): + if self.system_profiler_interval is None: + return + iface_map = await _gather_iface_map() + if iface_map is None: + return + + old_idents = [] + old_conns = [] + while True: + data = await TBConnectivity.gather() + if data is None: + break + + idents = [it for i in data if (it := i.ident(iface_map)) is not None] + if idents != old_idents: + await self.info_sender.send(idents) + old_idents = idents + + conns = [it for i in data if (it := i.conn()) is not None] + if conns != old_conns: + await self.info_sender.send(conns) + old_conns = conns + + async def _monitor_memory_usage(self): + override_memory_env = os.getenv("OVERRIDE_MEMORY_MB") + override_memory: int | None = ( + Memory.from_mb(int(override_memory_env)).in_bytes + if override_memory_env + else None + ) + if self.memory_poll_rate is None: + return + while True: + await self.info_sender.send( + MemoryUsage.from_psutil(override_memory=override_memory) + ) + await anyio.sleep(self.memory_poll_rate) + + async def _watch_system_info(self): + if self.interface_watcher_interval is None: + return + old_nics = [] + while True: + nics = get_network_interfaces() + if nics != old_nics: + await self.info_sender.send(nics) + old_nics = nics + await anyio.sleep(self.interface_watcher_interval) + + async def _monitor_macmon(self, macmon_path: str): + if self.macmon_interval is None: + return + # macmon pipe --interval [interval in ms] + try: + async with await open_process( + [macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)] + ) as p: + logger.critical("MacMon closed stdout") + if not p.stdout: + return + async for text in TextReceiveStream(p.stdout): + await self.info_sender.send(MacmonMetrics.from_raw_json(text)) + + except CalledProcessError as e: + stderr_msg = "no stderr" + stderr_output = cast(bytes | str | None, e.stderr) + if stderr_output is not None: + stderr_msg = ( + stderr_output.decode() + if isinstance(stderr_output, bytes) + else str(stderr_output) + ) + logger.warning( + f"MacMon failed with return code {e.returncode}: {stderr_msg}" + ) + diff --git a/src/exo/utils/info_gatherer/macmon.py b/src/exo/utils/info_gatherer/macmon.py new file mode 100644 index 00000000..c2c3a1a0 --- /dev/null +++ b/src/exo/utils/info_gatherer/macmon.py @@ -0,0 +1,67 @@ +from typing import Self + +from pydantic import BaseModel + +from exo.utils.pydantic_ext import CamelCaseModel +from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile + + +class _TempMetrics(BaseModel, extra="ignore"): + """Temperature-related metrics returned by macmon.""" + + cpu_temp_avg: float + gpu_temp_avg: float + + +class _MemoryMetrics(BaseModel, extra="ignore"): + """Memory-related metrics returned by macmon.""" + + ram_total: int + ram_usage: int + swap_total: int + swap_usage: int + + +class RawMacmonMetrics(BaseModel, extra="ignore"): + """Complete set of metrics returned by macmon. + + Unknown fields are ignored for forward-compatibility. + """ + + timestamp: str # ignored + temp: _TempMetrics + memory: _MemoryMetrics + ecpu_usage: tuple[int, float] # freq mhz, usage % + pcpu_usage: tuple[int, float] # freq mhz, usage % + gpu_usage: tuple[int, float] # freq mhz, usage % + all_power: float + ane_power: float + cpu_power: float + gpu_power: float + gpu_ram_power: float + ram_power: float + sys_power: float + + +class MacmonMetrics(CamelCaseModel): + system_profile: SystemPerformanceProfile + memory: MemoryUsage + + @classmethod + def from_raw(cls, raw: RawMacmonMetrics) -> Self: + return cls( + system_profile = SystemPerformanceProfile( + gpu_usage=raw.gpu_usage[1], + temp=raw.temp.gpu_temp_avg, + sys_power=raw.sys_power, + pcpu_usage= raw.pcpu_usage[1], + ecpu_usage= raw.ecpu_usage[1] + ), + memory=MemoryUsage.from_bytes( + ram_total= raw.memory.ram_total, ram_available=(raw.memory.ram_total - raw.memory.ram_usage), swap_total=raw.memory.swap_total, swap_available=(raw.memory.swap_total - raw.memory.swap_usage) + ), + ) + + @classmethod + def from_raw_json(cls, json: str) -> Self: + return cls.from_raw(RawMacmonMetrics.model_validate_json(json)) diff --git a/src/exo/worker/utils/net_profile.py b/src/exo/utils/info_gatherer/net_profile.py similarity index 100% rename from src/exo/worker/utils/net_profile.py rename to src/exo/utils/info_gatherer/net_profile.py diff --git a/src/exo/worker/utils/system_info.py b/src/exo/utils/info_gatherer/system_info.py similarity index 100% rename from src/exo/worker/utils/system_info.py rename to src/exo/utils/info_gatherer/system_info.py diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 8be6947f..b1f813f5 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -16,15 +16,13 @@ from exo.shared.types.events import ( ForwarderEvent, IndexedEvent, NodeDownloadProgress, - NodeMemoryMeasured, - NodePerformanceMeasured, + NodeGatheredInfo, TaskCreated, TaskStatusUpdated, TopologyEdgeCreated, TopologyEdgeDeleted, ) from exo.shared.types.multiaddr import Multiaddr -from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile from exo.shared.types.state import State from exo.shared.types.tasks import ( CreateRunner, @@ -44,14 +42,14 @@ from exo.shared.types.worker.runners import RunnerId from exo.shared.types.worker.shards import ShardMetadata from exo.utils.channels import Receiver, Sender, channel from exo.utils.event_buffer import OrderedBuffer +from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer +from exo.utils.info_gatherer.net_profile import check_reachable from exo.worker.download.download_utils import ( map_repo_download_progress_to_download_progress_data, ) from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader from exo.worker.plan import plan from exo.worker.runner.runner_supervisor import RunnerSupervisor -from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics -from exo.worker.utils.net_profile import check_reachable class Worker: @@ -85,7 +83,7 @@ class Worker: self.state: State = State() self.download_status: dict[ShardMetadata, DownloadProgress] = {} self.runners: dict[RunnerId, RunnerSupervisor] = {} - self._tg: TaskGroup | None = None + self._tg: TaskGroup = create_task_group() self._nack_cancel_scope: CancelScope | None = None self._nack_attempts: int = 0 @@ -97,37 +95,13 @@ class Worker: async def run(self): logger.info("Starting Worker") - # TODO: CLEANUP HEADER - async def resource_monitor_callback( - node_performance_profile: NodePerformanceProfile, - ) -> None: - await self.event_sender.send( - NodePerformanceMeasured( - node_id=self.node_id, - node_profile=node_performance_profile, - when=str(datetime.now(tz=timezone.utc)), - ), - ) + info_send, info_recv = channel[GatheredInfo]() + info_gatherer: InfoGatherer = InfoGatherer(info_send) - async def memory_monitor_callback( - memory_profile: MemoryPerformanceProfile, - ) -> None: - await self.event_sender.send( - NodeMemoryMeasured( - node_id=self.node_id, - memory=memory_profile, - when=str(datetime.now(tz=timezone.utc)), - ) - ) - - # END CLEANUP - - async with create_task_group() as tg: - self._tg = tg + async with self._tg as tg: + tg.start_soon(info_gatherer.run) + tg.start_soon(self._forward_info, info_recv) tg.start_soon(self.plan_step) - tg.start_soon(start_polling_node_metrics, resource_monitor_callback) - - tg.start_soon(start_polling_memory_metrics, memory_monitor_callback) tg.start_soon(self._connection_message_event_writer) tg.start_soon(self._resend_out_for_delivery) tg.start_soon(self._event_applier) @@ -140,6 +114,17 @@ class Worker: for runner in self.runners.values(): runner.shutdown() + async def _forward_info(self, recv: Receiver[GatheredInfo]): + with recv as info_stream: + async for info in info_stream: + await self.event_sender.send( + NodeGatheredInfo( + node_id=self.node_id, + when=str(datetime.now(tz=timezone.utc)), + info=info, + ) + ) + async def _event_applier(self): with self.global_event_receiver as events: async for f_event in events: @@ -159,7 +144,6 @@ class Worker: self._nack_cancel_scope is None or self._nack_cancel_scope.cancel_called ): - assert self._tg # Request the next index. self._tg.start_soon( self._nack_request, self.state.last_event_applied_idx + 1 @@ -248,8 +232,7 @@ class Worker: await self.runners[self._task_to_runner_id(task)].start_task(task) def shutdown(self): - if self._tg: - self._tg.cancel_scope.cancel() + self._tg.cancel_scope.cancel() def _task_to_runner_id(self, task: Task): instance = self.state.instances[task.instance_id] @@ -332,7 +315,6 @@ class Worker: event_sender=self.event_sender.clone(), ) self.runners[task.bound_instance.bound_runner_id] = runner - assert self._tg self._tg.start_soon(runner.run) return runner @@ -391,7 +373,6 @@ class Worker: last_progress_time = current_time() self.shard_downloader.on_progress(download_progress_callback) - assert self._tg self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata) async def _forward_events(self) -> None: diff --git a/src/exo/worker/utils/__init__.py b/src/exo/worker/utils/__init__.py deleted file mode 100644 index 9a94e028..00000000 --- a/src/exo/worker/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .profile import start_polling_memory_metrics, start_polling_node_metrics - -__all__ = [ - "start_polling_node_metrics", - "start_polling_memory_metrics", -]