From 8799c288b0da65e0ec9fcf28768c2c2022a78159 Mon Sep 17 00:00:00 2001 From: Arbion Halili <99731180+ToxicPine@users.noreply.github.com> Date: Mon, 14 Jul 2025 21:09:08 +0100 Subject: [PATCH] BROKEN: work thus far --- .zed/settings.json | 29 +++ flake.nix | 13 ++ master/env.py | 5 + master/event_routing.py | 163 ++++++++++++++ master/logging.py | 95 ++++++++ master/main.py | 273 ++++++++++++++++++++++- master/pyproject.toml | 7 +- networking/src/networking/_core.pyi | 1 - shared/constants.py | 20 +- shared/graphs/networkx.py | 221 +++++++++++++++++++ shared/logger.py | 30 ++- shared/logging/common.py | 18 ++ shared/openai.py | 9 +- shared/pyproject.toml | 2 + shared/types/events/chunks.py | 20 +- shared/types/events/common.py | 280 +++++++++++------------- shared/types/events/events.py | 172 ++++++--------- shared/types/events/registry.py | 133 +++++++++++ shared/types/events/sanity_checking.py | 68 ++++++ shared/types/graphs/common.py | 11 +- shared/types/models/metadata.py | 3 +- shared/types/models/sources.py | 14 +- shared/types/networking/topology.py | 45 +--- shared/types/states/master.py | 58 +++-- shared/types/states/shared.py | 17 +- shared/types/states/worker.py | 4 +- shared/types/tasks/common.py | 73 +++--- shared/types/worker/downloads.py | 9 +- shared/types/worker/instances.py | 21 +- shared/types/worker/resource_monitor.py | 39 +--- shared/types/worker/runners.py | 21 +- shared/types/worker/shards.py | 20 +- uv.lock | 106 +++++++-- worker/logging.py | 13 ++ 34 files changed, 1516 insertions(+), 497 deletions(-) create mode 100644 .zed/settings.json create mode 100644 master/env.py create mode 100644 master/event_routing.py create mode 100644 master/logging.py create mode 100644 shared/graphs/networkx.py create mode 100644 shared/logging/common.py create mode 100644 shared/types/events/registry.py create mode 100644 shared/types/events/sanity_checking.py create mode 100644 worker/logging.py diff --git a/.zed/settings.json b/.zed/settings.json new file mode 100644 index 00000000..f885d7e7 --- /dev/null +++ b/.zed/settings.json @@ -0,0 +1,29 @@ +// Folder-specific settings +// +// For a full list of overridable settings, and general information on folder-specific settings, +// see the documentation: https://zed.dev/docs/configuring-zed#settings-files +{ + "lsp": { + "nix_python": { + "binary": { + "path": "nix", + "arguments": [ + "run", + "--quiet", + "--no-warn-dirty", + "--no-allow-import-from-derivation", + "--print-build-logs", + "never", + "${projectRoot}#python-lsp", + "--", + "--stdio" + ] + } + } + }, + "languages": { + "Python": { + "language_servers": ["nix_python"] + } + } +} diff --git a/flake.nix b/flake.nix index a97b3f63..006af63c 100644 --- a/flake.nix +++ b/flake.nix @@ -25,9 +25,22 @@ pkgs.rustc pkgs.cargo pkgs.basedpyright + pkgs.ruff ]; }; } ); + + apps = forAllSystems (system: + let + pkgs = import nixpkgs { inherit system; }; + in + { + python-lsp = { + type = "app"; + program = "${pkgs.basedpyright}/bin/basedpyright-langserver"; + }; + } + ); }; } \ No newline at end of file diff --git a/master/env.py b/master/env.py new file mode 100644 index 00000000..dadeee5f --- /dev/null +++ b/master/env.py @@ -0,0 +1,5 @@ +from shared.env import BaseEnv + + +class MasterEnvironmentSchema(BaseEnv): + pass diff --git a/master/event_routing.py b/master/event_routing.py new file mode 100644 index 00000000..697e0000 --- /dev/null +++ b/master/event_routing.py @@ -0,0 +1,163 @@ +from enum import StrEnum +from typing import List, LiteralString, Protocol, Literal +from logging import Logger + +from shared.types.events.common import ( + EffectHandler, + EventCategories, + EventCategory, + Event, + EventCategoryEnum, + EventFromEventLog, + EventFetcherProtocol, + State, + Apply, +) +from asyncio import Lock, Queue, Task, gather, create_task +from typing import Any, Type, TypedDict +from collections.abc import Mapping +from shared.logger import log +from shared.constants import EXO_ERROR_REPORTING_MESSAGE +from master.logging import ( + StateUpdateLoopAlreadyRunningLogEntry, + StateUpdateLoopStartedLogEntry, + StateUpdateLoopNotRunningLogEntry, + StateUpdateLoopStoppedLogEntry, + StateUpdateErrorLogEntry, + StateUpdateEffectHandlerErrorLogEntry, +) + +class QueueMapping(TypedDict): + MutatesTaskState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]] + MutatesControlPlaneState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]] + MutatesDataPlaneState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]] + MutatesInstanceState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]] + MutatesNodePerformanceState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]] + +def check_keys_in_map_match_enum_values[TEnum: StrEnum]( + mapping_type: Type[Mapping[Any, Any]], + enum: Type[TEnum], +) -> None: + mapping_keys = set(mapping_type.__annotations__.keys()) + category_values = set(e.value for e in enum) + assert mapping_keys == category_values, ( + f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}" + ) + +check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum) + +class AsyncUpdateStateFromEvents[EventCategoryT: EventCategory](Protocol): + """Protocol for services that manage a specific state domain.""" + + _task: Task[None] | None + _logger: Logger + _apply: Apply[EventCategoryT] + _default_effects: List[EffectHandler[EventCategoryT]] + extra_effects: List[EffectHandler[EventCategoryT]] + state: State[EventCategoryT] + queue: Queue[EventFromEventLog[EventCategoryT]] + lock: Lock + + def __init__( + self, + state: State[EventCategoryT], + queue: Queue[EventFromEventLog[EventCategoryT]], + extra_effects: List[EffectHandler[EventCategoryT]], + logger: Logger, + ) -> None: + """Initialise the service with its event queue.""" + self.state = state + self.queue = queue + self.extra_effects = extra_effects + self._logger = logger + self._task = None + + async def read_state(self) -> State[EventCategoryT]: + """Get a thread-safe snapshot of this service's state domain.""" + return self.state.model_copy(deep=True) + + @property + def is_running(self) -> bool: + """Check if the service's event loop is running.""" + return self._task is not None and not self._task.done() + + async def start(self) -> None: + """Start the service's event loop.""" + if self.is_running: + log(self._logger, StateUpdateLoopAlreadyRunningLogEntry()) + raise RuntimeError("State Update Loop Already Running") + log(self._logger, StateUpdateLoopStartedLogEntry()) + self._task = create_task(self._event_loop()) + + async def stop(self) -> None: + """Stop the service's event loop.""" + if not self.is_running: + log(self._logger, StateUpdateLoopNotRunningLogEntry()) + raise RuntimeError("State Update Loop Not Running") + + assert self._task is not None, ( + f"{EXO_ERROR_REPORTING_MESSAGE()}" + "BUG: is_running is True but _task is None, this should never happen!" + ) + self._task.cancel() + log(self._logger, StateUpdateLoopStoppedLogEntry()) + + async def _event_loop(self) -> None: + """Event loop for the service.""" + while True: + event = await self.queue.get() + previous_state = self.state.model_copy(deep=True) + try: + async with self.lock: + updated_state = self._apply( + self.state, + event, + ) + self.state = updated_state + except Exception as e: + log(self._logger, StateUpdateErrorLogEntry(error=e)) + raise e + try: + for effect_handler in self._default_effects + self.extra_effects: + effect_handler((previous_state, event), updated_state) + except Exception as e: + log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e)) + raise e + + +class EventRouter: + """Routes events to appropriate services based on event categories.""" + + queue_map: QueueMapping + event_fetcher: EventFetcherProtocol[EventCategory] + _logger: Logger + + async def _get_queue_by_category[T: EventCategory]( + self, category: T + ) -> Queue[Event[T]]: + """Get the queue for a given category.""" + category_str: str = category.value + queue: Queue[Event[T]] = self.queue_map[category_str] + + async def _process_events[T: EventCategory](self, category: T) -> None: + """Process events for a given domain.""" + queue: Queue[Event[T]] = await self._get_queue_by_category(category) + events_to_process: list[Event[T]] = [] + while not queue.empty(): + events_to_process.append(await queue.get()) + for event_to_process in events_to_process: + await self.queue_map[category].put(event_to_process) + return None + + async def _submit_events(self, events: list[Event[EventCategory | EventCategories]]) -> None: + """Route multiple events to their appropriate services.""" + for event in events: + for category in event.event_category: + await self._event_queues[category].put(event) + + await gather( + *[self._process_events(domain) for domain in self._event_queues.keys()] + ) + + async def _get_events_to_process(self) -> list[Event[EventCategories]]: + """Get events to process from the event fetcher.""" diff --git a/master/logging.py b/master/logging.py new file mode 100644 index 00000000..1300ca06 --- /dev/null +++ b/master/logging.py @@ -0,0 +1,95 @@ +from typing import Literal +from collections.abc import Set + +from shared.logging.common import LogEntry, LogEntryType + + +class MasterUninitializedLogEntry(LogEntry[Literal["master_uninitialized"]]): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["master_uninitialized"] = "master_uninitialized" + message: str = "No master state found, creating new one." + + +class MasterCommandReceivedLogEntry(LogEntry[Literal["master_command_received"]]): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["master_command_received"] = "master_command_received" + command_name: str + + +class MasterInvalidCommandReceivedLogEntry( + LogEntry[Literal["master_invalid_command_received"]] +): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["master_invalid_command_received"] = ( + "master_invalid_command_received" + ) + command_name: str + + +class EventCategoryUnknownLogEntry(LogEntry[Literal["event_category_unknown"]]): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["event_category_unknown"] = "event_category_unknown" + event_category: str + message: str = "Event Category Unknown, Skipping Event." + + +class StateUpdateLoopAlreadyRunningLogEntry( + LogEntry[Literal["state_update_loop_already_running"]] +): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["state_update_loop_already_running"] = ( + "state_update_loop_already_running" + ) + message: str = "State Update Loop Already Running" + + +class StateUpdateLoopStartedLogEntry(LogEntry[Literal["state_update_loop_started"]]): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["state_update_loop_started"] = "state_update_loop_started" + message: str = "State Update Loop Started" + + +class StateUpdateLoopNotRunningLogEntry( + LogEntry[Literal["state_update_loop_not_running"]] +): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["state_update_loop_not_running"] = ( + "state_update_loop_not_running" + ) + message: str = "State Update Loop Not Running" + + +class StateUpdateLoopStoppedLogEntry(LogEntry[Literal["state_update_loop_stopped"]]): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["state_update_loop_stopped"] = "state_update_loop_stopped" + message: str = "State Update Loop Stopped" + + +class StateUpdateErrorLogEntry(LogEntry[Literal["state_update_error"]]): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["state_update_error"] = "state_update_error" + error: Exception + + +class StateUpdateEffectHandlerErrorLogEntry( + LogEntry[Literal["state_update_effect_handler_error"]] +): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["state_update_effect_handler_error"] = ( + "state_update_effect_handler_error" + ) + error: Exception + + +MasterLogEntries = ( + MasterUninitializedLogEntry + | MasterCommandReceivedLogEntry + | MasterInvalidCommandReceivedLogEntry + | EventCategoryUnknownLogEntry + | StateUpdateLoopAlreadyRunningLogEntry + | StateUpdateLoopStartedLogEntry + | StateUpdateLoopNotRunningLogEntry + | StateUpdateLoopStoppedLogEntry + | StateUpdateErrorLogEntry + | StateUpdateEffectHandlerErrorLogEntry +) diff --git a/master/main.py b/master/main.py index f1c6bd53..bf7cd59c 100644 --- a/master/main.py +++ b/master/main.py @@ -1,6 +1,271 @@ -def main(): - print("Hello from master!") +from fastapi import FastAPI, Response +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field, TypeAdapter +from logging import Logger + +from shared.types.events.common import Event, EventCategories, EventFetcherProtocol, EventPublisher, State +from shared.logger import ( + configure_logger, + LogEntryType, + FilterLogByType, + create_queue_listener, + attach_to_queue, +) +from shared.types.worker.common import InstanceId +from shared.types.worker.instances import Instance +from shared.types.models.common import ModelId +from shared.types.models.model import ModelInfo +from shared.types.states.master import MasterState +from shared.constants import EXO_MASTER_STATE +from contextlib import asynccontextmanager +from logging import LogRecord +from typing import Annotated, Literal +from master.env import MasterEnvironmentSchema +from master.logging import ( + MasterUninitializedLogEntry, + MasterCommandReceivedLogEntry, + MasterInvalidCommandReceivedLogEntry, +) +from master.event_routing import AsyncUpdateStateFromEvents +from shared.logger import log +from asyncio import Lock, Task, CancelledError, Queue, create_task +from enum import Enum -if __name__ == "__main__": - main() +# Restore State +def get_master_state(logger: Logger) -> MasterState: + if EXO_MASTER_STATE.exists(): + with open(EXO_MASTER_STATE, "r") as f: + return MasterState.model_validate_json(f.read()) + else: + log(logger, MasterUninitializedLogEntry()) + return MasterState() + + +# FastAPI Dependencies +def check_env_vars_defined(data: object, logger: Logger) -> MasterEnvironmentSchema: + if not isinstance(data, MasterEnvironmentSchema): + raise RuntimeError("Environment Variables Not Found") + return data + + +def get_master_state_dependency(data: object, logger: Logger) -> MasterState: + if not isinstance(data, MasterState): + raise RuntimeError("Master State Not Found") + return data + + +class BaseExternalCommand[T: str](BaseModel): + command_type: T + + +class ChatCompletionNonStreamingCommand( + BaseExternalCommand[Literal["chat_completion_non_streaming"]] +): + command_type: Literal["chat_completion_non_streaming"] = ( + "chat_completion_non_streaming" + ) + + +ExternalCommand = Annotated[ + ChatCompletionNonStreamingCommand, Field(discriminator="command_type") +] +ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalCommand) + + +class MasterBackgroundServices(str, Enum): + MAIN_LOOP = "main_loop" + +class StateManager[T: EventCategories]: + state: State[T] + queue: Queue[Event[T]] + manager: AsyncUpdateStateFromEvents[T] + + def __init__( + self, + state: State[T], + queue: Queue[Event[T]], + ) -> None: + ... + +class MasterStateManager: + """Thread-safe manager for MasterState with independent event loop.""" + + def __init__( + self, + initial_state: MasterState, + event_processor: EventFetcherProtocol[EventCategories], + event_publisher: EventPublisher[EventCategories], + logger: Logger, + ): + self._state = initial_state + self._state_lock = Lock() + self._command_queue: Queue[ExternalCommand] = Queue() + self._services: dict[MasterBackgroundServices, Task[None]] = {} + self._logger = logger + + async def read_state(self) -> MasterState: + """Get a thread-safe snapshot of the current state.""" + async with self._state_lock: + return self._state.model_copy(deep=True) + + async def send_command( + self, command: ExternalCommand + ) -> Response | StreamingResponse: + """Send a command to the background event loop.""" + if self._services[MasterBackgroundServices.MAIN_LOOP]: + self._command_queue.put(command) + return Response(status_code=200) + else: + raise RuntimeError("State manager is not running") + + async def start(self) -> None: + """Start the background event loop.""" + for service in MasterBackgroundServices: + match service: + case MasterBackgroundServices.MAIN_LOOP: + if self._services[service]: + raise RuntimeError("State manager is already running") + self._services[service]: Task[None] = create_task(self._event_loop()) + log(self._logger, MasterStateManagerStartedLogEntry()) + case _: + raise ValueError(f"Unknown service: {service}") + + async def stop(self) -> None: + """Stop the background event loop and persist state.""" + if not self._services[MasterBackgroundServices.MAIN_LOOP]: + raise RuntimeError("State manager is not running") + + for service in self._services.values(): + service.cancel() + try: + await service + except CancelledError: + pass + + log(self._logger, MasterStateManagerStoppedLogEntry()) + + async def _event_loop(self) -> None: + """Independent event loop for processing commands and mutating state.""" + while True: + try: + async with self._state_lock: + match EventCategories: + case EventCategories.InstanceEventTypes: + events_one = self._event_processor.get_events_to_apply( + self._state.data_plane_network_state + ) + case EventCategories.InstanceStateEventTypes: + events_one = self._event_processor.get_events_to_apply( + self._state.control_plane_network_state + ) + case _: + raise ValueError( + f"Unknown event category: {event_category}" + ) + command = self._command_queue.get(timeout=5.0) + match command: + case ChatCompletionNonStreamingCommand(): + log( + self._logger, + MasterCommandReceivedLogEntry( + command_name=command.command_type + ), + ) + case _: + log( + self._logger, + MasterInvalidCommandReceivedLogEntry( + command_name=command.command_type + ), + ) + except CancelledError: + break + except Exception as e: + log(self._logger, MasterStateManagerErrorLogEntry(error=str(e))) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger = configure_logger("master") + + telemetry_queue: Queue[LogRecord] = Queue() + metrics_queue: Queue[LogRecord] = Queue() + cluster_queue: Queue[LogRecord] = Queue() + + attach_to_queue( + logger, + [ + FilterLogByType(log_types={LogEntryType.telemetry}), + ], + telemetry_queue, + ) + attach_to_queue( + logger, + [ + FilterLogByType(log_types={LogEntryType.metrics}), + ], + metrics_queue, + ) + attach_to_queue( + logger, + [ + FilterLogByType(log_types={LogEntryType.cluster}), + ], + cluster_queue, + ) + + # TODO: Add handlers + telemetry_listener = create_queue_listener(telemetry_queue, []) + metrics_listener = create_queue_listener(metrics_queue, []) + cluster_listener = create_queue_listener(cluster_queue, []) + + telemetry_listener.start() + metrics_listener.start() + cluster_listener.start() + + initial_state = get_master_state(logger) + app.state.master_state_manager = MasterStateManager(initial_state, logger) + await app.state.master_state_manager.start() + + yield + + await app.state.master_state_manager.stop() + + +app = FastAPI(lifespan=lifespan) + + +@app.get("/topology/control_plane") +def get_control_plane_topology(): + return {"message": "Hello, World!"} + + +@app.get("/topology/data_plane") +def get_data_plane_topology(): + return {"message": "Hello, World!"} + + +@app.get("/instances/list") +def list_instances(): + return {"message": "Hello, World!"} + + +@app.post("/instances/create") +def create_instance(model_id: ModelId) -> InstanceId: ... + + +@app.get("/instance/{instance_id}/read") +def get_instance(instance_id: InstanceId) -> Instance: ... + + +@app.delete("/instance/{instance_id}/delete") +def remove_instance(instance_id: InstanceId) -> None: ... + + +@app.get("/model/{model_id}/metadata") +def get_model_data(model_id: ModelId) -> ModelInfo: ... + + +@app.post("/model/{model_id}/instances") +def get_instances_by_model(model_id: ModelId) -> list[Instance]: ... diff --git a/master/pyproject.toml b/master/pyproject.toml index 8410b18f..b8912679 100644 --- a/master/pyproject.toml +++ b/master/pyproject.toml @@ -4,7 +4,10 @@ version = "0.1.0" description = "Master service for the Exo project" readme = "README.md" requires-python = ">=3.13" -dependencies = ["exo-shared"] +dependencies = [ + "exo-shared", + "fastapi>=0.116.0", +] [build-system] requires = ["hatchling"] @@ -21,4 +24,4 @@ exclude = ["*.md", "pyproject.toml"] [tool.hatch.build.targets.sdist] packages = [] include = ["*"] -exclude = ["*.md", "pyproject.toml"] \ No newline at end of file +exclude = ["*.md", "pyproject.toml"] diff --git a/networking/src/networking/_core.pyi b/networking/src/networking/_core.pyi index d52129eb..e69de29b 100644 --- a/networking/src/networking/_core.pyi +++ b/networking/src/networking/_core.pyi @@ -1 +0,0 @@ -def hello_from_bin() -> str: ... diff --git a/shared/constants.py b/shared/constants.py index 5410f899..82ffd6c1 100644 --- a/shared/constants.py +++ b/shared/constants.py @@ -1,11 +1,27 @@ from pathlib import Path +import inspect EXO_HOME = Path.home() / ".exo" EXO_EVENT_DB = EXO_HOME / "event_db.sqlite3" -EXO_MASTER_CONFIG = EXO_HOME / "master.json" -EXO_WORKER_CONFIG = EXO_HOME / "worker.json" +EXO_MASTER_STATE = EXO_HOME / "master_state.json" +EXO_WORKER_STATE = EXO_HOME / "worker_state.json" EXO_MASTER_LOG = EXO_HOME / "master.log" EXO_WORKER_LOG = EXO_HOME / "worker.log" EXO_WORKER_KEYRING_FILE = EXO_HOME / "worker_keyring" EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring" + + +# little helper function to get the name of the module that raised the error +def get_caller_module_name() -> str: + frm = inspect.stack()[1] + mod = inspect.getmodule(frm[0]) + if mod is None: + return "UNKNOWN MODULE" + return mod.__name__ + + +EXO_ERROR_REPORTING_MESSAGE = lambda: ( + f"THIS IS A BUG IN THE EXO SOFTWARE, PLEASE REPORT IT AT https://github.com/exo-explore/exo/\n" + f"The module that raised the error was: {get_caller_module_name()}" +) diff --git a/shared/graphs/networkx.py b/shared/graphs/networkx.py new file mode 100644 index 00000000..0ab7ee81 --- /dev/null +++ b/shared/graphs/networkx.py @@ -0,0 +1,221 @@ +from typing import Set, Mapping +from dataclasses import dataclass +from pydantic import TypeAdapter + +import rustworkx as rx + +from shared.types.graphs.common import ( + Edge, + EdgeData, + MutableGraphProtocol, + Vertex, + VertexData, + EdgeIdT, + VertexIdT, + EdgeTypeT, + VertexTypeT, +) + + +@dataclass(frozen=True) +class _VertexWrapper[VertexTypeT, VertexIdT]: + """Internal wrapper to store vertex ID alongside vertex data.""" + + vertex_id: VertexIdT + vertex_data: VertexData[VertexTypeT] + + +@dataclass(frozen=True) +class _EdgeWrapper[EdgeTypeT, EdgeIdT]: + """Internal wrapper to store edge ID alongside edge data.""" + + edge_id: EdgeIdT + edge_data: EdgeData[EdgeTypeT] + + +class NetworkXGraph(MutableGraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]): + edge_base: TypeAdapter[EdgeTypeT] + vertex_base: TypeAdapter[VertexTypeT] + + _graph: rx.PyDiGraph[ + _VertexWrapper[VertexTypeT, VertexIdT], _EdgeWrapper[EdgeTypeT, EdgeIdT] + ] + _vertex_id_to_index: dict[VertexIdT, int] + _edge_id_to_endpoints: dict[EdgeIdT, tuple[int, int]] + + def __init__( + self, edge_base: TypeAdapter[EdgeTypeT], vertex_base: TypeAdapter[VertexTypeT] + ) -> None: + self.edge_base = edge_base + self.vertex_base = vertex_base + self._graph = rx.PyDiGraph() + self._vertex_id_to_index = {} + self._edge_id_to_endpoints = {} + + ### + # GraphProtocol methods + ### + + def list_edges(self) -> Set[EdgeIdT]: + return {edge.edge_id for edge in self._graph.edges()} + + def list_vertices(self) -> Set[VertexIdT]: + return {node.vertex_id for node in self._graph.nodes()} + + def get_vertices_from_edges( + self, edges: Set[EdgeIdT] + ) -> Mapping[EdgeIdT, Set[VertexIdT]]: + result: dict[EdgeIdT, Set[VertexIdT]] = {} + + for edge_id in edges: + if edge_id in self._edge_id_to_endpoints: + u_idx, v_idx = self._edge_id_to_endpoints[edge_id] + u_data = self._graph.get_node_data(u_idx) + v_data = self._graph.get_node_data(v_idx) + result[edge_id] = {u_data.vertex_id, v_data.vertex_id} + + return result + + def get_edges_from_vertices( + self, vertices: Set[VertexIdT] + ) -> Mapping[VertexIdT, Set[EdgeIdT]]: + result: dict[VertexIdT, Set[EdgeIdT]] = {} + + for vertex_id in vertices: + if vertex_id in self._vertex_id_to_index: + vertex_idx = self._vertex_id_to_index[vertex_id] + edge_ids: Set[EdgeIdT] = set() + + # Get outgoing edges + for _, _, edge_data in self._graph.out_edges(vertex_idx): + edge_ids.add(edge_data.edge_id) + + # Get incoming edges + for _, _, edge_data in self._graph.in_edges(vertex_idx): + edge_ids.add(edge_data.edge_id) + + result[vertex_id] = edge_ids + + return result + + def get_edge_data( + self, edges: Set[EdgeIdT] + ) -> Mapping[EdgeIdT, EdgeData[EdgeTypeT]]: + result: dict[EdgeIdT, EdgeData[EdgeTypeT]] = {} + + for edge_id in edges: + if edge_id in self._edge_id_to_endpoints: + u_idx, v_idx = self._edge_id_to_endpoints[edge_id] + edge_wrapper = self._graph.get_edge_data(u_idx, v_idx) + result[edge_id] = edge_wrapper.edge_data + + return result + + def get_vertex_data( + self, vertices: Set[VertexIdT] + ) -> Mapping[VertexIdT, VertexData[VertexTypeT]]: + result: dict[VertexIdT, VertexData[VertexTypeT]] = {} + + for vertex_id in vertices: + if vertex_id in self._vertex_id_to_index: + vertex_idx = self._vertex_id_to_index[vertex_id] + vertex_wrapper = self._graph.get_node_data(vertex_idx) + result[vertex_id] = vertex_wrapper.vertex_data + + return result + + ### + # MutableGraphProtocol methods + ### + + def check_edges_exists(self, edge_id: EdgeIdT) -> bool: + return edge_id in self._edge_id_to_endpoints + + def check_vertex_exists(self, vertex_id: VertexIdT) -> bool: + return vertex_id in self._vertex_id_to_index + + def _add_edge(self, edge_id: EdgeIdT, edge_data: EdgeData[EdgeTypeT]) -> None: + # This internal method is not used in favor of a safer `attach_edge` implementation. + raise NotImplementedError( + "Use attach_edge to add edges. The internal _add_edge protocol method is flawed." + ) + + def _add_vertex( + self, vertex_id: VertexIdT, vertex_data: VertexData[VertexTypeT] + ) -> None: + if vertex_id not in self._vertex_id_to_index: + wrapper = _VertexWrapper(vertex_id=vertex_id, vertex_data=vertex_data) + idx = self._graph.add_node(wrapper) + self._vertex_id_to_index[vertex_id] = idx + + def _remove_edge(self, edge_id: EdgeIdT) -> None: + if edge_id in self._edge_id_to_endpoints: + u_idx, v_idx = self._edge_id_to_endpoints[edge_id] + self._graph.remove_edge(u_idx, v_idx) + del self._edge_id_to_endpoints[edge_id] + else: + raise ValueError(f"Edge with id {edge_id} not found.") + + def _remove_vertex(self, vertex_id: VertexIdT) -> None: + if vertex_id in self._vertex_id_to_index: + vertex_idx = self._vertex_id_to_index[vertex_id] + + # Remove any edges connected to this vertex from our mapping + edges_to_remove: list[EdgeIdT] = [] + for edge_id, (u_idx, v_idx) in self._edge_id_to_endpoints.items(): + if u_idx == vertex_idx or v_idx == vertex_idx: + edges_to_remove.append(edge_id) + + for edge_id in edges_to_remove: + del self._edge_id_to_endpoints[edge_id] + + # Remove the vertex from the graph + self._graph.remove_node(vertex_idx) + del self._vertex_id_to_index[vertex_id] + else: + raise ValueError(f"Vertex with id {vertex_id} not found.") + + def attach_edge( + self, + edge: Edge[EdgeTypeT, EdgeIdT, VertexIdT], + extra_vertex: Vertex[VertexTypeT, EdgeIdT, VertexIdT] | None = None, + ) -> None: + """ + Attaches an edge to the graph, overriding the default protocol implementation. + + This implementation corrects a flaw in the protocol's `_add_edge` + signature and provides more intuitive behavior when connecting existing vertices. + """ + base_vertex_id, target_vertex_id = edge.edge_vertices + + if not self.check_vertex_exists(base_vertex_id): + raise ValueError(f"Base vertex {base_vertex_id} does not exist.") + + target_vertex_exists = self.check_vertex_exists(target_vertex_id) + + if not target_vertex_exists: + if extra_vertex is None: + raise ValueError( + f"Target vertex {target_vertex_id} does not exist and no `extra_vertex` was provided." + ) + if extra_vertex.vertex_id != target_vertex_id: + raise ValueError( + f"The ID of `extra_vertex` ({extra_vertex.vertex_id}) does not match " + f"the target vertex ID of the edge ({target_vertex_id})." + ) + self._add_vertex(extra_vertex.vertex_id, extra_vertex.vertex_data) + elif extra_vertex is not None: + raise ValueError( + f"Target vertex {target_vertex_id} already exists, but `extra_vertex` was provided." + ) + + # Get the internal indices + base_idx = self._vertex_id_to_index[base_vertex_id] + target_idx = self._vertex_id_to_index[target_vertex_id] + + # Create edge wrapper and add to graph + edge_wrapper = _EdgeWrapper(edge_id=edge.edge_id, edge_data=edge.edge_data) + self._graph.add_edge(base_idx, target_idx, edge_wrapper) + + # Store the mapping + self._edge_id_to_endpoints[edge.edge_id] = (base_idx, target_idx) diff --git a/shared/logger.py b/shared/logger.py index 659f551e..eff188c6 100644 --- a/shared/logger.py +++ b/shared/logger.py @@ -1,31 +1,31 @@ import logging import logging.handlers from collections.abc import Sequence, Set -from enum import Enum from queue import Queue -from pydantic import BaseModel from rich.logging import RichHandler +from typing import Annotated +from pydantic import Field, TypeAdapter -class LogEntryType(str, Enum): - telemetry = "telemetry" - metrics = "metrics" - cluster = "cluster" +from shared.logging.common import LogEntryType +from master.logging import MasterLogEntries +from worker.logging import WorkerLogEntries + +LogEntries = Annotated[ + MasterLogEntries | WorkerLogEntries, Field(discriminator="entry_type") +] +LogParser: TypeAdapter[LogEntries] = TypeAdapter(LogEntries) -class LogEntry(BaseModel): - event_type: Set[LogEntryType] - - -class LogFilterByType(logging.Filter): +class FilterLogByType(logging.Filter): def __init__(self, log_types: Set[LogEntryType]): super().__init__() self.log_types = log_types def filter(self, record: logging.LogRecord) -> bool: message = record.getMessage() - LogEntry.model_validate_json(message) + LogParser.validate_json(message) return True @@ -79,3 +79,9 @@ def create_queue_listener( log_queue, *effect_handlers, respect_handler_level=True ) return listener + + +def log( + logger: logging.Logger, log_entry: LogEntries, log_level: int = logging.INFO +) -> None: + logger.log(log_level, log_entry.model_dump_json()) diff --git a/shared/logging/common.py b/shared/logging/common.py new file mode 100644 index 00000000..215068c9 --- /dev/null +++ b/shared/logging/common.py @@ -0,0 +1,18 @@ +from enum import Enum +from typing import Generic, TypeVar +from pydantic import BaseModel + +from collections.abc import Set + +LogEntryTypeT = TypeVar("LogEntryTypeT", bound=str) + + +class LogEntryType(str, Enum): + telemetry = "telemetry" + metrics = "metrics" + cluster = "cluster" + + +class LogEntry(BaseModel, Generic[LogEntryTypeT]): + entry_destination: Set[LogEntryType] + entry_type: LogEntryTypeT diff --git a/shared/openai.py b/shared/openai.py index 0a0a546f..ed651356 100644 --- a/shared/openai.py +++ b/shared/openai.py @@ -13,8 +13,11 @@ else: FinishReason: TypeAlias = Literal[ "stop", "length", "tool_calls", "content_filter", "function_call" ] -assert ( - get_type_hints(chat.chat_completion_chunk.Choice)["finish_reason"] == FinishReason -), "Upstream changed Choice.finish_reason; update FinishReason alias." + +if TYPE_CHECKING: + assert ( + get_type_hints(chat.chat_completion_chunk.Choice)["finish_reason"] + == FinishReason + ), "Upstream changed Choice.finish_reason; update FinishReason alias." __all__ = ["types", "chat", "FinishReason"] diff --git a/shared/pyproject.toml b/shared/pyproject.toml index d4ee919e..6602478a 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -5,11 +5,13 @@ description = "Shared utilities for the Exo project" readme = "README.md" requires-python = ">=3.13" dependencies = [ + "networkx>=3.5", "openai>=1.93.0", "pathlib>=1.0.1", "protobuf>=6.31.1", "pydantic>=2.11.7", "rich>=14.0.0", + "rustworkx>=0.16.0", ] [build-system] diff --git a/shared/types/events/chunks.py b/shared/types/events/chunks.py index e75d6e1e..ed52b008 100644 --- a/shared/types/events/chunks.py +++ b/shared/types/events/chunks.py @@ -1,28 +1,22 @@ from enum import Enum -from typing import Annotated, Generic, Literal, TypeVar +from typing import Annotated, Literal -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +# from openai.types.chat.chat_completion import ChatCompletion +# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel, Field, TypeAdapter from shared.openai import FinishReason from shared.types.models.common import ModelId from shared.types.tasks.common import TaskId -OpenAIResponse = ( - ChatCompletion | ChatCompletionChunk -) ## Currently we only support chat completions - class ChunkType(str, Enum): token = "token" image = "image" -ChunkT = TypeVar("ChunkT", bound=ChunkType) - - -class BaseChunk(BaseModel, Generic[ChunkT]): +class BaseChunk[ChunkTypeT: ChunkType](BaseModel): + chunk_type: ChunkTypeT task_id: TaskId idx: int model: ModelId @@ -59,6 +53,10 @@ class ImageChunk(BaseChunk[ChunkType.image]): GenerationChunk = Annotated[TokenChunk | ImageChunk, Field(discriminator="chunk_type")] GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(GenerationChunk) +## OpenAIResponse = ( +## ChatCompletion | ChatCompletionChunk +## ) ## Currently we only support chat completions + # my_chunk: dict[str, Any] = TokenChunk( # task_id=TaskId('nicerid'), # idx=0, diff --git a/shared/types/events/common.py b/shared/types/events/common.py index 6e5f78cf..a0abc252 100644 --- a/shared/types/events/common.py +++ b/shared/types/events/common.py @@ -1,17 +1,24 @@ -from enum import Enum, auto +from enum import Enum, StrEnum from typing import ( Annotated, - Callable, - Generic, - Protocol, - Sequence, - Tuple, - TypeVar, + Any, + FrozenSet, + Literal, + NamedTuple, + cast, ) -from pydantic import BaseModel, Field, TypeAdapter, model_validator +import annotated_types + +from shared.types.events.sanity_checking import ( + check_event_type_union_is_consistent_with_registry, + assert_literal_union_covers_enum, +) + +from pydantic import BaseModel, Field, model_validator from shared.types.common import NewUUID, NodeId +from typing import Callable, Sequence, Protocol class EventId(NewUUID): @@ -22,6 +29,8 @@ class TimerId(NewUUID): pass +# Here are all the unique kinds of events that can be sent over the network. +# I've defined them in different enums for clarity, but they're all part of the same set of possible events. class MLXEventTypes(str, Enum): MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare" MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare" @@ -29,7 +38,7 @@ class MLXEventTypes(str, Enum): class TaskEventTypes(str, Enum): TaskCreated = "TaskCreated" - TaskUpdated = "TaskUpdated" + TaskStateUpdated = "TaskStateUpdated" TaskDeleted = "TaskDeleted" @@ -40,22 +49,20 @@ class StreamingEventTypes(str, Enum): class InstanceEventTypes(str, Enum): InstanceCreated = "InstanceCreated" InstanceDeleted = "InstanceDeleted" - InstanceToBeReplacedAtomically = "InstanceToBeReplacedAtomically" InstanceReplacedAtomically = "InstanceReplacedAtomically" - InstanceStatusUpdated = "InstanceStatusUpdated" class InstanceStateEventTypes(str, Enum): - InstanceRunnerStateUpdated = "InstanceRunnerStateUpdated" + InstanceSagaRunnerStateUpdated = "InstanceSagaRunnerStateUpdated" class NodePerformanceEventTypes(str, Enum): - NodePerformanceProfiled = "NodePerformanceProfiled" + NodePerformanceMeasured = "NodePerformanceMeasured" class DataPlaneEventTypes(str, Enum): DataPlaneEdgeCreated = "DataPlaneEdgeCreated" - DataPlaneEdgeProfiled = "DataPlaneEdgeProfiled" + DataPlaneEdgeReplacedAtomically = "DataPlaneEdgeReplacedAtomically" DataPlaneEdgeDeleted = "DataPlaneEdgeDeleted" @@ -70,168 +77,132 @@ class TimerEventTypes(str, Enum): TimerFired = "TimerFired" -class ResourceEventTypes(str, Enum): - ResourceProfiled = "ResourceProfiled" +# Registry of all event type enums +EVENT_TYPE_ENUMS = [ + TaskEventTypes, + StreamingEventTypes, + InstanceEventTypes, + InstanceStateEventTypes, + NodePerformanceEventTypes, + DataPlaneEventTypes, + ControlPlaneEventTypes, + TimerEventTypes, + MLXEventTypes, +] -class EventCategories(str, Enum): - TaskEventTypes = auto() - StreamingEventTypes = auto() - InstanceEventTypes = auto() - InstanceStateEventTypes = auto() - NodePerformanceEventTypes = auto() - ControlPlaneEventTypes = auto() - DataPlaneEventTypes = auto() - TimerEventTypes = auto() - MLXEventTypes = auto() - - -PossibleEventOfEventTypeT = TypeVar("PossibleEventOfEventTypeT", bound=Enum) - -# T=(A|B) <: U=(A|B|C) ==> Event[A|B] <: Event[A|BCategoryOfEventsT_cov = TypeVar(name="CategoryOfEventsT_cov", bound=EventCategories, covariant=True) -CategoryOfEventsT_cov = TypeVar( - name="CategoryOfEventsT_cov", bound=EventCategories, contravariant=True -) -CategoryOfEventsT_con = TypeVar( - name="CategoryOfEventsT_con", bound=EventCategories, contravariant=True -) -CategoryOfEventsT_inv = TypeVar( - name="CategoryOfEventsT_inv", - bound=EventCategories, - covariant=False, - contravariant=False, +# Here's the set of all possible events. +EventTypes = ( + TaskEventTypes + | StreamingEventTypes + | InstanceEventTypes + | InstanceStateEventTypes + | NodePerformanceEventTypes + | ControlPlaneEventTypes + | DataPlaneEventTypes + | TimerEventTypes + | MLXEventTypes ) -class Event(BaseModel, Generic[PossibleEventOfEventTypeT]): - event_type: PossibleEventOfEventTypeT - event_category: EventCategories +check_event_type_union_is_consistent_with_registry(EVENT_TYPE_ENUMS, EventTypes) + + +class EventCategoryEnum(StrEnum): + MutatesTaskState = "MutatesTaskState" + MutatesInstanceState = "MutatesInstanceState" + MutatesNodePerformanceState = "MutatesNodePerformanceState" + MutatesControlPlaneState = "MutatesControlPlaneState" + MutatesDataPlaneState = "MutatesDataPlaneState" + + +EventCategory = ( + Literal[EventCategoryEnum.MutatesControlPlaneState] + | Literal[EventCategoryEnum.MutatesTaskState] + | Literal[EventCategoryEnum.MutatesInstanceState] + | Literal[EventCategoryEnum.MutatesNodePerformanceState] + | Literal[EventCategoryEnum.MutatesDataPlaneState] +) + +EventCategories = FrozenSet[EventCategory] + +assert_literal_union_covers_enum(EventCategory, EventCategoryEnum) + +class Event[SetMembersT: EventCategories | EventCategory](BaseModel): + event_type: EventTypes + event_category: SetMembersT event_id: EventId - def check_origin_id(self, origin_id: NodeId) -> bool: - return True + def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: ... -class TaskEvent(Event[TaskEventTypes]): - event_type: TaskEventTypes - - -class InstanceEvent(Event[InstanceEventTypes]): - event_type: InstanceEventTypes - - -class InstanceStateEvent(Event[InstanceStateEventTypes]): - event_type: InstanceStateEventTypes - - -class MLXEvent(Event[MLXEventTypes]): - event_type: MLXEventTypes - - -class NodePerformanceEvent(Event[NodePerformanceEventTypes]): - event_type: NodePerformanceEventTypes - - -class ControlPlaneEvent(Event[ControlPlaneEventTypes]): - event_type: ControlPlaneEventTypes - - -class StreamingEvent(Event[StreamingEventTypes]): - event_type: StreamingEventTypes - - -class DataPlaneEvent(Event[DataPlaneEventTypes]): - event_type: DataPlaneEventTypes - - -class TimerEvent(Event[TimerEventTypes]): - event_type: TimerEventTypes - - -class ResourceEvent(Event[ResourceEventTypes]): - event_type: ResourceEventTypes - - -class WrappedMessage(BaseModel, Generic[PossibleEventOfEventTypeT]): - message: Event[PossibleEventOfEventTypeT] - origin_id: NodeId +class EventFromEventLog[SetMembersT: EventCategories | EventCategory](BaseModel): + event: Event[SetMembersT] + origin: NodeId + idx_in_log: int = Field(gt=0) @model_validator(mode="after") - def check_origin_id(self) -> "WrappedMessage[PossibleEventOfEventTypeT]": - if self.message.check_origin_id(self.origin_id): + def check_event_was_sent_by_correct_node( + self, + ) -> "EventFromEventLog[SetMembersT]": + if self.event.check_event_was_sent_by_correct_node(self.origin): return self raise ValueError("Invalid Event: Origin ID Does Not Match") -class PersistedEvent(BaseModel, Generic[PossibleEventOfEventTypeT]): - event: Event[PossibleEventOfEventTypeT] - sequence_number: int = Field(gt=0) +def narrow_event_type[T: EventCategory]( + event: Event[EventCategories], + target_category: T, +) -> Event[T]: + if target_category not in event.event_category: + raise ValueError(f"Event Does Not Contain Target Category {target_category}") + + narrowed_event = event.model_copy(update={"event_category": {target_category}}) + return cast(Event[T], narrowed_event) -class State(BaseModel, Generic[CategoryOfEventsT_cov]): - event_category: CategoryOfEventsT_cov - sequence_number: int = Field(default=0, ge=0) +class State[EventCategoryT: EventCategory](BaseModel): + event_category: EventCategoryT + last_event_applied_idx: int = Field(default=0, ge=0) -AnnotatedEventType = Annotated[ - Event[EventCategories], Field(discriminator="event_category") +# Definitions for Type Variables +type Saga[EventCategoryT: EventCategory] = Callable[ + [State[EventCategoryT], EventFromEventLog[EventCategoryT]], + Sequence[Event[EventCategories]], ] -EventTypeParser: TypeAdapter[AnnotatedEventType] = TypeAdapter(AnnotatedEventType) - - -# it's not possible to enforce this at compile time, so we have to do it at runtime -def mock_todo[T](something: T | None) -> T: ... - - -def apply( - state: State[CategoryOfEventsT_inv], event: Event[CategoryOfEventsT_inv] -) -> State[CategoryOfEventsT_inv]: ... - - -# T=(A|B) <: U=(A|B|C) ==> Apply[A|B] <: Apply[A|B|C] -SagaApplicator = Callable[ - [State[CategoryOfEventsT_inv], Event[CategoryOfEventsT_inv]], - Sequence[Event[CategoryOfEventsT_inv]], +type Apply[EventCategoryT: EventCategory] = Callable[ + [State[EventCategoryT], EventFromEventLog[EventCategoryT]], + State[EventCategoryT], ] -Saga = Callable[ - [State[CategoryOfEventsT_inv], Event[CategoryOfEventsT_inv]], - Sequence[Event[CategoryOfEventsT_inv]], + + +class StateAndEvent[EventCategoryT: EventCategory](NamedTuple): + state: State[EventCategoryT] + event: EventFromEventLog[EventCategoryT] + + +type EffectHandler[EventCategoryT: EventCategory] = Callable[ + [StateAndEvent[EventCategoryT], State[EventCategoryT]], None ] -Apply = Callable[ - [State[CategoryOfEventsT_inv], Event[CategoryOfEventsT_inv]], - State[CategoryOfEventsT_inv], -] -StateAndEvent = Tuple[State[CategoryOfEventsT_inv], Event[CategoryOfEventsT_inv]] -EffectHandler = Callable[ - [StateAndEvent[CategoryOfEventsT_inv], State[CategoryOfEventsT_inv]], None -] -EventPublisher = Callable[[Event[CategoryOfEventsT_inv]], None] +type EventPublisher = Callable[[Event[Any]], None] -class MutableState[EventCategoryT: EventCategories](Protocol): - def apply( - self, - event: Event[EventCategoryT], - applicator: Apply[EventCategoryT], - effect_handlers: Sequence[EffectHandler[EventCategoryT]], - ) -> None: ... - - -class EventOutbox(Protocol): +# A component that can publish events +class EventPublisherProtocol(Protocol): def send(self, events: Sequence[Event[EventCategories]]) -> None: ... -# -# T=[A|B] <: U=[A|B|C] => EventProcessor[A|B] :> EventProcessor[A|B|C] -# -class EventProcessor[EventCategoryT: EventCategories](Protocol): +# A component that can fetch events to apply +class EventFetcherProtocol[EventCategoryT: EventCategory](Protocol): def get_events_to_apply( self, state: State[EventCategoryT] ) -> Sequence[Event[EventCategoryT]]: ... -def get_saga_effect_handler[EventCategoryT: EventCategories]( - saga: Saga[EventCategoryT], event_publisher: EventPublisher[EventCategoryT] +# A component that can get the effect handler for a saga +def get_saga_effect_handler[EventCategoryT: EventCategory]( + saga: Saga[EventCategoryT], event_publisher: EventPublisher ) -> EffectHandler[EventCategoryT]: def effect_handler(state_and_event: StateAndEvent[EventCategoryT]) -> None: trigger_state, trigger_event = state_and_event @@ -241,14 +212,16 @@ def get_saga_effect_handler[EventCategoryT: EventCategories]( return lambda state_and_event, _: effect_handler(state_and_event) -def get_effects_from_sagas[EventCategoryT: EventCategories]( +def get_effects_from_sagas[EventCategoryT: EventCategory]( sagas: Sequence[Saga[EventCategoryT]], - event_publisher: EventPublisher[EventCategoryT], + event_publisher: EventPublisher, ) -> Sequence[EffectHandler[EventCategoryT]]: return [get_saga_effect_handler(saga, event_publisher) for saga in sagas] -IdemKeyGenerator = Callable[[State[CategoryOfEventsT_cov], int], Sequence[EventId]] +type IdemKeyGenerator[EventCategoryT: EventCategory] = Callable[ + [State[EventCategoryT], int], Sequence[EventId] +] class CommandId(NewUUID): @@ -261,14 +234,15 @@ class CommandTypes(str, Enum): Delete = "Delete" -class Command[EventCategoryT: EventCategories, CommandType: CommandTypes](BaseModel): +class Command[ + EventCategoryT: EventCategories | EventCategory, + CommandType: CommandTypes, +](BaseModel): command_type: CommandType command_id: CommandId -CommandTypeT = TypeVar("CommandTypeT", bound=CommandTypes, covariant=True) - -Decide = Callable[ - [State[CategoryOfEventsT_cov], Command[CategoryOfEventsT_cov, CommandTypeT]], - Sequence[Event[CategoryOfEventsT_cov]], +type Decide[EventCategoryT: EventCategory, CommandTypeT: CommandTypes] = Callable[ + [State[EventCategoryT], Command[EventCategoryT, CommandTypeT]], + Sequence[Event[EventCategoryT]], ] diff --git a/shared/types/events/events.py b/shared/types/events/events.py index 1f6422c8..fbb19798 100644 --- a/shared/types/events/events.py +++ b/shared/types/events/events.py @@ -1,33 +1,22 @@ from __future__ import annotations -from typing import Any, Literal, Tuple - -from pydantic import BaseModel +from typing import Literal, Tuple from shared.types.common import NodeId from shared.types.events.common import ( - ControlPlaneEvent, + Event, + EventTypes, + EventCategoryEnum, ControlPlaneEventTypes, - DataPlaneEvent, DataPlaneEventTypes, - InstanceEvent, InstanceEventTypes, - InstanceStateEvent, InstanceStateEventTypes, - MLXEvent, MLXEventTypes, - NodePerformanceEvent, NodePerformanceEventTypes, - ResourceEvent, - ResourceEventTypes, - StreamingEvent, StreamingEventTypes, - TaskEvent, TaskEventTypes, - TimerEvent, - TimerEventTypes, - TimerId, ) +from shared.types.events.chunks import GenerationChunk from shared.types.networking.control_plane import ( ControlPlaneEdgeId, ControlPlaneEdgeType, @@ -37,149 +26,132 @@ from shared.types.networking.data_plane import ( DataPlaneEdgeId, DataPlaneEdgeProfile, ) -from shared.types.profiling.common import NodePerformanceProfile, ProfiledResourceName +from shared.types.profiling.common import NodePerformanceProfile from shared.types.tasks.common import ( - TaskData, TaskId, + TaskParams, TaskState, - TaskStatusIncompleteType, + TaskStatusOtherType, TaskStatusType, TaskType, ) from shared.types.worker.common import InstanceId, NodeStatus -from shared.types.worker.instances import InstanceData, InstanceStatus +from shared.types.worker.instances import InstanceParams, TypeOfInstance from shared.types.worker.runners import RunnerId, RunnerState, RunnerStateType - -class TimerData(BaseModel): - timer_id: TimerId +MLXEvent = Event[ + frozenset( + { + EventCategoryEnum.MutatesTaskState, + EventCategoryEnum.MutatesControlPlaneState, + } + ) +] +TaskEvent = Event[EventCategoryEnum.MutatesTaskState] +InstanceEvent = Event[EventCategoryEnum.MutatesInstanceState] +ControlPlaneEvent = Event[EventCategoryEnum.MutatesControlPlaneState] +DataPlaneEvent = Event[EventCategoryEnum.MutatesDataPlaneState] +NodePerformanceEvent = Event[EventCategoryEnum.MutatesNodePerformanceState] -class TaskCreated[TaskTypeT: TaskType](TaskEvent): - event_type: TaskEventTypes = TaskEventTypes.TaskCreated +class TaskCreated(Event[EventCategoryEnum.MutatesTaskState]): + event_type: EventTypes = TaskEventTypes.TaskCreated task_id: TaskId - task_data: TaskData[TaskTypeT] - task_state: TaskState[Literal[TaskStatusIncompleteType.Pending], TaskTypeT] + task_params: TaskParams[TaskType] + task_state: TaskState[Literal[TaskStatusOtherType.Pending], TaskType] on_instance: InstanceId -class TaskUpdated[TaskTypeT: TaskType](TaskEvent): - event_type: TaskEventTypes = TaskEventTypes.TaskUpdated - task_id: TaskId - update_data: TaskState[TaskStatusType, TaskTypeT] - - -class TaskDeleted(TaskEvent): - event_type: TaskEventTypes = TaskEventTypes.TaskDeleted +# Covers Cancellation Of Task, Non-Cancelled Tasks Perist +class TaskDeleted(Event[EventCategoryEnum.MutatesTaskState]): + event_type: EventTypes = TaskEventTypes.TaskDeleted task_id: TaskId -class InstanceCreated(InstanceEvent): - event_type: InstanceEventTypes = InstanceEventTypes.InstanceCreated +class TaskStateUpdated(Event[EventCategoryEnum.MutatesTaskState]): + event_type: EventTypes = TaskEventTypes.TaskStateUpdated + task_state: TaskState[TaskStatusType, TaskType] + + +class InstanceCreated(Event[EventCategoryEnum.MutatesInstanceState]): + event_type: EventTypes = InstanceEventTypes.InstanceCreated instance_id: InstanceId - instance_data: InstanceData - target_status: InstanceStatus + instance_params: InstanceParams + instance_type: TypeOfInstance -class InstanceDeleted(InstanceEvent): - event_type: InstanceEventTypes = InstanceEventTypes.InstanceDeleted +class InstanceDeleted(Event[EventCategoryEnum.MutatesInstanceState]): + event_type: EventTypes = InstanceEventTypes.InstanceDeleted instance_id: InstanceId - -class InstanceStatusUpdated(InstanceEvent): - event_type: InstanceEventTypes = InstanceEventTypes.InstanceStatusUpdated - instance_id: InstanceId - instance_status: InstanceStatus + transition: Tuple[InstanceId, InstanceId] -class InstanceRunnerStateUpdated(InstanceStateEvent): - event_type: InstanceStateEventTypes = ( - InstanceStateEventTypes.InstanceRunnerStateUpdated - ) +class InstanceReplacedAtomically(Event[EventCategoryEnum.MutatesInstanceState]): + event_type: EventTypes = InstanceEventTypes.InstanceReplacedAtomically + instance_to_replace: InstanceId + new_instance_id: InstanceId + new_instance_params: InstanceParams + new_instance_type: TypeOfInstance + + +class InstanceSagaRunnerStateUpdated(Event[EventCategoryEnum.MutatesInstanceState]): + event_type: EventTypes = InstanceStateEventTypes.InstanceSagaRunnerStateUpdated instance_id: InstanceId state_update: Tuple[RunnerId, RunnerState[RunnerStateType]] -class InstanceToBeReplacedAtomically(InstanceEvent): - event_type: InstanceEventTypes = InstanceEventTypes.InstanceToBeReplacedAtomically - transition: Tuple[InstanceId, InstanceId] - - -class InstanceReplacedAtomically(InstanceEvent): - event_type: InstanceEventTypes = InstanceEventTypes.InstanceReplacedAtomically - transition: Tuple[InstanceId, InstanceId] - - -class MLXInferenceSagaPrepare(MLXEvent): - event_type: MLXEventTypes = MLXEventTypes.MLXInferenceSagaPrepare +class MLXInferenceSagaPrepare(Event[EventCategoryEnum.MutatesTaskState]): + event_type: EventTypes = MLXEventTypes.MLXInferenceSagaPrepare task_id: TaskId instance_id: InstanceId -class MLXInferenceSagaStartPrepare(MLXEvent): - event_type: MLXEventTypes = MLXEventTypes.MLXInferenceSagaStartPrepare +class MLXInferenceSagaStartPrepare(Event[EventCategoryEnum.MutatesTaskState]): + event_type: EventTypes = MLXEventTypes.MLXInferenceSagaStartPrepare task_id: TaskId instance_id: InstanceId -class NodePerformanceProfiled(NodePerformanceEvent): - event_type: NodePerformanceEventTypes = ( - NodePerformanceEventTypes.NodePerformanceProfiled - ) +class NodePerformanceMeasured(Event[EventCategoryEnum.MutatesNodePerformanceState]): + event_type: EventTypes = NodePerformanceEventTypes.NodePerformanceMeasured node_id: NodeId node_profile: NodePerformanceProfile -class WorkerConnected(ControlPlaneEvent): - event_type: ControlPlaneEventTypes = ControlPlaneEventTypes.WorkerConnected +class WorkerConnected(Event[EventCategoryEnum.MutatesControlPlaneState]): + event_type: EventTypes = ControlPlaneEventTypes.WorkerConnected edge: DataPlaneEdge -class WorkerStatusUpdated(ControlPlaneEvent): - event_type: ControlPlaneEventTypes = ControlPlaneEventTypes.WorkerStatusUpdated +class WorkerStatusUpdated(Event[EventCategoryEnum.MutatesControlPlaneState]): + event_type: EventTypes = ControlPlaneEventTypes.WorkerStatusUpdated node_id: NodeId node_state: NodeStatus -class WorkerDisconnected(ControlPlaneEvent): - event_type: ControlPlaneEventTypes = ControlPlaneEventTypes.WorkerConnected +class WorkerDisconnected(Event[EventCategoryEnum.MutatesControlPlaneState]): + event_type: EventTypes = ControlPlaneEventTypes.WorkerConnected vertex_id: ControlPlaneEdgeId -class ChunkGenerated(StreamingEvent): - event_type: StreamingEventTypes = StreamingEventTypes.ChunkGenerated +class ChunkGenerated(Event[EventCategoryEnum.MutatesTaskState]): + event_type: EventTypes = StreamingEventTypes.ChunkGenerated task_id: TaskId - instance_id: InstanceId - chunk: Any + chunk: GenerationChunk -class DataPlaneEdgeCreated(DataPlaneEvent): - event_type: DataPlaneEventTypes = DataPlaneEventTypes.DataPlaneEdgeCreated +class DataPlaneEdgeCreated(Event[EventCategoryEnum.MutatesDataPlaneState]): + event_type: EventTypes = DataPlaneEventTypes.DataPlaneEdgeCreated vertex: ControlPlaneEdgeType -class DataPlaneEdgeProfiled(DataPlaneEvent): - event_type: DataPlaneEventTypes = DataPlaneEventTypes.DataPlaneEdgeProfiled +class DataPlaneEdgeReplacedAtomically(Event[EventCategoryEnum.MutatesDataPlaneState]): + event_type: EventTypes = DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically edge_id: DataPlaneEdgeId edge_profile: DataPlaneEdgeProfile -class DataPlaneEdgeDeleted(DataPlaneEvent): - event_type: DataPlaneEventTypes = DataPlaneEventTypes.DataPlaneEdgeDeleted +class DataPlaneEdgeDeleted(Event[EventCategoryEnum.MutatesDataPlaneState]): + event_type: EventTypes = DataPlaneEventTypes.DataPlaneEdgeDeleted edge_id: DataPlaneEdgeId - - -class TimerScheduled(TimerEvent): - event_type: TimerEventTypes = TimerEventTypes.TimerCreated - timer_data: TimerData - - -class TimerFired(TimerEvent): - event_type: TimerEventTypes = TimerEventTypes.TimerFired - timer_data: TimerData - - -class ResourceProfiled(ResourceEvent): - event_type: ResourceEventTypes = ResourceEventTypes.ResourceProfiled - resource_name: ProfiledResourceName - resource_profile: NodePerformanceProfile diff --git a/shared/types/events/registry.py b/shared/types/events/registry.py new file mode 100644 index 00000000..79d7616e --- /dev/null +++ b/shared/types/events/registry.py @@ -0,0 +1,133 @@ +from typing import Any, Mapping, Type, get_args +from types import UnionType +from shared.constants import EXO_ERROR_REPORTING_MESSAGE +from shared.types.events.common import ( + Event, + EventTypes, + TaskEventTypes, + InstanceEventTypes, + NodePerformanceEventTypes, + ControlPlaneEventTypes, + StreamingEventTypes, + DataPlaneEventTypes, + MLXEventTypes, + InstanceStateEventTypes, +) +from shared.types.events.events import ( + TaskCreated, + TaskStateUpdated, + TaskDeleted, + InstanceCreated, + InstanceDeleted, + InstanceReplacedAtomically, + InstanceSagaRunnerStateUpdated, + NodePerformanceMeasured, + WorkerConnected, + WorkerStatusUpdated, + WorkerDisconnected, + ChunkGenerated, + DataPlaneEdgeCreated, + DataPlaneEdgeReplacedAtomically, + DataPlaneEdgeDeleted, + MLXInferenceSagaPrepare, + MLXInferenceSagaStartPrepare, +) +from pydantic import TypeAdapter +from typing import Annotated +from pydantic import Field +from shared.types.events.common import EventCategories + +""" +class EventTypeNames(StrEnum): + TaskEventType = auto() + InstanceEvent = auto() + NodePerformanceEvent = auto() + ControlPlaneEvent = auto() + StreamingEvent = auto() + DataPlaneEvent = auto() + TimerEvent = auto() + MLXEvent = auto() + +check_event_categories_are_defined_for_all_event_types(EVENT_TYPE_ENUMS, EventTypeNames) +""" + +EventRegistry: Mapping[EventTypes, Type[Any]] = { + TaskEventTypes.TaskCreated: TaskCreated, + TaskEventTypes.TaskStateUpdated: TaskStateUpdated, + TaskEventTypes.TaskDeleted: TaskDeleted, + InstanceEventTypes.InstanceCreated: InstanceCreated, + InstanceEventTypes.InstanceDeleted: InstanceDeleted, + InstanceEventTypes.InstanceReplacedAtomically: InstanceReplacedAtomically, + InstanceStateEventTypes.InstanceSagaRunnerStateUpdated: InstanceSagaRunnerStateUpdated, + NodePerformanceEventTypes.NodePerformanceMeasured: NodePerformanceMeasured, + ControlPlaneEventTypes.WorkerConnected: WorkerConnected, + ControlPlaneEventTypes.WorkerStatusUpdated: WorkerStatusUpdated, + ControlPlaneEventTypes.WorkerDisconnected: WorkerDisconnected, + StreamingEventTypes.ChunkGenerated: ChunkGenerated, + DataPlaneEventTypes.DataPlaneEdgeCreated: DataPlaneEdgeCreated, + DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically: DataPlaneEdgeReplacedAtomically, + DataPlaneEventTypes.DataPlaneEdgeDeleted: DataPlaneEdgeDeleted, + MLXEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare, + MLXEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare, +} + + +# Sanity Check. +def check_registry_has_all_event_types() -> None: + event_types: tuple[EventTypes, ...] = get_args(EventTypes) + missing_event_types = set(event_types) - set(EventRegistry.keys()) + + assert not missing_event_types, ( + f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"There's an event missing from the registry: {missing_event_types}" + ) + + +def check_union_of_all_events_is_consistent_with_registry( + registry: Mapping[EventTypes, Type[Any]], union_type: UnionType +) -> None: + type_of_each_registry_entry = set( + type(event_type) for event_type in registry.keys() + ) + type_of_each_entry_in_union = set(get_args(union_type)) + missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union + + assert not missing_from_union, ( + f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"Event classes in registry are missing from all_events union: {missing_from_union}" + ) + + extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry + + assert not extra_in_union, ( + f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"Event classes in all_events union are missing from registry: {extra_in_union}" + ) + + +AllEvents = ( + TaskCreated + | TaskStateUpdated + | TaskDeleted + | InstanceCreated + | InstanceDeleted + | InstanceReplacedAtomically + | InstanceSagaRunnerStateUpdated + | NodePerformanceMeasured + | WorkerConnected + | WorkerStatusUpdated + | WorkerDisconnected + | ChunkGenerated + | DataPlaneEdgeCreated + | DataPlaneEdgeReplacedAtomically + | DataPlaneEdgeDeleted + | MLXInferenceSagaPrepare + | MLXInferenceSagaStartPrepare +) + +# Run the sanity check +check_union_of_all_events_is_consistent_with_registry(EventRegistry, AllEvents) + + +_EventType = Annotated[AllEvents, Field(discriminator="event_type")] +EventParser: TypeAdapter[Event[EventCategories]] = TypeAdapter(_EventType) diff --git a/shared/types/events/sanity_checking.py b/shared/types/events/sanity_checking.py new file mode 100644 index 00000000..4387a52c --- /dev/null +++ b/shared/types/events/sanity_checking.py @@ -0,0 +1,68 @@ +from typing import LiteralString, Sequence, Set, Any, Type, get_args +from types import UnionType +from enum import Enum, StrEnum + +from shared.constants import EXO_ERROR_REPORTING_MESSAGE + + +def check_event_type_union_is_consistent_with_registry( + event_type_enums: Sequence[Type[Enum]], event_types: UnionType +) -> None: + """Assert that every enum value from _EVENT_TYPE_ENUMS satisfies EventTypes.""" + + event_types_inferred_from_union = set(get_args(event_types)) + + event_types_inferred_from_registry = [ + member for enum_class in event_type_enums for member in enum_class + ] + + # Check that each registry value belongs to one of the types in the union + for tag_of_event_type in event_types_inferred_from_registry: + event_type = type(tag_of_event_type) + assert event_type in event_types_inferred_from_union, ( + f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"There's a mismatch between the registry of event types and the union of possible event types." + f"The enum value {tag_of_event_type} for type {event_type} is not covered by {event_types_inferred_from_union}." + ) + + +def check_event_categories_are_defined_for_all_event_types( + event_definitions: Sequence[Type[Enum]], event_categories: Type[StrEnum] +) -> None: + """Assert that the event category names are consistent with the event type enums.""" + + expected_category_tags: list[str] = [ + enum_class.__name__ for enum_class in event_definitions + ] + tag_of_event_categories: list[str] = list(event_categories.__members__.values()) + assert tag_of_event_categories == expected_category_tags, ( + f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"The values of the enum EventCategories are not named after the event type enums." + f"These are the missing categories: {set(expected_category_tags) - set(tag_of_event_categories)}" + f"These are the extra categories: {set(tag_of_event_categories) - set(expected_category_tags)}" + ) + + +def assert_literal_union_covers_enum[TEnum: StrEnum]( + literal_union: UnionType, + enum_type: Type[TEnum], +) -> None: + enum_values: Set[Any] = {member.value for member in enum_type} + + def _flatten(tp: UnionType) -> Set[Any]: + values: Set[Any] = set() + args: tuple[LiteralString, ...] = get_args(tp) + for arg in args: + payloads: tuple[TEnum, ...] = get_args(arg) + for payload in payloads: + values.add(payload.value) + return values + + literal_values: Set[Any] = _flatten(literal_union) + + assert enum_values == literal_values, ( + f"{EXO_ERROR_REPORTING_MESSAGE()}" + f"The values of the enum {enum_type} are not covered by the literal union {literal_union}.\n" + f"These are the missing values: {enum_values - literal_values}\n" + f"These are the extra values: {literal_values - enum_values}\n" + ) diff --git a/shared/types/graphs/common.py b/shared/types/graphs/common.py index b43581fa..d87fcace 100644 --- a/shared/types/graphs/common.py +++ b/shared/types/graphs/common.py @@ -41,8 +41,8 @@ class Edge( class GraphData(BaseModel, Generic[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]): - edges: Mapping[EdgeIdT, EdgeData[EdgeTypeT]] - vertices: Mapping[VertexIdT, VertexData[VertexTypeT]] + edges: Mapping[EdgeIdT, EdgeData[EdgeTypeT]] = {} + vertices: Mapping[VertexIdT, VertexData[VertexTypeT]] = {} class GraphProtocol(Protocol, Generic[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]): @@ -111,11 +111,12 @@ class MutableGraphProtocol(GraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, Vertex class Graph( - BaseModel, Generic[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT], - GraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT], + MutableGraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT], ): - graph_data: GraphData[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT] + graph_data: GraphData[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT] = GraphData[ + EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT + ]() # the first element in the return value is the filtered graph; the second is the diff --git a/shared/types/models/metadata.py b/shared/types/models/metadata.py index 1d42d3dc..1c0015e9 100644 --- a/shared/types/models/metadata.py +++ b/shared/types/models/metadata.py @@ -1,8 +1,9 @@ -from typing import Annotated +from typing import Annotated, final from pydantic import BaseModel, PositiveInt +@final class ModelMetadata(BaseModel): pretty_name: str storage_size_kilobytes: Annotated[int, PositiveInt] diff --git a/shared/types/models/sources.py b/shared/types/models/sources.py index 8f636a26..a3712bff 100644 --- a/shared/types/models/sources.py +++ b/shared/types/models/sources.py @@ -1,27 +1,26 @@ from enum import Enum -from typing import Annotated, Any, Generic, Literal, TypeVar, Union, final +from typing import Annotated, Any, Literal, Union, final from pydantic import AnyHttpUrl, BaseModel, Field, TypeAdapter from shared.types.models.common import ModelId +@final class SourceType(str, Enum): HuggingFace = "HuggingFace" GitHub = "GitHub" +@final class SourceFormatType(str, Enum): HuggingFaceTransformers = "HuggingFaceTransformers" -T = TypeVar("T", bound=SourceType) -S = TypeVar("S", bound=SourceFormatType) - RepoPath = Annotated[str, Field(pattern=r"^[^/]+/[^/]+$")] -class BaseModelSource(BaseModel, Generic[T, S]): +class BaseModelSource[T: SourceType, S: SourceFormatType](BaseModel): model_uuid: ModelId source_type: T source_format: S @@ -50,15 +49,16 @@ class HuggingFaceModelSource( @final -class GitHubModelSource(BaseModelSource[SourceType.GitHub, S]): +class GitHubModelSource(BaseModelSource[SourceType.GitHub, SourceFormatType]): source_type: Literal[SourceType.GitHub] = SourceType.GitHub + source_format: SourceFormatType source_data: GitHubModelSourceData _ModelSource = Annotated[ Union[ HuggingFaceModelSource, - GitHubModelSource[SourceFormatType.HuggingFaceTransformers], + GitHubModelSource, ], Field(discriminator="source_type"), ] diff --git a/shared/types/networking/topology.py b/shared/types/networking/topology.py index 61e8900b..747358b9 100644 --- a/shared/types/networking/topology.py +++ b/shared/types/networking/topology.py @@ -1,72 +1,45 @@ from shared.types.common import NodeId -from shared.types.graphs.common import Graph, GraphData from shared.types.networking.control_plane import ControlPlaneEdgeId from shared.types.networking.data_plane import ( DataPlaneEdgeData, DataPlaneEdgeId, ) from shared.types.worker.common import NodeStatus +from shared.graphs.networkx import NetworkXGraph class DataPlaneTopology( - Graph[ + NetworkXGraph[ DataPlaneEdgeData, None, DataPlaneEdgeId, NodeId, ] ): - graph_data: GraphData[ - DataPlaneEdgeData, - None, - DataPlaneEdgeId, - NodeId, - ] + pass class OrphanedPartOfDataPlaneTopology( - Graph[ + NetworkXGraph[ DataPlaneEdgeData, None, DataPlaneEdgeId, NodeId, ] ): - graph_data: GraphData[ - DataPlaneEdgeData, - None, - DataPlaneEdgeId, - NodeId, - ] + pass -class ControlPlaneTopology( - Graph[ - None, - NodeStatus, - ControlPlaneEdgeId, - NodeId, - ] -): - graph_data: GraphData[ - None, - NodeStatus, - ControlPlaneEdgeId, - NodeId, - ] +class ControlPlaneTopology(NetworkXGraph[None, NodeStatus, ControlPlaneEdgeId, NodeId]): + pass class OrphanedPartOfControlPlaneTopology( - Graph[ + NetworkXGraph[ None, NodeStatus, ControlPlaneEdgeId, NodeId, ] ): - graph_data: GraphData[ - None, - NodeStatus, - ControlPlaneEdgeId, - NodeId, - ] + pass diff --git a/shared/types/states/master.py b/shared/types/states/master.py index e1233b11..b15417be 100644 --- a/shared/types/states/master.py +++ b/shared/types/states/master.py @@ -1,20 +1,22 @@ from collections.abc import Mapping, Sequence from enum import Enum from queue import Queue -from typing import Generic, TypeVar +from typing import Generic, Literal, TypeVar -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter +from shared.types.worker.common import NodeStatus from shared.types.common import NodeId from shared.types.events.common import ( Event, - EventCategories, + EventCategory, State, ) from shared.types.graphs.resource_graph import ResourceGraph from shared.types.networking.data_plane import ( DataPlaneEdge, DataPlaneEdgeId, + DataPlaneEdgeAdapter, ) from shared.types.networking.topology import ( ControlPlaneTopology, @@ -24,8 +26,8 @@ from shared.types.networking.topology import ( ) from shared.types.profiling.common import NodePerformanceProfile from shared.types.states.shared import SharedState -from shared.types.tasks.common import TaskData, TaskType -from shared.types.worker.instances import InstanceData, InstanceId +from shared.types.tasks.common import TaskParams, TaskType +from shared.types.worker.instances import InstanceParams, InstanceId class ExternalCommand(BaseModel): ... @@ -42,44 +44,56 @@ class CachePolicy(BaseModel, Generic[CachePolicyTypeT]): policy_type: CachePolicyTypeT -class NodePerformanceProfileState(State[EventCategories.NodePerformanceEventTypes]): +class NodePerformanceProfileState(State[EventCategory.MutatesNodePerformanceState]): node_profiles: Mapping[NodeId, NodePerformanceProfile] -class DataPlaneNetworkState(State[EventCategories.DataPlaneEventTypes]): - topology: DataPlaneTopology - history: Sequence[OrphanedPartOfDataPlaneTopology] +class DataPlaneNetworkState(State[EventCategory.MutatesDataPlaneState]): + event_category: Literal[EventCategory.MutatesDataPlaneState] = ( + EventCategory.MutatesDataPlaneState + ) + topology: DataPlaneTopology = DataPlaneTopology( + edge_base=DataPlaneEdgeAdapter, vertex_base=TypeAdapter(None) + ) + history: Sequence[OrphanedPartOfDataPlaneTopology] = [] def delete_edge(self, edge_id: DataPlaneEdgeId) -> None: ... def add_edge(self, edge: DataPlaneEdge) -> None: ... -class ControlPlaneNetworkState(State[EventCategories.ControlPlaneEventTypes]): - topology: ControlPlaneTopology - history: Sequence[OrphanedPartOfControlPlaneTopology] +class ControlPlaneNetworkState(State[EventCategory.MutatesControlPlaneState]): + event_category: Literal[EventCategory.MutatesControlPlaneState] = ( + EventCategory.MutatesControlPlaneState + ) + topology: ControlPlaneTopology = ControlPlaneTopology( + edge_base=TypeAdapter(None), vertex_base=TypeAdapter(NodeStatus) + ) + history: Sequence[OrphanedPartOfControlPlaneTopology] = [] def delete_edge(self, edge_id: DataPlaneEdgeId) -> None: ... def add_edge(self, edge: DataPlaneEdge) -> None: ... class MasterState(SharedState): - data_plane_network_state: DataPlaneNetworkState - control_plane_network_state: ControlPlaneNetworkState - job_inbox: Queue[TaskData[TaskType]] - job_outbox: Queue[TaskData[TaskType]] - cache_policy: CachePolicy[CachePolicyType] + data_plane_network_state: DataPlaneNetworkState = DataPlaneNetworkState() + control_plane_network_state: ControlPlaneNetworkState = ControlPlaneNetworkState() + job_inbox: Queue[TaskParams[TaskType]] = Queue() + job_outbox: Queue[TaskParams[TaskType]] = Queue() + cache_policy: CachePolicy[CachePolicyType] = CachePolicy[CachePolicyType]( + policy_type=CachePolicyType.KeepAll + ) def get_shard_assignments( inbox: Queue[ExternalCommand], outbox: Queue[ExternalCommand], resource_graph: ResourceGraph, - current_instances: Mapping[InstanceId, InstanceData], + current_instances: Mapping[InstanceId, InstanceParams], cache_policy: CachePolicy[CachePolicyType], -) -> Mapping[InstanceId, InstanceData]: ... +) -> Mapping[InstanceId, InstanceParams]: ... def get_transition_events( - current_instances: Mapping[InstanceId, InstanceData], - target_instances: Mapping[InstanceId, InstanceData], -) -> Sequence[Event[EventCategories]]: ... + current_instances: Mapping[InstanceId, InstanceParams], + target_instances: Mapping[InstanceId, InstanceParams], +) -> Sequence[Event[EventCategory]]: ... diff --git a/shared/types/states/shared.py b/shared/types/states/shared.py index 75e3140e..4b1c6e4d 100644 --- a/shared/types/states/shared.py +++ b/shared/types/states/shared.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Sequence +from typing import Literal, Sequence from pydantic import BaseModel @@ -11,17 +11,22 @@ from shared.types.worker.instances import BaseInstance class KnownInstances(State[EventCategories.InstanceStateEventTypes]): - instances: Mapping[InstanceId, BaseInstance] + event_category: Literal[EventCategories.InstanceStateEventTypes] = ( + EventCategories.InstanceStateEventTypes + ) + instances: Mapping[InstanceId, BaseInstance] = {} class Tasks(State[EventCategories.TaskEventTypes]): - tasks: Mapping[TaskId, Task[TaskType, TaskStatusType]] + event_category: Literal[EventCategories.TaskEventTypes] = ( + EventCategories.TaskEventTypes + ) + tasks: Mapping[TaskId, Task[TaskType, TaskStatusType]] = {} class SharedState(BaseModel): - node_id: NodeId - known_instances: KnownInstances - compute_tasks: Tasks + known_instances: KnownInstances = KnownInstances() + compute_tasks: Tasks = Tasks() def get_node_id(self) -> NodeId: ... diff --git a/shared/types/states/worker.py b/shared/types/states/worker.py index 699ecb84..a57dcd06 100644 --- a/shared/types/states/worker.py +++ b/shared/types/states/worker.py @@ -2,14 +2,14 @@ from collections.abc import Mapping from shared.types.common import NodeId from shared.types.events.common import ( - EventCategories, + EventCategory, State, ) from shared.types.states.shared import SharedState from shared.types.worker.common import NodeStatus -class NodeStatusState(State[EventCategories.ControlPlaneEventTypes]): +class NodeStatusState(State[EventCategory.MutatesControlPlaneState]): node_status: Mapping[NodeId, NodeStatus] diff --git a/shared/types/tasks/common.py b/shared/types/tasks/common.py index 7e58c35f..648cc054 100644 --- a/shared/types/tasks/common.py +++ b/shared/types/tasks/common.py @@ -1,18 +1,18 @@ -from collections.abc import Mapping from enum import Enum -from typing import Annotated, Generic, Literal, TypeVar, Union +from typing import Annotated, Generic, Literal, TypeVar, Union, final import openai.types.chat as openai from pydantic import BaseModel, Field, TypeAdapter from shared.types.common import NewUUID -from shared.types.worker.common import InstanceId, RunnerId +from shared.types.worker.common import InstanceId class TaskId(NewUUID): pass +@final class TaskType(str, Enum): ChatCompletionNonStreaming = "ChatCompletionNonStreaming" ChatCompletionStreaming = "ChatCompletionStreaming" @@ -21,82 +21,68 @@ class TaskType(str, Enum): TaskTypeT = TypeVar("TaskTypeT", bound=TaskType, covariant=True) -class TaskData(BaseModel, Generic[TaskTypeT]): ... +class TaskParams(BaseModel, Generic[TaskTypeT]): ... -class ChatCompletionNonStreamingTask(TaskData[TaskType.ChatCompletionNonStreaming]): +@final +class ChatCompletionNonStreamingTask(TaskParams[TaskType.ChatCompletionNonStreaming]): task_type: Literal[TaskType.ChatCompletionNonStreaming] = ( TaskType.ChatCompletionNonStreaming ) task_data: openai.completion_create_params.CompletionCreateParams -class ChatCompletionStreamingTask(TaskData[TaskType.ChatCompletionStreaming]): +@final +class ChatCompletionStreamingTask(TaskParams[TaskType.ChatCompletionStreaming]): task_type: Literal[TaskType.ChatCompletionStreaming] = ( TaskType.ChatCompletionStreaming ) task_data: openai.completion_create_params.CompletionCreateParams -class TaskStatusIncompleteType(str, Enum): - Pending = "Pending" - Running = "Running" +@final +class TaskStatusFailedType(str, Enum): Failed = "Failed" +@final class TaskStatusCompleteType(str, Enum): Complete = "Complete" -TaskStatusType = Union[TaskStatusIncompleteType, TaskStatusCompleteType] +@final +class TaskStatusOtherType(str, Enum): + Pending = "Pending" + Running = "Running" + + +TaskStatusType = TaskStatusCompleteType | TaskStatusFailedType | TaskStatusOtherType class TaskArtifact[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel): ... -class IncompleteTaskArtifact[TaskTypeT: TaskType]( - TaskArtifact[TaskTypeT, TaskStatusIncompleteType] -): +@final +class NoTaskArtifact[TaskTypeT: TaskType](TaskArtifact[TaskTypeT, TaskStatusOtherType]): pass -class TaskStatusUpdate[TaskStatusTypeT: TaskStatusType](BaseModel): - task_status: TaskStatusTypeT - - -class PendingTaskStatus(TaskStatusUpdate[TaskStatusIncompleteType.Pending]): - task_status: Literal[TaskStatusIncompleteType.Pending] = ( - TaskStatusIncompleteType.Pending - ) - - -class RunningTaskStatus(TaskStatusUpdate[TaskStatusIncompleteType.Running]): - task_status: Literal[TaskStatusIncompleteType.Running] = ( - TaskStatusIncompleteType.Running - ) - - -class CompletedTaskStatus(TaskStatusUpdate[TaskStatusCompleteType.Complete]): - task_status: Literal[TaskStatusCompleteType.Complete] = ( - TaskStatusCompleteType.Complete - ) - - -class FailedTaskStatus(TaskStatusUpdate[TaskStatusIncompleteType.Failed]): - task_status: Literal[TaskStatusIncompleteType.Failed] = ( - TaskStatusIncompleteType.Failed - ) - error_message: Mapping[RunnerId, str] +@final +class FailedTaskArtifact[TaskTypeT: TaskType]( + TaskArtifact[TaskTypeT, TaskStatusFailedType] +): + error_message: str +@final class TaskState[TaskStatusTypeT: TaskStatusType, TaskTypeT: TaskType](BaseModel): - task_status: TaskStatusUpdate[TaskStatusTypeT] + task_status: TaskStatusTypeT task_artifact: TaskArtifact[TaskTypeT, TaskStatusTypeT] class BaseTask[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel): task_type: TaskTypeT - task_data: TaskData[TaskTypeT] + task_params: TaskParams[TaskTypeT] task_state: TaskState[TaskStatusTypeT, TaskTypeT] on_instance: InstanceId @@ -109,11 +95,12 @@ BaseTaskAnnotated = Annotated[ Field(discriminator="task_type"), ] -BaseTaskValidator: TypeAdapter[BaseTask[TaskType, TaskStatusType]] = TypeAdapter( +BaseTaskParser: TypeAdapter[BaseTask[TaskType, TaskStatusType]] = TypeAdapter( BaseTaskAnnotated ) +@final class Task[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType]( BaseTask[TaskTypeT, TaskStatusTypeT] ): diff --git a/shared/types/worker/downloads.py b/shared/types/worker/downloads.py index c88b2d57..acc53650 100644 --- a/shared/types/worker/downloads.py +++ b/shared/types/worker/downloads.py @@ -2,11 +2,9 @@ from enum import Enum from typing import ( Annotated, Callable, - Generic, Literal, NewType, Sequence, - TypeVar, Union, ) @@ -30,10 +28,7 @@ class DownloadStatus(str, Enum): Failed = "Failed" -DownloadStatusT = TypeVar("DownloadStatusT", bound=DownloadStatus) - - -class BaseDownloadProgress(BaseModel, Generic[DownloadStatusT]): +class BaseDownloadProgress[DownloadStatusT: DownloadStatus](BaseModel): node_id: NodeId download_status: DownloadStatusT @@ -80,6 +75,6 @@ DownloadEffectHandler = Callable[ def download_shard( model_id: ModelId, model_source: ModelSource, - shard_meta: ShardMetadata[PartitionStrategy], + shard_metadata: ShardMetadata[PartitionStrategy], effect_handlers: Sequence[DownloadEffectHandler], ) -> None: ... diff --git a/shared/types/worker/instances.py b/shared/types/worker/instances.py index f23b5807..42d23486 100644 --- a/shared/types/worker/instances.py +++ b/shared/types/worker/instances.py @@ -12,24 +12,27 @@ from shared.types.worker.runners import ( ) -class InstanceStatus(str, Enum): +class TypeOfInstance(str, Enum): ACTIVE = "active" INACTIVE = "inactive" -class InstanceState(BaseModel): - runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]] - - -class InstanceData(BaseModel): +class InstanceParams(BaseModel): shard_assignments: ShardAssignments class BaseInstance(BaseModel): - instance_data: InstanceData - instance_state: InstanceState - instance_status: InstanceStatus + instance_params: InstanceParams + instance_type: TypeOfInstance class Instance(BaseInstance): instance_id: InstanceId + + +class BaseInstanceSaga(BaseModel): + runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]] + + +class InstanceSaga(BaseInstanceSaga): + instance_id: InstanceId diff --git a/shared/types/worker/resource_monitor.py b/shared/types/worker/resource_monitor.py index 96eba8d2..f45d943a 100644 --- a/shared/types/worker/resource_monitor.py +++ b/shared/types/worker/resource_monitor.py @@ -1,9 +1,8 @@ import asyncio from abc import ABC, abstractmethod from collections.abc import Coroutine -from typing import Callable, Set +from typing import Callable, List, Set -from shared.types.events.events import ResourceProfiled from shared.types.profiling.common import ( MemoryPerformanceProfile, NodePerformanceProfile, @@ -11,58 +10,44 @@ from shared.types.profiling.common import ( ) -class EventLog: - def append(self, event: ResourceProfiled) -> None: ... - - class ResourceCollector(ABC): """ Details a single resource (or resource type) that is being monitored by the resource monitor. """ - def __init__(self, name: str): - self.name = name + name = str @abstractmethod async def collect(self) -> NodePerformanceProfile: ... class SystemResourceCollector(ResourceCollector): - def __init__(self): - super().__init__("system") + name = "system" @abstractmethod async def collect(self) -> SystemPerformanceProfile: ... class MemoryResourceCollector(ResourceCollector): - def __init__(self): - super().__init__("memory") + name = "memory" @abstractmethod async def collect(self) -> MemoryPerformanceProfile: ... class ResourceMonitor: - def __init__( - self, - collectors: list[ResourceCollector], - effect_handlers: Set[Callable[[NodePerformanceProfile], None]], - ): - self.effect_handlers: Set[Callable[[NodePerformanceProfile], None]] = ( - effect_handlers - ) - self.collectors: list[ResourceCollector] = collectors + data_collectors: List[ResourceCollector] + effect_handlers: Set[Callable[[NodePerformanceProfile], None]] - # Since there's no implementation, this breaks the typechecker. - # self.collectors: list[ResourceCollector] = [ - # SystemResourceCollector(), - # MemoryResourceCollector(), - # ] + # Since there's no implementation, this breaks the typechecker. + # self.collectors: list[ResourceCollector] = [ + # SystemResourceCollector(), + # MemoryResourceCollector(), + # ] async def _collect(self) -> list[NodePerformanceProfile]: tasks: list[Coroutine[None, None, NodePerformanceProfile]] = [ - collector.collect() for collector in self.collectors + collector.collect() for collector in self.data_collectors ] return await asyncio.gather(*tasks) diff --git a/shared/types/worker/runners.py b/shared/types/worker/runners.py index c7528094..31bfa070 100644 --- a/shared/types/worker/runners.py +++ b/shared/types/worker/runners.py @@ -1,8 +1,8 @@ from collections.abc import Mapping, Sequence from enum import Enum -from typing import Generic, Literal, TypeVar +from typing import Generic, Literal, TypeVar, Annotated -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, TypeAdapter, model_validator from shared.types.common import NodeId from shared.types.models.common import ModelId @@ -48,14 +48,15 @@ class FailedRunnerState(RunnerState[RunnerStateType.Failed]): error_message: str | None = None -class RunnerData(BaseModel): - runner_id: RunnerId - runner_state: RunnerState[RunnerStateType] = RunnerState( - runner_state=RunnerStateType.Starting - ) - - -PartitionStrategyT = TypeVar(name="PartitionStrategyT", bound=PartitionStrategy) +_RunnerState = Annotated[ + RejectedRunnerState + | StartingRunnerState + | DownloadingRunnerState + | RunningRunnerState + | FailedRunnerState, + Field, +] +RunnerStateParser: TypeAdapter[RunnerState[RunnerStateType]] = TypeAdapter(_RunnerState) class ShardAssignments(BaseModel): diff --git a/shared/types/worker/shards.py b/shared/types/worker/shards.py index 5b33457d..67361967 100644 --- a/shared/types/worker/shards.py +++ b/shared/types/worker/shards.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Annotated, Generic, Literal, TypeVar +from typing import Annotated, Literal from pydantic import BaseModel, DirectoryPath, Field, TypeAdapter @@ -11,22 +11,20 @@ class PartitionStrategy(str, Enum): pipeline = "pipeline" -PartitionStrategyT = TypeVar(name="PartitionStrategyT", bound=PartitionStrategy) - - -class ShardMetadata(BaseModel, Generic[PartitionStrategyT]): +class ShardMetadata[PartitionStrategyT: PartitionStrategy](BaseModel): """ Defines a specific shard of the model that is ready to be run on a device. Replaces previous `Shard` object. """ + partition_strategy: PartitionStrategyT device_rank: int world_size: int model_id: ModelId model_path: DirectoryPath -class PipelineShardMeta(ShardMetadata[PartitionStrategy.pipeline]): +class PipelineShardMetadata(ShardMetadata[PartitionStrategy.pipeline]): """ Pipeline parallelism shard meta. """ @@ -38,13 +36,15 @@ class PipelineShardMeta(ShardMetadata[PartitionStrategy.pipeline]): end_layer: Annotated[int, Field(ge=0)] -_ShardMeta = Annotated[PipelineShardMeta, Field(discriminator="partition_strategy")] -ShardMetaAdapter: TypeAdapter[ShardMetadata[PartitionStrategy]] = TypeAdapter( - _ShardMeta +_ShardMetadata = Annotated[ + PipelineShardMetadata, Field(discriminator="partition_strategy") +] +ShardMetaParser: TypeAdapter[ShardMetadata[PartitionStrategy]] = TypeAdapter( + _ShardMetadata ) -class ShardPlacement(BaseModel, Generic[PartitionStrategyT]): +class ShardPlacement[PartitionStrategyT: PartitionStrategy](BaseModel): """ A shard placement is the description of a model distributed across a set of nodes. The Generic[PartitionStrategyT] enforces that the shard assignments all use the same partition strategy. diff --git a/uv.lock b/uv.lock index d08efbb3..dee246b4 100644 --- a/uv.lock +++ b/uv.lock @@ -42,18 +42,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916, upload-time = "2025-03-17T00:02:52.713Z" }, ] -[[package]] -name = "basedpyright" -version = "1.29.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nodejs-wheel-binaries", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/80/fb/bd92196a07e3b4ccee4ff2761a26a05bff77d4da089b67b4b1a547868099/basedpyright-1.29.4.tar.gz", hash = "sha256:2df1976f8591eedf4b4ce8f9d123f43e810cc8cb7cc83c53eec0e2f8044073d0", size = 21961481, upload-time = "2025-06-11T22:25:55.173Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/dc/180fe721a2574fb3aad4051adcca196ac2d18adaf75122f5eeb47436cca2/basedpyright-1.29.4-py3-none-any.whl", hash = "sha256:e087513979972f83010639c6c1a1c13dd3b1d24ee45f8ecff747962cc2063d6f", size = 11476859, upload-time = "2025-06-11T22:25:52.01Z" }, -] - [[package]] name = "certifi" version = "2025.6.15" @@ -88,7 +76,6 @@ darwin = [ [package.dev-dependencies] dev = [ - { name = "basedpyright", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "maturin", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "ruff", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -104,7 +91,6 @@ provides-extras = ["darwin"] [package.metadata.requires-dev] dev = [ - { name = "basedpyright", specifier = ">=1.29.4" }, { name = "maturin", specifier = ">=1.9.0" }, { name = "pytest", specifier = ">=8.4.0" }, { name = "ruff", specifier = ">=0.11.13" }, @@ -121,10 +107,14 @@ version = "0.1.0" source = { editable = "master" } dependencies = [ { name = "exo-shared", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] [package.metadata] -requires-dist = [{ name = "exo-shared", editable = "shared" }] +requires-dist = [ + { name = "exo-shared", editable = "shared" }, + { name = "fastapi", specifier = ">=0.116.0" }, +] [[package]] name = "exo-networking" @@ -136,11 +126,13 @@ name = "exo-shared" version = "0.1.0" source = { editable = "shared" } dependencies = [ + { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pathlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] [package.dev-dependencies] @@ -150,11 +142,13 @@ dev = [ [package.metadata] requires-dist = [ + { name = "networkx", specifier = ">=3.5" }, { name = "openai", specifier = ">=1.93.0" }, { name = "pathlib", specifier = ">=1.0.1" }, { name = "protobuf", specifier = ">=6.31.1" }, { name = "pydantic", specifier = ">=2.11.7" }, { name = "rich", specifier = ">=14.0.0" }, + { name = "rustworkx", specifier = ">=0.16.0" }, ] [package.metadata.requires-dev] @@ -171,6 +165,20 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "exo-shared", editable = "shared" }] +[[package]] +name = "fastapi" +version = "0.116.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "starlette", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/38/e1da78736143fd885c36213a3ccc493c384ae8fea6a0f0bc272ef42ebea8/fastapi-0.116.0.tar.gz", hash = "sha256:80dc0794627af0390353a6d1171618276616310d37d24faba6648398e57d687a", size = 296518, upload-time = "2025-07-07T15:09:27.82Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/68/d80347fe2360445b5f58cf290e588a4729746e7501080947e6cdae114b1f/fastapi-0.116.0-py3-none-any.whl", hash = "sha256:fdcc9ed272eaef038952923bef2b735c02372402d1203ee1210af4eea7a78d2b", size = 95625, upload-time = "2025-07-07T15:09:26.348Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -308,17 +316,36 @@ wheels = [ ] [[package]] -name = "nodejs-wheel-binaries" -version = "22.16.0" +name = "networkx" +version = "3.5" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0f/c6/66f36b7b0d528660dfb4a59cb9b8dd6a3f4c0a3939cd49c404a775ea4a63/nodejs_wheel_binaries-22.16.0.tar.gz", hash = "sha256:d695832f026df3a0cf9a089d222225939de9d1b67f8f0a353b79f015aabbe7e2", size = 8061, upload-time = "2025-05-22T07:27:52.149Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/dc/417a5c5f99e53a5d2b3be122506312731eb90fb9630c248e327e2e38cc6b/nodejs_wheel_binaries-22.16.0-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:986b715a96ed703f8ce0c15712f76fc42895cf09067d72b6ef29e8b334eccf64", size = 50957501, upload-time = "2025-05-22T07:27:20.132Z" }, - { url = "https://files.pythonhosted.org/packages/0e/dd/d6ce48209ed15f5d1fccb29eeaa111f962557123eaf4fd03a7316c42734c/nodejs_wheel_binaries-22.16.0-py2.py3-none-macosx_11_0_x86_64.whl", hash = "sha256:4ae3cf22138891cb44c3ee952862a257ce082b098b29024d7175684a9a77b0c0", size = 51891634, upload-time = "2025-05-22T07:27:24.029Z" }, - { url = "https://files.pythonhosted.org/packages/80/fa/a07e622fd87717eec3e5cff41575f85ad62717e8698884d28ca809266ca1/nodejs_wheel_binaries-22.16.0-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71f2de4dc0b64ae43e146897ce811f80ac4f9acfbae6ccf814226282bf4ef174", size = 57857862, upload-time = "2025-05-22T07:27:27.933Z" }, - { url = "https://files.pythonhosted.org/packages/1f/80/52736f9570a93f8e6b7942981dc9770eca2bc7aa1d200c1d54198374a6ca/nodejs_wheel_binaries-22.16.0-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbfccbcd558d2f142ccf66d8c3a098022bf4436db9525b5b8d32169ce185d99e", size = 58395868, upload-time = "2025-05-22T07:27:32.088Z" }, - { url = "https://files.pythonhosted.org/packages/0f/0e/53616a5ed8fc1fbe9e48bf132862da5a9abf5cc7f8483dab1722ec257187/nodejs_wheel_binaries-22.16.0-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:447ad796850eb52ca20356ad39b2d296ed8fef3f214921f84a1ccdad49f2eba1", size = 59712469, upload-time = "2025-05-22T07:27:37.193Z" }, - { url = "https://files.pythonhosted.org/packages/4a/cd/e2b5083df581fc1d08eb93feb6f8fbd3d56b113cef9b59d8e0fb7d4dd4f3/nodejs_wheel_binaries-22.16.0-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7f526ca6a132b0caf633566a2a78c6985fe92857e7bfdb37380f76205a10b808", size = 60763005, upload-time = "2025-05-22T07:27:41.39Z" }, + { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" }, +] + +[[package]] +name = "numpy" +version = "2.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/19/d7c972dfe90a353dbd3efbbe1d14a5951de80c99c9dc1b93cd998d51dc0f/numpy-2.3.1.tar.gz", hash = "sha256:1ec9ae20a4226da374362cca3c62cd753faf2f951440b0e3b98e93c235441d2b", size = 20390372, upload-time = "2025-06-21T12:28:33.469Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/bd/35ad97006d8abff8631293f8ea6adf07b0108ce6fec68da3c3fcca1197f2/numpy-2.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25a1992b0a3fdcdaec9f552ef10d8103186f5397ab45e2d25f8ac51b1a6b97e8", size = 20889381, upload-time = "2025-06-21T12:19:04.103Z" }, + { url = "https://files.pythonhosted.org/packages/f1/4f/df5923874d8095b6062495b39729178eef4a922119cee32a12ee1bd4664c/numpy-2.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7dea630156d39b02a63c18f508f85010230409db5b2927ba59c8ba4ab3e8272e", size = 14152726, upload-time = "2025-06-21T12:19:25.599Z" }, + { url = "https://files.pythonhosted.org/packages/8c/0f/a1f269b125806212a876f7efb049b06c6f8772cf0121139f97774cd95626/numpy-2.3.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:bada6058dd886061f10ea15f230ccf7dfff40572e99fef440a4a857c8728c9c0", size = 5105145, upload-time = "2025-06-21T12:19:34.782Z" }, + { url = "https://files.pythonhosted.org/packages/6d/63/a7f7fd5f375b0361682f6ffbf686787e82b7bbd561268e4f30afad2bb3c0/numpy-2.3.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:a894f3816eb17b29e4783e5873f92faf55b710c2519e5c351767c51f79d8526d", size = 6639409, upload-time = "2025-06-21T12:19:45.228Z" }, + { url = "https://files.pythonhosted.org/packages/bf/0d/1854a4121af895aab383f4aa233748f1df4671ef331d898e32426756a8a6/numpy-2.3.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:18703df6c4a4fee55fd3d6e5a253d01c5d33a295409b03fda0c86b3ca2ff41a1", size = 14257630, upload-time = "2025-06-21T12:20:06.544Z" }, + { url = "https://files.pythonhosted.org/packages/50/30/af1b277b443f2fb08acf1c55ce9d68ee540043f158630d62cef012750f9f/numpy-2.3.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:5902660491bd7a48b2ec16c23ccb9124b8abfd9583c5fdfa123fe6b421e03de1", size = 16627546, upload-time = "2025-06-21T12:20:31.002Z" }, + { url = "https://files.pythonhosted.org/packages/6e/ec/3b68220c277e463095342d254c61be8144c31208db18d3fd8ef02712bcd6/numpy-2.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:36890eb9e9d2081137bd78d29050ba63b8dab95dff7912eadf1185e80074b2a0", size = 15562538, upload-time = "2025-06-21T12:20:54.322Z" }, + { url = "https://files.pythonhosted.org/packages/77/2b/4014f2bcc4404484021c74d4c5ee8eb3de7e3f7ac75f06672f8dcf85140a/numpy-2.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a780033466159c2270531e2b8ac063704592a0bc62ec4a1b991c7c40705eb0e8", size = 18360327, upload-time = "2025-06-21T12:21:21.053Z" }, + { url = "https://files.pythonhosted.org/packages/ea/19/a029cd335cf72f79d2644dcfc22d90f09caa86265cbbde3b5702ccef6890/numpy-2.3.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:b0b5397374f32ec0649dd98c652a1798192042e715df918c20672c62fb52d4b8", size = 20987593, upload-time = "2025-06-21T12:21:51.664Z" }, + { url = "https://files.pythonhosted.org/packages/25/91/8ea8894406209107d9ce19b66314194675d31761fe2cb3c84fe2eeae2f37/numpy-2.3.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:c5bdf2015ccfcee8253fb8be695516ac4457c743473a43290fd36eba6a1777eb", size = 14300523, upload-time = "2025-06-21T12:22:13.583Z" }, + { url = "https://files.pythonhosted.org/packages/a6/7f/06187b0066eefc9e7ce77d5f2ddb4e314a55220ad62dd0bfc9f2c44bac14/numpy-2.3.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d70f20df7f08b90a2062c1f07737dd340adccf2068d0f1b9b3d56e2038979fee", size = 5227993, upload-time = "2025-06-21T12:22:22.53Z" }, + { url = "https://files.pythonhosted.org/packages/e8/ec/a926c293c605fa75e9cfb09f1e4840098ed46d2edaa6e2152ee35dc01ed3/numpy-2.3.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:2fb86b7e58f9ac50e1e9dd1290154107e47d1eef23a0ae9145ded06ea606f992", size = 6736652, upload-time = "2025-06-21T12:22:33.629Z" }, + { url = "https://files.pythonhosted.org/packages/e3/62/d68e52fb6fde5586650d4c0ce0b05ff3a48ad4df4ffd1b8866479d1d671d/numpy-2.3.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:23ab05b2d241f76cb883ce8b9a93a680752fbfcbd51c50eff0b88b979e471d8c", size = 14331561, upload-time = "2025-06-21T12:22:55.056Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ec/b74d3f2430960044bdad6900d9f5edc2dc0fb8bf5a0be0f65287bf2cbe27/numpy-2.3.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:ce2ce9e5de4703a673e705183f64fd5da5bf36e7beddcb63a25ee2286e71ca48", size = 16693349, upload-time = "2025-06-21T12:23:20.53Z" }, + { url = "https://files.pythonhosted.org/packages/0d/15/def96774b9d7eb198ddadfcbd20281b20ebb510580419197e225f5c55c3e/numpy-2.3.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c4913079974eeb5c16ccfd2b1f09354b8fed7e0d6f2cab933104a09a6419b1ee", size = 15642053, upload-time = "2025-06-21T12:23:43.697Z" }, + { url = "https://files.pythonhosted.org/packages/2b/57/c3203974762a759540c6ae71d0ea2341c1fa41d84e4971a8e76d7141678a/numpy-2.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:010ce9b4f00d5c036053ca684c77441f2f2c934fd23bee058b4d6f196efd8280", size = 18434184, upload-time = "2025-06-21T12:24:10.708Z" }, ] [[package]] @@ -477,6 +504,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/c0/b0b508193b0e8a1654ec683ebab18d309861f8bd64e3a2f9648b80d392cb/ruff-0.11.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:51c3f95abd9331dc5b87c47ac7f376db5616041173826dfd556cfe3d4977f492", size = 11602992, upload-time = "2025-06-05T21:00:06.249Z" }, ] +[[package]] +name = "rustworkx" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a5/c4/6d6ef39e57610d54c5f106dc3dece9eebce8b9d52d561ae092e3aede1b66/rustworkx-0.16.0.tar.gz", hash = "sha256:9f0dcb83f38d5ca2c3a683eb9b6951c8aec3262fbfe5141946a7ee5ba37e0bb6", size = 349524, upload-time = "2025-01-24T01:22:34.686Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/70/36f5916aee41ffe4f604ad75742eb1bb1b849fb568e010555f9d159cd93e/rustworkx-0.16.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:476a6c67b0142acd941691943750cc6737a48372304489969c2b62d30aaf4c27", size = 2141999, upload-time = "2025-01-24T01:21:50.3Z" }, + { url = "https://files.pythonhosted.org/packages/94/47/7e7c37fb73efcc87be6414b235534605c4008a4cdbd92a61db23b878eecd/rustworkx-0.16.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:bef2ef42870f806af93979b457e240f6dfa4f867ca33965c620f3a804409ed3a", size = 1940309, upload-time = "2025-01-24T01:21:52.053Z" }, + { url = "https://files.pythonhosted.org/packages/c6/42/a6d6b3137be55ef1d887becdf6b64b0917c7d437bd483065a88500a55603/rustworkx-0.16.0-cp39-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0db3a73bf68b3e66c08322a2fc95d3aa663d037d9b4e49c3509da4898d3529cc", size = 2195350, upload-time = "2025-01-24T01:21:53.785Z" }, + { url = "https://files.pythonhosted.org/packages/59/d2/1bc99df831c132c4b7420a85ce9150e065f4c993798f31b6a4229f238398/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f12a13d7486234fa2a84746d5e41f436bf9df43548043e7a232f48804ff8c61", size = 1971689, upload-time = "2025-01-24T17:09:26.338Z" }, + { url = "https://files.pythonhosted.org/packages/b5/3b/1125e7eb834f4408bcec3cee79947efd504c715fb7ab1876f8cd4bbca497/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:89efd5c3a4653ddacc55ca39f28b261d43deec7d678f8f8fc6b76b5087f1dfea", size = 3297342, upload-time = "2025-01-24T03:18:48.885Z" }, + { url = "https://files.pythonhosted.org/packages/4f/e2/e21187b255c6211d71db0d08a44fc16771038b2af41712d66c408d9bec16/rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec0c12aac8c54910ace20ac6ada4b890cd39f95f69100514715f8ad7af9041e4", size = 2110107, upload-time = "2025-01-24T01:21:58.884Z" }, + { url = "https://files.pythonhosted.org/packages/3c/79/e3fcff21f31253ea85ef196bf2fcabad7802b11468f7d3a5d592cd0ac789/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d650e39fc1a1534335f7517358ebfc3478bb235428463cfcd7c5750d50377b33", size = 2007544, upload-time = "2025-01-26T04:16:53.807Z" }, + { url = "https://files.pythonhosted.org/packages/67/04/741ed09c2b0dc0f360f85270c1179ed433785372ac9ab6ab26d3dd3ae02d/rustworkx-0.16.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:293180b83509ee9bff4c3af7ccc1024f6528d61b65d0cb7320bd31924f10cb71", size = 2172787, upload-time = "2025-01-24T01:22:01.282Z" }, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -486,6 +532,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "starlette" +version = "0.46.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846, upload-time = "2025-04-13T13:56:17.942Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037, upload-time = "2025-04-13T13:56:16.21Z" }, +] + [[package]] name = "tqdm" version = "4.67.1" diff --git a/worker/logging.py b/worker/logging.py new file mode 100644 index 00000000..b61031be --- /dev/null +++ b/worker/logging.py @@ -0,0 +1,13 @@ +from typing import Literal +from collections.abc import Set + +from shared.logging.common import LogEntry, LogEntryType + + +class WorkerUninitialized(LogEntry[Literal["master_uninitialized"]]): + entry_destination: Set[LogEntryType] = {LogEntryType.cluster} + entry_type: Literal["master_uninitialized"] = "master_uninitialized" + message: str = "No master state found, creating new one." + + +WorkerLogEntries = WorkerUninitialized \ No newline at end of file