diff --git a/master/event_routing.py b/master/event_routing.py deleted file mode 100644 index d4697756..00000000 --- a/master/event_routing.py +++ /dev/null @@ -1,193 +0,0 @@ -from asyncio import Lock, Queue, Task, create_task, gather -from collections.abc import Mapping -from enum import StrEnum -from logging import Logger -from typing import Any, List, Literal, Protocol, Type, TypedDict - -from master.logging import ( - StateUpdateEffectHandlerErrorLogEntry, - StateUpdateErrorLogEntry, - StateUpdateLoopAlreadyRunningLogEntry, - StateUpdateLoopNotRunningLogEntry, - StateUpdateLoopStartedLogEntry, - StateUpdateLoopStoppedLogEntry, -) -from shared.constants import EXO_ERROR_REPORTING_MESSAGE -from shared.logger import log -from shared.types.events.common import ( - Apply, - EffectHandler, - EventCategories, - EventCategory, - EventCategoryEnum, - EventFetcherProtocol, - EventFromEventLog, - StateAndEvent, - State, - narrow_event_from_event_log_type, -) - - -class QueueMapping(TypedDict): - MutatesTaskState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]] - ] - MutatesTaskSagaState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]] - ] - MutatesControlPlaneState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]] - ] - MutatesDataPlaneState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]] - ] - MutatesRunnerStatus: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]] - ] - MutatesInstanceState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]] - ] - MutatesNodePerformanceState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]] - ] - - -def check_keys_in_map_match_enum_values[TEnum: StrEnum]( - mapping_type: Type[Mapping[Any, Any]], - enum: Type[TEnum], -) -> None: - mapping_keys = set(mapping_type.__annotations__.keys()) - category_values = set(e.value for e in enum) - assert mapping_keys == category_values, ( - f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}" - ) - - -check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum) - - -class AsyncUpdateStateFromEvents[EventCategoryT: EventCategory](Protocol): - """Protocol for services that manage a specific state domain.""" - - _task: Task[None] | None - _logger: Logger - _apply: Apply[EventCategoryT] - _default_effects: List[EffectHandler[EventCategoryT]] - extra_effects: List[EffectHandler[EventCategoryT]] - state: State[EventCategoryT] - queue: Queue[EventFromEventLog[EventCategoryT]] - lock: Lock - - def __init__( - self, - state: State[EventCategoryT], - queue: Queue[EventFromEventLog[EventCategoryT]], - extra_effects: List[EffectHandler[EventCategoryT]], - logger: Logger, - ) -> None: - """Initialise the service with its event queue.""" - self.state = state - self.queue = queue - self.extra_effects = extra_effects - self._logger = logger - self._task = None - - async def read_state(self) -> State[EventCategoryT]: - """Get a thread-safe snapshot of this service's state domain.""" - return self.state.model_copy(deep=True) - - @property - def is_running(self) -> bool: - """Check if the service's event loop is running.""" - return self._task is not None and not self._task.done() - - async def start(self) -> None: - """Start the service's event loop.""" - if self.is_running: - log(self._logger, StateUpdateLoopAlreadyRunningLogEntry()) - raise RuntimeError("State Update Loop Already Running") - log(self._logger, StateUpdateLoopStartedLogEntry()) - self._task = create_task(self._event_loop()) - - async def stop(self) -> None: - """Stop the service's event loop.""" - if not self.is_running: - log(self._logger, StateUpdateLoopNotRunningLogEntry()) - raise RuntimeError("State Update Loop Not Running") - - assert self._task is not None, ( - f"{EXO_ERROR_REPORTING_MESSAGE()}" - "BUG: is_running is True but _task is None, this should never happen!" - ) - self._task.cancel() - log(self._logger, StateUpdateLoopStoppedLogEntry()) - - async def _event_loop(self) -> None: - """Event loop for the service.""" - while True: - event = await self.queue.get() - previous_state = self.state.model_copy(deep=True) - try: - async with self.lock: - updated_state = self._apply( - self.state, - event, - ) - self.state = updated_state - except Exception as e: - log(self._logger, StateUpdateErrorLogEntry(error=e)) - raise e - try: - for effect_handler in self._default_effects + self.extra_effects: - effect_handler(StateAndEvent(previous_state, event), updated_state) - except Exception as e: - log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e)) - raise e - - -class EventRouter: - """Routes events to appropriate services based on event categories.""" - - queue_map: QueueMapping - event_fetcher: EventFetcherProtocol[EventCategory] - _logger: Logger - - async def _get_queue_by_category[T: EventCategory]( - self, category: T - ) -> Queue[EventFromEventLog[T]]: - """Get the queue for a given category.""" - category_str: str = category.value - queue: Queue[EventFromEventLog[T]] = self.queue_map[category_str] - return queue - - async def _process_events[T: EventCategory](self, category: T) -> None: - """Process events for a given domain.""" - queue: Queue[EventFromEventLog[T]] = await self._get_queue_by_category(category) - events_to_process: list[EventFromEventLog[T]] = [] - while not queue.empty(): - events_to_process.append(await queue.get()) - for event_to_process in events_to_process: - await self.queue_map[category.value].put(event_to_process) - return None - - async def _submit_events[T: EventCategory | EventCategories]( - self, events: list[EventFromEventLog[T]] - ) -> None: - """Route multiple events to their appropriate services.""" - for event in events: - if isinstance(event.event.event_category, EventCategory): - q1: Queue[EventFromEventLog[T]] = self.queue_map[event.event.event_category.value] - await q1.put(event) - elif isinstance(event.event.event_category, EventCategories): - for category in event.event.event_category: - narrow_event = narrow_event_from_event_log_type(event, category) - q2: Queue[EventFromEventLog[T]] = self.queue_map[category.value] - await q2.put(narrow_event) - - await gather( - *[self._process_events(domain) for domain in EventCategoryEnum] - ) - - async def _get_events_to_process(self) -> list[EventFromEventLog[EventCategories | EventCategory]]: - """Get events to process from the event fetcher.""" - ... diff --git a/master/idempotency.py b/master/idempotency.py index b4761707..2216da1b 100644 --- a/master/idempotency.py +++ b/master/idempotency.py @@ -5,7 +5,9 @@ from uuid import UUID from shared.types.events.common import EventCategory, EventId, IdemKeyGenerator, State -def get_idem_tag_generator[EventCategoryT: EventCategory](base: str) -> IdemKeyGenerator[EventCategoryT]: +def get_idem_tag_generator[EventCategoryT: EventCategory]( + base: str, +) -> IdemKeyGenerator[EventCategoryT]: """Generates idempotency keys for events. The keys are generated by hashing the state sequence number against a base string. @@ -22,7 +24,9 @@ def get_idem_tag_generator[EventCategoryT: EventCategory](base: str) -> IdemKeyG *recurse(n - 1, next_hash), ) - initial_bytes = state.last_event_applied_idx.to_bytes(8, byteorder="big", signed=False) + initial_bytes = state.last_event_applied_idx.to_bytes( + 8, byteorder="big", signed=False + ) return recurse(num_keys, initial_bytes) return get_idem_keys diff --git a/master/logging.py b/master/logging.py index 81e61dd4..f6df8808 100644 --- a/master/logging.py +++ b/master/logging.py @@ -26,6 +26,12 @@ class MasterInvalidCommandReceivedLogEntry( command_name: str +class MasterCommandRunnerNotRunningLogEntry: ... + + +class MasterStateManagerStoppedLogEntry: ... + + class EventCategoryUnknownLogEntry(LogEntry[Literal["event_category_unknown"]]): entry_destination: Set[LogEntryType] = {LogEntryType.cluster} entry_type: Literal["event_category_unknown"] = "event_category_unknown" diff --git a/master/main.py b/master/main.py index c5991806..9890003f 100644 --- a/master/main.py +++ b/master/main.py @@ -1,22 +1,21 @@ -from asyncio import CancelledError, Lock, Task, create_task +from asyncio import CancelledError, Lock, Task from asyncio import Queue as AsyncQueue -from queue import Queue as PQueue from contextlib import asynccontextmanager -from enum import Enum from logging import Logger, LogRecord -from typing import Annotated, Literal, Type +from queue import Queue as PQueue +from typing import Annotated, Literal from fastapi import FastAPI, Response from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field, TypeAdapter from master.env import MasterEnvironmentSchema -from master.event_routing import AsyncUpdateStateFromEvents, QueueMapping from master.logging import ( - MasterCommandReceivedLogEntry, - MasterInvalidCommandReceivedLogEntry, + MasterCommandRunnerNotRunningLogEntry, + MasterStateManagerStoppedLogEntry, MasterUninitializedLogEntry, ) +from master.state_manager.sync import SyncStateManagerMapping from shared.constants import EXO_MASTER_STATE from shared.logger import ( FilterLogByType, @@ -27,11 +26,9 @@ from shared.logger import ( log, ) from shared.types.events.common import ( - Event, EventCategory, EventFetcherProtocol, EventPublisher, - State, ) from shared.types.models.common import ModelId from shared.types.models.model import ModelInfo @@ -81,23 +78,7 @@ ExternalCommand = Annotated[ ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalCommand) -class MasterBackgroundServices(str, Enum): - MAIN_LOOP = "main_loop" - - -class StateManager[T: EventCategory]: - state: State[T] - queue: AsyncQueue[Event[T]] - manager: AsyncUpdateStateFromEvents[T] - - def __init__( - self, - state: State[T], - queue: AsyncQueue[Event[T]], - ) -> None: ... - - -class MasterStateManager: +class MasterEventLoop: """Thread-safe manager for MasterState with independent event loop.""" def __init__( @@ -105,7 +86,7 @@ class MasterStateManager: initial_state: MasterState, event_processor: EventFetcherProtocol[EventCategory], event_publisher: EventPublisher[EventCategory], - state_updater: dict[EventCategory, AsyncUpdateStateFromEvents[EventCategory]], + state_managers: SyncStateManagerMapping, logger: Logger, ): self._state = initial_state @@ -113,14 +94,19 @@ class MasterStateManager: self._command_runner: Task[None] | None = None self._command_queue: AsyncQueue[ExternalCommand] = AsyncQueue() self._response_queue: AsyncQueue[Response | StreamingResponse] = AsyncQueue() - self._state_managers: dict[EventCategory, AsyncUpdateStateFromEvents[EventCategory]] = {} - self._asyncio_tasks: dict[EventCategory, Task[None]] = {} + self._state_managers: SyncStateManagerMapping + self._event_fetcher: EventFetcherProtocol[EventCategory] + self._event_fetch_task: Task[None] | None = None self._logger = logger @property def _is_command_runner_running(self) -> bool: return self._command_runner is not None and not self._command_runner.done() + @property + def _is_event_fetcher_running(self) -> bool: + return self._event_fetch_task is not None and not self._event_fetch_task.done() + async def send_command( self, command: ExternalCommand ) -> Response | StreamingResponse: @@ -134,19 +120,15 @@ class MasterStateManager: async def start(self) -> None: """Start the background event loop.""" - for category in self._state_managers: - self._asyncio_tasks[category] = create_task( - self._state_managers[category].start() - ) async def stop(self) -> None: """Stop the background event loop and persist state.""" - if not self._is_command_runner_running: + if not self._is_command_runner_running or not self._is_event_fetcher_running: raise RuntimeError("Command Runner Is Not Running") - assert self._command_runner is not None + assert self._command_runner is not None and self._event_fetch_task is not None - for service in [*self._asyncio_tasks.values(), self._command_runner]: + for service in [self._event_fetch_task, self._command_runner]: service.cancel() try: await service @@ -196,12 +178,14 @@ async def lifespan(app: FastAPI): cluster_listener.start() initial_state = get_master_state(logger) - app.state.master_state_manager = MasterStateManager(initial_state, logger) - await app.state.master_state_manager.start() + app.state.master_event_loop = MasterEventLoop( + initial_state, None, None, None, logger + ) + await app.state.master_event_loop.start() yield - await app.state.master_state_manager.stop() + await app.state.master_event_loop.stop() app = FastAPI(lifespan=lifespan) diff --git a/master/router.py b/master/router.py new file mode 100644 index 00000000..6da8359a --- /dev/null +++ b/master/router.py @@ -0,0 +1,89 @@ +from asyncio import Queue, gather +from logging import Logger +from typing import Literal, TypedDict + +from master.sanity_checking import check_keys_in_map_match_enum_values +from shared.types.events.common import ( + EventCategories, + EventCategory, + EventCategoryEnum, + EventFetcherProtocol, + EventFromEventLog, + narrow_event_from_event_log_type, +) + + +class QueueMapping(TypedDict): + MutatesTaskState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]] + ] + MutatesTaskSagaState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]] + ] + MutatesControlPlaneState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]] + ] + MutatesDataPlaneState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]] + ] + MutatesRunnerStatus: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]] + ] + MutatesInstanceState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]] + ] + MutatesNodePerformanceState: Queue[ + EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]] + ] + + +check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum) + + +class EventRouter: + """Routes events to appropriate services based on event categories.""" + + queue_map: QueueMapping + event_fetcher: EventFetcherProtocol[EventCategory] + _logger: Logger + + async def _get_queue_by_category[T: EventCategory]( + self, category: T + ) -> Queue[EventFromEventLog[T]]: + """Get the queue for a given category.""" + category_str: str = category.value + queue: Queue[EventFromEventLog[T]] = self.queue_map[category_str] + return queue + + async def _process_events[T: EventCategory](self, category: T) -> None: + """Process events for a given domain.""" + queue: Queue[EventFromEventLog[T]] = await self._get_queue_by_category(category) + events_to_process: list[EventFromEventLog[T]] = [] + while not queue.empty(): + events_to_process.append(await queue.get()) + for event_to_process in events_to_process: + await self.queue_map[category.value].put(event_to_process) + return None + + async def _submit_events[T: EventCategory | EventCategories]( + self, events: list[EventFromEventLog[T]] + ) -> None: + """Route multiple events to their appropriate services.""" + for event in events: + if isinstance(event.event.event_category, EventCategory): + q1: Queue[EventFromEventLog[T]] = self.queue_map[ + event.event.event_category.value + ] + await q1.put(event) + elif isinstance(event.event.event_category, EventCategories): + for category in event.event.event_category: + narrow_event = narrow_event_from_event_log_type(event, category) + q2: Queue[EventFromEventLog[T]] = self.queue_map[category.value] + await q2.put(narrow_event) + + await gather(*[self._process_events(domain) for domain in EventCategoryEnum]) + + async def _get_events_to_process( + self, + ) -> list[EventFromEventLog[EventCategories | EventCategory]]: + """Get events to process from the event fetcher.""" diff --git a/master/sanity_checking.py b/master/sanity_checking.py new file mode 100644 index 00000000..b472b9be --- /dev/null +++ b/master/sanity_checking.py @@ -0,0 +1,13 @@ +from enum import StrEnum +from typing import Any, Mapping, Type + + +def check_keys_in_map_match_enum_values[TEnum: StrEnum]( + mapping_type: Type[Mapping[Any, Any]], + enum: Type[TEnum], +) -> None: + mapping_keys = set(mapping_type.__annotations__.keys()) + category_values = set(e.value for e in enum) + assert mapping_keys == category_values, ( + f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}" + ) diff --git a/master/state_manager/async.py b/master/state_manager/async.py new file mode 100644 index 00000000..dcddfa25 --- /dev/null +++ b/master/state_manager/async.py @@ -0,0 +1,128 @@ +from asyncio import Lock, Queue, Task, create_task +from logging import Logger +from typing import List, Literal, Protocol, TypedDict + +from master.logging import ( + StateUpdateEffectHandlerErrorLogEntry, + StateUpdateErrorLogEntry, + StateUpdateLoopAlreadyRunningLogEntry, + StateUpdateLoopNotRunningLogEntry, + StateUpdateLoopStartedLogEntry, + StateUpdateLoopStoppedLogEntry, +) +from master.router import check_keys_in_map_match_enum_values +from shared.constants import EXO_ERROR_REPORTING_MESSAGE +from shared.logger import log +from shared.types.events.common import ( + Apply, + EffectHandler, + EventCategory, + EventCategoryEnum, + EventFromEventLog, + State, + StateAndEvent, +) + + +class AsyncStateManager[EventCategoryT: EventCategory](Protocol): + """Protocol for services that manage a specific state domain.""" + + _task: Task[None] | None + _logger: Logger + _apply: Apply[EventCategoryT] + _default_effects: List[EffectHandler[EventCategoryT]] + extra_effects: List[EffectHandler[EventCategoryT]] + state: State[EventCategoryT] + queue: Queue[EventFromEventLog[EventCategoryT]] + lock: Lock + + def __init__( + self, + state: State[EventCategoryT], + queue: Queue[EventFromEventLog[EventCategoryT]], + extra_effects: List[EffectHandler[EventCategoryT]], + logger: Logger, + ) -> None: + """Initialise the service with its event queue.""" + self.state = state + self.queue = queue + self.extra_effects = extra_effects + self._logger = logger + self._task = None + + async def read_state(self) -> State[EventCategoryT]: + """Get a thread-safe snapshot of this service's state domain.""" + return self.state.model_copy(deep=True) + + @property + def is_running(self) -> bool: + """Check if the service's event loop is running.""" + return self._task is not None and not self._task.done() + + async def start(self) -> None: + """Start the service's event loop.""" + if self.is_running: + log(self._logger, StateUpdateLoopAlreadyRunningLogEntry()) + raise RuntimeError("State Update Loop Already Running") + log(self._logger, StateUpdateLoopStartedLogEntry()) + self._task = create_task(self._event_loop()) + + async def stop(self) -> None: + """Stop the service's event loop.""" + if not self.is_running: + log(self._logger, StateUpdateLoopNotRunningLogEntry()) + raise RuntimeError("State Update Loop Not Running") + + assert self._task is not None, ( + f"{EXO_ERROR_REPORTING_MESSAGE()}" + "BUG: is_running is True but _task is None, this should never happen!" + ) + self._task.cancel() + log(self._logger, StateUpdateLoopStoppedLogEntry()) + + async def _event_loop(self) -> None: + """Event loop for the service.""" + while True: + event = await self.queue.get() + previous_state = self.state.model_copy(deep=True) + try: + async with self.lock: + updated_state = self._apply( + self.state, + event, + ) + self.state = updated_state + except Exception as e: + log(self._logger, StateUpdateErrorLogEntry(error=e)) + raise e + try: + for effect_handler in self._default_effects + self.extra_effects: + effect_handler(StateAndEvent(previous_state, event), updated_state) + except Exception as e: + log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e)) + raise e + + +class AsyncStateManagerMapping(TypedDict): + MutatesTaskState: AsyncStateManager[Literal[EventCategoryEnum.MutatesTaskState]] + MutatesTaskSagaState: AsyncStateManager[ + Literal[EventCategoryEnum.MutatesTaskSagaState] + ] + MutatesControlPlaneState: AsyncStateManager[ + Literal[EventCategoryEnum.MutatesControlPlaneState] + ] + MutatesDataPlaneState: AsyncStateManager[ + Literal[EventCategoryEnum.MutatesDataPlaneState] + ] + MutatesRunnerStatus: AsyncStateManager[ + Literal[EventCategoryEnum.MutatesRunnerStatus] + ] + MutatesInstanceState: AsyncStateManager[ + Literal[EventCategoryEnum.MutatesInstanceState] + ] + MutatesNodePerformanceState: AsyncStateManager[ + Literal[EventCategoryEnum.MutatesNodePerformanceState] + ] + + +check_keys_in_map_match_enum_values(AsyncStateManagerMapping, EventCategoryEnum) diff --git a/master/state_manager/sync.py b/master/state_manager/sync.py new file mode 100644 index 00000000..b411447e --- /dev/null +++ b/master/state_manager/sync.py @@ -0,0 +1,19 @@ +from typing import Literal, TypedDict + +from master.sanity_checking import check_keys_in_map_match_enum_values +from shared.types.events.common import EventCategoryEnum, State + + +class SyncStateManagerMapping(TypedDict): + MutatesTaskState: State[Literal[EventCategoryEnum.MutatesTaskState]] + MutatesTaskSagaState: State[Literal[EventCategoryEnum.MutatesTaskSagaState]] + MutatesControlPlaneState: State[Literal[EventCategoryEnum.MutatesControlPlaneState]] + MutatesDataPlaneState: State[Literal[EventCategoryEnum.MutatesDataPlaneState]] + MutatesRunnerStatus: State[Literal[EventCategoryEnum.MutatesRunnerStatus]] + MutatesInstanceState: State[Literal[EventCategoryEnum.MutatesInstanceState]] + MutatesNodePerformanceState: State[ + Literal[EventCategoryEnum.MutatesNodePerformanceState] + ] + + +check_keys_in_map_match_enum_values(SyncStateManagerMapping, EventCategoryEnum) diff --git a/shared/types/events/common.py b/shared/types/events/common.py index 3c9e9e2c..364d256f 100644 --- a/shared/types/events/common.py +++ b/shared/types/events/common.py @@ -164,13 +164,19 @@ def narrow_event_type[T: EventCategory, Q: EventCategories | EventCategory]( narrowed_event = event.model_copy(update={"event_category": {target_category}}) return cast(Event[T], narrowed_event) -def narrow_event_from_event_log_type[T: EventCategory, Q: EventCategories | EventCategory]( + +def narrow_event_from_event_log_type[ + T: EventCategory, + Q: EventCategories | EventCategory, +]( event: EventFromEventLog[Q], target_category: T, ) -> EventFromEventLog[T]: if target_category not in event.event.event_category: raise ValueError(f"Event Does Not Contain Target Category {target_category}") - narrowed_event = event.model_copy(update={"event": narrow_event_type(event.event, target_category)}) + narrowed_event = event.model_copy( + update={"event": narrow_event_type(event.event, target_category)} + ) return cast(EventFromEventLog[T], narrowed_event) @@ -199,7 +205,9 @@ class StateAndEvent[EventCategoryT: EventCategory](NamedTuple): type EffectHandler[EventCategoryT: EventCategory] = Callable[ [StateAndEvent[EventCategoryT], State[EventCategoryT]], None ] -type EventPublisher[EventCategoryT: EventCategory] = Callable[[Event[EventCategoryT]], None] +type EventPublisher[EventCategoryT: EventCategory] = Callable[ + [Event[EventCategoryT]], None +] # A component that can publish events