fix: Some, still broken

This commit is contained in:
Arbion Halili
2025-07-15 12:58:50 +01:00
committed by Matt Beton
parent 9b3c105bea
commit 9f96b6791f
9 changed files with 295 additions and 237 deletions

View File

@@ -1,193 +0,0 @@
from asyncio import Lock, Queue, Task, create_task, gather
from collections.abc import Mapping
from enum import StrEnum
from logging import Logger
from typing import Any, List, Literal, Protocol, Type, TypedDict
from master.logging import (
StateUpdateEffectHandlerErrorLogEntry,
StateUpdateErrorLogEntry,
StateUpdateLoopAlreadyRunningLogEntry,
StateUpdateLoopNotRunningLogEntry,
StateUpdateLoopStartedLogEntry,
StateUpdateLoopStoppedLogEntry,
)
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
from shared.logger import log
from shared.types.events.common import (
Apply,
EffectHandler,
EventCategories,
EventCategory,
EventCategoryEnum,
EventFetcherProtocol,
EventFromEventLog,
StateAndEvent,
State,
narrow_event_from_event_log_type,
)
class QueueMapping(TypedDict):
MutatesTaskState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
]
MutatesTaskSagaState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
]
MutatesControlPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
]
MutatesDataPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
]
MutatesRunnerStatus: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
]
MutatesInstanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
]
MutatesNodePerformanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
]
def check_keys_in_map_match_enum_values[TEnum: StrEnum](
mapping_type: Type[Mapping[Any, Any]],
enum: Type[TEnum],
) -> None:
mapping_keys = set(mapping_type.__annotations__.keys())
category_values = set(e.value for e in enum)
assert mapping_keys == category_values, (
f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}"
)
check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum)
class AsyncUpdateStateFromEvents[EventCategoryT: EventCategory](Protocol):
"""Protocol for services that manage a specific state domain."""
_task: Task[None] | None
_logger: Logger
_apply: Apply[EventCategoryT]
_default_effects: List[EffectHandler[EventCategoryT]]
extra_effects: List[EffectHandler[EventCategoryT]]
state: State[EventCategoryT]
queue: Queue[EventFromEventLog[EventCategoryT]]
lock: Lock
def __init__(
self,
state: State[EventCategoryT],
queue: Queue[EventFromEventLog[EventCategoryT]],
extra_effects: List[EffectHandler[EventCategoryT]],
logger: Logger,
) -> None:
"""Initialise the service with its event queue."""
self.state = state
self.queue = queue
self.extra_effects = extra_effects
self._logger = logger
self._task = None
async def read_state(self) -> State[EventCategoryT]:
"""Get a thread-safe snapshot of this service's state domain."""
return self.state.model_copy(deep=True)
@property
def is_running(self) -> bool:
"""Check if the service's event loop is running."""
return self._task is not None and not self._task.done()
async def start(self) -> None:
"""Start the service's event loop."""
if self.is_running:
log(self._logger, StateUpdateLoopAlreadyRunningLogEntry())
raise RuntimeError("State Update Loop Already Running")
log(self._logger, StateUpdateLoopStartedLogEntry())
self._task = create_task(self._event_loop())
async def stop(self) -> None:
"""Stop the service's event loop."""
if not self.is_running:
log(self._logger, StateUpdateLoopNotRunningLogEntry())
raise RuntimeError("State Update Loop Not Running")
assert self._task is not None, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
"BUG: is_running is True but _task is None, this should never happen!"
)
self._task.cancel()
log(self._logger, StateUpdateLoopStoppedLogEntry())
async def _event_loop(self) -> None:
"""Event loop for the service."""
while True:
event = await self.queue.get()
previous_state = self.state.model_copy(deep=True)
try:
async with self.lock:
updated_state = self._apply(
self.state,
event,
)
self.state = updated_state
except Exception as e:
log(self._logger, StateUpdateErrorLogEntry(error=e))
raise e
try:
for effect_handler in self._default_effects + self.extra_effects:
effect_handler(StateAndEvent(previous_state, event), updated_state)
except Exception as e:
log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e))
raise e
class EventRouter:
"""Routes events to appropriate services based on event categories."""
queue_map: QueueMapping
event_fetcher: EventFetcherProtocol[EventCategory]
_logger: Logger
async def _get_queue_by_category[T: EventCategory](
self, category: T
) -> Queue[EventFromEventLog[T]]:
"""Get the queue for a given category."""
category_str: str = category.value
queue: Queue[EventFromEventLog[T]] = self.queue_map[category_str]
return queue
async def _process_events[T: EventCategory](self, category: T) -> None:
"""Process events for a given domain."""
queue: Queue[EventFromEventLog[T]] = await self._get_queue_by_category(category)
events_to_process: list[EventFromEventLog[T]] = []
while not queue.empty():
events_to_process.append(await queue.get())
for event_to_process in events_to_process:
await self.queue_map[category.value].put(event_to_process)
return None
async def _submit_events[T: EventCategory | EventCategories](
self, events: list[EventFromEventLog[T]]
) -> None:
"""Route multiple events to their appropriate services."""
for event in events:
if isinstance(event.event.event_category, EventCategory):
q1: Queue[EventFromEventLog[T]] = self.queue_map[event.event.event_category.value]
await q1.put(event)
elif isinstance(event.event.event_category, EventCategories):
for category in event.event.event_category:
narrow_event = narrow_event_from_event_log_type(event, category)
q2: Queue[EventFromEventLog[T]] = self.queue_map[category.value]
await q2.put(narrow_event)
await gather(
*[self._process_events(domain) for domain in EventCategoryEnum]
)
async def _get_events_to_process(self) -> list[EventFromEventLog[EventCategories | EventCategory]]:
"""Get events to process from the event fetcher."""
...

View File

@@ -5,7 +5,9 @@ from uuid import UUID
from shared.types.events.common import EventCategory, EventId, IdemKeyGenerator, State
def get_idem_tag_generator[EventCategoryT: EventCategory](base: str) -> IdemKeyGenerator[EventCategoryT]:
def get_idem_tag_generator[EventCategoryT: EventCategory](
base: str,
) -> IdemKeyGenerator[EventCategoryT]:
"""Generates idempotency keys for events.
The keys are generated by hashing the state sequence number against a base string.
@@ -22,7 +24,9 @@ def get_idem_tag_generator[EventCategoryT: EventCategory](base: str) -> IdemKeyG
*recurse(n - 1, next_hash),
)
initial_bytes = state.last_event_applied_idx.to_bytes(8, byteorder="big", signed=False)
initial_bytes = state.last_event_applied_idx.to_bytes(
8, byteorder="big", signed=False
)
return recurse(num_keys, initial_bytes)
return get_idem_keys

View File

@@ -26,6 +26,12 @@ class MasterInvalidCommandReceivedLogEntry(
command_name: str
class MasterCommandRunnerNotRunningLogEntry: ...
class MasterStateManagerStoppedLogEntry: ...
class EventCategoryUnknownLogEntry(LogEntry[Literal["event_category_unknown"]]):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["event_category_unknown"] = "event_category_unknown"

View File

@@ -1,22 +1,21 @@
from asyncio import CancelledError, Lock, Task, create_task
from asyncio import CancelledError, Lock, Task
from asyncio import Queue as AsyncQueue
from queue import Queue as PQueue
from contextlib import asynccontextmanager
from enum import Enum
from logging import Logger, LogRecord
from typing import Annotated, Literal, Type
from queue import Queue as PQueue
from typing import Annotated, Literal
from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, TypeAdapter
from master.env import MasterEnvironmentSchema
from master.event_routing import AsyncUpdateStateFromEvents, QueueMapping
from master.logging import (
MasterCommandReceivedLogEntry,
MasterInvalidCommandReceivedLogEntry,
MasterCommandRunnerNotRunningLogEntry,
MasterStateManagerStoppedLogEntry,
MasterUninitializedLogEntry,
)
from master.state_manager.sync import SyncStateManagerMapping
from shared.constants import EXO_MASTER_STATE
from shared.logger import (
FilterLogByType,
@@ -27,11 +26,9 @@ from shared.logger import (
log,
)
from shared.types.events.common import (
Event,
EventCategory,
EventFetcherProtocol,
EventPublisher,
State,
)
from shared.types.models.common import ModelId
from shared.types.models.model import ModelInfo
@@ -81,23 +78,7 @@ ExternalCommand = Annotated[
ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalCommand)
class MasterBackgroundServices(str, Enum):
MAIN_LOOP = "main_loop"
class StateManager[T: EventCategory]:
state: State[T]
queue: AsyncQueue[Event[T]]
manager: AsyncUpdateStateFromEvents[T]
def __init__(
self,
state: State[T],
queue: AsyncQueue[Event[T]],
) -> None: ...
class MasterStateManager:
class MasterEventLoop:
"""Thread-safe manager for MasterState with independent event loop."""
def __init__(
@@ -105,7 +86,7 @@ class MasterStateManager:
initial_state: MasterState,
event_processor: EventFetcherProtocol[EventCategory],
event_publisher: EventPublisher[EventCategory],
state_updater: dict[EventCategory, AsyncUpdateStateFromEvents[EventCategory]],
state_managers: SyncStateManagerMapping,
logger: Logger,
):
self._state = initial_state
@@ -113,14 +94,19 @@ class MasterStateManager:
self._command_runner: Task[None] | None = None
self._command_queue: AsyncQueue[ExternalCommand] = AsyncQueue()
self._response_queue: AsyncQueue[Response | StreamingResponse] = AsyncQueue()
self._state_managers: dict[EventCategory, AsyncUpdateStateFromEvents[EventCategory]] = {}
self._asyncio_tasks: dict[EventCategory, Task[None]] = {}
self._state_managers: SyncStateManagerMapping
self._event_fetcher: EventFetcherProtocol[EventCategory]
self._event_fetch_task: Task[None] | None = None
self._logger = logger
@property
def _is_command_runner_running(self) -> bool:
return self._command_runner is not None and not self._command_runner.done()
@property
def _is_event_fetcher_running(self) -> bool:
return self._event_fetch_task is not None and not self._event_fetch_task.done()
async def send_command(
self, command: ExternalCommand
) -> Response | StreamingResponse:
@@ -134,19 +120,15 @@ class MasterStateManager:
async def start(self) -> None:
"""Start the background event loop."""
for category in self._state_managers:
self._asyncio_tasks[category] = create_task(
self._state_managers[category].start()
)
async def stop(self) -> None:
"""Stop the background event loop and persist state."""
if not self._is_command_runner_running:
if not self._is_command_runner_running or not self._is_event_fetcher_running:
raise RuntimeError("Command Runner Is Not Running")
assert self._command_runner is not None
assert self._command_runner is not None and self._event_fetch_task is not None
for service in [*self._asyncio_tasks.values(), self._command_runner]:
for service in [self._event_fetch_task, self._command_runner]:
service.cancel()
try:
await service
@@ -196,12 +178,14 @@ async def lifespan(app: FastAPI):
cluster_listener.start()
initial_state = get_master_state(logger)
app.state.master_state_manager = MasterStateManager(initial_state, logger)
await app.state.master_state_manager.start()
app.state.master_event_loop = MasterEventLoop(
initial_state, None, None, None, logger
)
await app.state.master_event_loop.start()
yield
await app.state.master_state_manager.stop()
await app.state.master_event_loop.stop()
app = FastAPI(lifespan=lifespan)

89
master/router.py Normal file
View File

@@ -0,0 +1,89 @@
from asyncio import Queue, gather
from logging import Logger
from typing import Literal, TypedDict
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.types.events.common import (
EventCategories,
EventCategory,
EventCategoryEnum,
EventFetcherProtocol,
EventFromEventLog,
narrow_event_from_event_log_type,
)
class QueueMapping(TypedDict):
MutatesTaskState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
]
MutatesTaskSagaState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
]
MutatesControlPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
]
MutatesDataPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
]
MutatesRunnerStatus: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
]
MutatesInstanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
]
MutatesNodePerformanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
]
check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum)
class EventRouter:
"""Routes events to appropriate services based on event categories."""
queue_map: QueueMapping
event_fetcher: EventFetcherProtocol[EventCategory]
_logger: Logger
async def _get_queue_by_category[T: EventCategory](
self, category: T
) -> Queue[EventFromEventLog[T]]:
"""Get the queue for a given category."""
category_str: str = category.value
queue: Queue[EventFromEventLog[T]] = self.queue_map[category_str]
return queue
async def _process_events[T: EventCategory](self, category: T) -> None:
"""Process events for a given domain."""
queue: Queue[EventFromEventLog[T]] = await self._get_queue_by_category(category)
events_to_process: list[EventFromEventLog[T]] = []
while not queue.empty():
events_to_process.append(await queue.get())
for event_to_process in events_to_process:
await self.queue_map[category.value].put(event_to_process)
return None
async def _submit_events[T: EventCategory | EventCategories](
self, events: list[EventFromEventLog[T]]
) -> None:
"""Route multiple events to their appropriate services."""
for event in events:
if isinstance(event.event.event_category, EventCategory):
q1: Queue[EventFromEventLog[T]] = self.queue_map[
event.event.event_category.value
]
await q1.put(event)
elif isinstance(event.event.event_category, EventCategories):
for category in event.event.event_category:
narrow_event = narrow_event_from_event_log_type(event, category)
q2: Queue[EventFromEventLog[T]] = self.queue_map[category.value]
await q2.put(narrow_event)
await gather(*[self._process_events(domain) for domain in EventCategoryEnum])
async def _get_events_to_process(
self,
) -> list[EventFromEventLog[EventCategories | EventCategory]]:
"""Get events to process from the event fetcher."""

13
master/sanity_checking.py Normal file
View File

@@ -0,0 +1,13 @@
from enum import StrEnum
from typing import Any, Mapping, Type
def check_keys_in_map_match_enum_values[TEnum: StrEnum](
mapping_type: Type[Mapping[Any, Any]],
enum: Type[TEnum],
) -> None:
mapping_keys = set(mapping_type.__annotations__.keys())
category_values = set(e.value for e in enum)
assert mapping_keys == category_values, (
f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}"
)

View File

@@ -0,0 +1,128 @@
from asyncio import Lock, Queue, Task, create_task
from logging import Logger
from typing import List, Literal, Protocol, TypedDict
from master.logging import (
StateUpdateEffectHandlerErrorLogEntry,
StateUpdateErrorLogEntry,
StateUpdateLoopAlreadyRunningLogEntry,
StateUpdateLoopNotRunningLogEntry,
StateUpdateLoopStartedLogEntry,
StateUpdateLoopStoppedLogEntry,
)
from master.router import check_keys_in_map_match_enum_values
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
from shared.logger import log
from shared.types.events.common import (
Apply,
EffectHandler,
EventCategory,
EventCategoryEnum,
EventFromEventLog,
State,
StateAndEvent,
)
class AsyncStateManager[EventCategoryT: EventCategory](Protocol):
"""Protocol for services that manage a specific state domain."""
_task: Task[None] | None
_logger: Logger
_apply: Apply[EventCategoryT]
_default_effects: List[EffectHandler[EventCategoryT]]
extra_effects: List[EffectHandler[EventCategoryT]]
state: State[EventCategoryT]
queue: Queue[EventFromEventLog[EventCategoryT]]
lock: Lock
def __init__(
self,
state: State[EventCategoryT],
queue: Queue[EventFromEventLog[EventCategoryT]],
extra_effects: List[EffectHandler[EventCategoryT]],
logger: Logger,
) -> None:
"""Initialise the service with its event queue."""
self.state = state
self.queue = queue
self.extra_effects = extra_effects
self._logger = logger
self._task = None
async def read_state(self) -> State[EventCategoryT]:
"""Get a thread-safe snapshot of this service's state domain."""
return self.state.model_copy(deep=True)
@property
def is_running(self) -> bool:
"""Check if the service's event loop is running."""
return self._task is not None and not self._task.done()
async def start(self) -> None:
"""Start the service's event loop."""
if self.is_running:
log(self._logger, StateUpdateLoopAlreadyRunningLogEntry())
raise RuntimeError("State Update Loop Already Running")
log(self._logger, StateUpdateLoopStartedLogEntry())
self._task = create_task(self._event_loop())
async def stop(self) -> None:
"""Stop the service's event loop."""
if not self.is_running:
log(self._logger, StateUpdateLoopNotRunningLogEntry())
raise RuntimeError("State Update Loop Not Running")
assert self._task is not None, (
f"{EXO_ERROR_REPORTING_MESSAGE()}"
"BUG: is_running is True but _task is None, this should never happen!"
)
self._task.cancel()
log(self._logger, StateUpdateLoopStoppedLogEntry())
async def _event_loop(self) -> None:
"""Event loop for the service."""
while True:
event = await self.queue.get()
previous_state = self.state.model_copy(deep=True)
try:
async with self.lock:
updated_state = self._apply(
self.state,
event,
)
self.state = updated_state
except Exception as e:
log(self._logger, StateUpdateErrorLogEntry(error=e))
raise e
try:
for effect_handler in self._default_effects + self.extra_effects:
effect_handler(StateAndEvent(previous_state, event), updated_state)
except Exception as e:
log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e))
raise e
class AsyncStateManagerMapping(TypedDict):
MutatesTaskState: AsyncStateManager[Literal[EventCategoryEnum.MutatesTaskState]]
MutatesTaskSagaState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesTaskSagaState]
]
MutatesControlPlaneState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesControlPlaneState]
]
MutatesDataPlaneState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesDataPlaneState]
]
MutatesRunnerStatus: AsyncStateManager[
Literal[EventCategoryEnum.MutatesRunnerStatus]
]
MutatesInstanceState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesInstanceState]
]
MutatesNodePerformanceState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesNodePerformanceState]
]
check_keys_in_map_match_enum_values(AsyncStateManagerMapping, EventCategoryEnum)

View File

@@ -0,0 +1,19 @@
from typing import Literal, TypedDict
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.types.events.common import EventCategoryEnum, State
class SyncStateManagerMapping(TypedDict):
MutatesTaskState: State[Literal[EventCategoryEnum.MutatesTaskState]]
MutatesTaskSagaState: State[Literal[EventCategoryEnum.MutatesTaskSagaState]]
MutatesControlPlaneState: State[Literal[EventCategoryEnum.MutatesControlPlaneState]]
MutatesDataPlaneState: State[Literal[EventCategoryEnum.MutatesDataPlaneState]]
MutatesRunnerStatus: State[Literal[EventCategoryEnum.MutatesRunnerStatus]]
MutatesInstanceState: State[Literal[EventCategoryEnum.MutatesInstanceState]]
MutatesNodePerformanceState: State[
Literal[EventCategoryEnum.MutatesNodePerformanceState]
]
check_keys_in_map_match_enum_values(SyncStateManagerMapping, EventCategoryEnum)

View File

@@ -164,13 +164,19 @@ def narrow_event_type[T: EventCategory, Q: EventCategories | EventCategory](
narrowed_event = event.model_copy(update={"event_category": {target_category}})
return cast(Event[T], narrowed_event)
def narrow_event_from_event_log_type[T: EventCategory, Q: EventCategories | EventCategory](
def narrow_event_from_event_log_type[
T: EventCategory,
Q: EventCategories | EventCategory,
](
event: EventFromEventLog[Q],
target_category: T,
) -> EventFromEventLog[T]:
if target_category not in event.event.event_category:
raise ValueError(f"Event Does Not Contain Target Category {target_category}")
narrowed_event = event.model_copy(update={"event": narrow_event_type(event.event, target_category)})
narrowed_event = event.model_copy(
update={"event": narrow_event_type(event.event, target_category)}
)
return cast(EventFromEventLog[T], narrowed_event)
@@ -199,7 +205,9 @@ class StateAndEvent[EventCategoryT: EventCategory](NamedTuple):
type EffectHandler[EventCategoryT: EventCategory] = Callable[
[StateAndEvent[EventCategoryT], State[EventCategoryT]], None
]
type EventPublisher[EventCategoryT: EventCategory] = Callable[[Event[EventCategoryT]], None]
type EventPublisher[EventCategoryT: EventCategory] = Callable[
[Event[EventCategoryT]], None
]
# A component that can publish events