From 5adad08e09102f1ce12a9eb3fc19fca671bfe692 Mon Sep 17 00:00:00 2001 From: Matt Beton Date: Tue, 22 Jul 2025 15:16:06 +0100 Subject: [PATCH] New events --- master/idempotency.py | 32 ---- master/main.py | 17 +- master/placement.py | 4 +- master/sanity_checking.py | 13 -- master/state_manager/async.py | 125 ------------- master/state_manager/sync.py | 18 -- shared/db/sqlite/connector.py | 28 ++- shared/db/sqlite/types.py | 15 +- shared/event_loops/main.py | 42 ++--- shared/event_loops/router.py | 78 -------- shared/tests/test_sqlite_connector.py | 37 ++-- shared/types/events/categories.py | 10 + shared/types/events/commands.py | 38 ++++ shared/types/events/common.py | 246 +++---------------------- shared/types/events/components.py | 38 ++++ shared/types/events/events.py | 136 +++++--------- shared/types/events/registry.py | 124 +++++-------- shared/types/events/sanity_checking.py | 111 +++++------ shared/types/state.py | 3 + 19 files changed, 318 insertions(+), 797 deletions(-) delete mode 100644 master/idempotency.py delete mode 100644 master/sanity_checking.py delete mode 100644 master/state_manager/async.py delete mode 100644 master/state_manager/sync.py delete mode 100644 shared/event_loops/router.py create mode 100644 shared/types/events/categories.py create mode 100644 shared/types/events/commands.py create mode 100644 shared/types/events/components.py diff --git a/master/idempotency.py b/master/idempotency.py deleted file mode 100644 index 2216da1b..00000000 --- a/master/idempotency.py +++ /dev/null @@ -1,32 +0,0 @@ -from hashlib import sha3_224 as hasher -from typing import Sequence -from uuid import UUID - -from shared.types.events.common import EventCategory, EventId, IdemKeyGenerator, State - - -def get_idem_tag_generator[EventCategoryT: EventCategory]( - base: str, -) -> IdemKeyGenerator[EventCategoryT]: - """Generates idempotency keys for events. - - The keys are generated by hashing the state sequence number against a base string. - You can pick any base string, **so long as it's not used in any other function that generates idempotency keys**. - """ - - def get_idem_keys(state: State[EventCategoryT], num_keys: int) -> Sequence[EventId]: - def recurse(n: int, last: bytes) -> Sequence[EventId]: - if n == 0: - return [] - next_hash = hasher(last).digest() - return ( - EventId(UUID(bytes=next_hash, version=4)), - *recurse(n - 1, next_hash), - ) - - initial_bytes = state.last_event_applied_idx.to_bytes( - 8, byteorder="big", signed=False - ) - return recurse(num_keys, initial_bytes) - - return get_idem_keys diff --git a/master/main.py b/master/main.py index 9a131e0e..8e4dadeb 100644 --- a/master/main.py +++ b/master/main.py @@ -1,7 +1,6 @@ from contextlib import asynccontextmanager from logging import Logger, LogRecord from queue import Queue as PQueue -from typing import Literal from fastapi import FastAPI @@ -19,9 +18,6 @@ from shared.logger import ( create_queue_listener, log, ) -from shared.types.events.common import ( - EventCategoryEnum, -) from shared.types.models import ModelId, ModelMetadata from shared.types.state import State from shared.types.worker.common import InstanceId @@ -51,19 +47,8 @@ def get_state_dependency(data: object, logger: Logger) -> State: return data -# What The Master Cares About -MasterEventCategories = ( - Literal[EventCategoryEnum.MutatesTopologyState] - | Literal[EventCategoryEnum.MutatesTaskState] - | Literal[EventCategoryEnum.MutatesTaskSagaState] - | Literal[EventCategoryEnum.MutatesRunnerStatus] - | Literal[EventCategoryEnum.MutatesInstanceState] - | Literal[EventCategoryEnum.MutatesNodePerformanceState] -) - - # Takes Care Of All States And Events Related To The Master -class MasterEventLoopProtocol(NodeEventLoopProtocol[MasterEventCategories]): ... +class MasterEventLoopProtocol(NodeEventLoopProtocol): ... @asynccontextmanager diff --git a/master/placement.py b/master/placement.py index 2eaf9ad0..9803816f 100644 --- a/master/placement.py +++ b/master/placement.py @@ -1,7 +1,7 @@ from queue import Queue from typing import Mapping, Sequence -from shared.types.events.common import BaseEvent, EventCategory +from shared.types.events.registry import Event from shared.types.graphs.topology import Topology from shared.types.state import CachePolicy from shared.types.tasks.common import Task @@ -20,4 +20,4 @@ def get_instance_placement( def get_transition_events( current_instances: Mapping[InstanceId, InstanceParams], target_instances: Mapping[InstanceId, InstanceParams], -) -> Sequence[BaseEvent[EventCategory]]: ... +) -> Sequence[Event]: ... diff --git a/master/sanity_checking.py b/master/sanity_checking.py deleted file mode 100644 index b472b9be..00000000 --- a/master/sanity_checking.py +++ /dev/null @@ -1,13 +0,0 @@ -from enum import StrEnum -from typing import Any, Mapping, Type - - -def check_keys_in_map_match_enum_values[TEnum: StrEnum]( - mapping_type: Type[Mapping[Any, Any]], - enum: Type[TEnum], -) -> None: - mapping_keys = set(mapping_type.__annotations__.keys()) - category_values = set(e.value for e in enum) - assert mapping_keys == category_values, ( - f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}" - ) diff --git a/master/state_manager/async.py b/master/state_manager/async.py deleted file mode 100644 index 4774d786..00000000 --- a/master/state_manager/async.py +++ /dev/null @@ -1,125 +0,0 @@ -from asyncio import Lock, Queue, Task, create_task -from logging import Logger -from typing import List, Literal, Protocol, TypedDict - -from master.logging import ( - StateUpdateEffectHandlerErrorLogEntry, - StateUpdateErrorLogEntry, - StateUpdateLoopAlreadyRunningLogEntry, - StateUpdateLoopNotRunningLogEntry, - StateUpdateLoopStartedLogEntry, - StateUpdateLoopStoppedLogEntry, -) -from master.sanity_checking import check_keys_in_map_match_enum_values -from shared.constants import get_error_reporting_message -from shared.logger import log -from shared.types.events.common import ( - Apply, - EffectHandler, - EventCategory, - EventCategoryEnum, - EventFromEventLog, - State, - StateAndEvent, -) - - -class AsyncStateManager[EventCategoryT: EventCategory](Protocol): - """Protocol for services that manage a specific state domain.""" - - _task: Task[None] | None - _logger: Logger - _apply: Apply[EventCategoryT] - _default_effects: List[EffectHandler[EventCategoryT]] - extra_effects: List[EffectHandler[EventCategoryT]] - state: State[EventCategoryT] - queue: Queue[EventFromEventLog[EventCategoryT]] - lock: Lock - - def __init__( - self, - state: State[EventCategoryT], - queue: Queue[EventFromEventLog[EventCategoryT]], - extra_effects: List[EffectHandler[EventCategoryT]], - logger: Logger, - ) -> None: - """Initialise the service with its event queue.""" - self.state = state - self.queue = queue - self.extra_effects = extra_effects - self._logger = logger - self._task = None - - async def read_state(self) -> State[EventCategoryT]: - """Get a thread-safe snapshot of this service's state domain.""" - return self.state.model_copy(deep=True) - - @property - def is_running(self) -> bool: - """Check if the service's event loop is running.""" - return self._task is not None and not self._task.done() - - async def start(self) -> None: - """Start the service's event loop.""" - if self.is_running: - log(self._logger, StateUpdateLoopAlreadyRunningLogEntry()) - raise RuntimeError("State Update Loop Already Running") - log(self._logger, StateUpdateLoopStartedLogEntry()) - self._task = create_task(self._event_loop()) - - async def stop(self) -> None: - """Stop the service's event loop.""" - if not self.is_running: - log(self._logger, StateUpdateLoopNotRunningLogEntry()) - raise RuntimeError("State Update Loop Not Running") - - assert self._task is not None, ( - f"{get_error_reporting_message()}" - "BUG: is_running is True but _task is None, this should never happen!" - ) - self._task.cancel() - log(self._logger, StateUpdateLoopStoppedLogEntry()) - - async def _event_loop(self) -> None: - """Event loop for the service.""" - while True: - event = await self.queue.get() - previous_state = self.state.model_copy(deep=True) - try: - async with self.lock: - updated_state = self._apply( - self.state, - event, - ) - self.state = updated_state - except Exception as e: - log(self._logger, StateUpdateErrorLogEntry(error=e)) - raise e - try: - for effect_handler in self._default_effects + self.extra_effects: - effect_handler(StateAndEvent(previous_state, event), updated_state) - except Exception as e: - log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e)) - raise e - - -class AsyncStateManagerMapping(TypedDict): - MutatesTaskState: AsyncStateManager[Literal[EventCategoryEnum.MutatesTaskState]] - MutatesTaskSagaState: AsyncStateManager[ - Literal[EventCategoryEnum.MutatesTaskSagaState] - ] - MutatesTopologyState: AsyncStateManager[ - Literal[EventCategoryEnum.MutatesTopologyState] - ] - MutatesRunnerStatus: AsyncStateManager[ - Literal[EventCategoryEnum.MutatesRunnerStatus] - ] - MutatesInstanceState: AsyncStateManager[ - Literal[EventCategoryEnum.MutatesInstanceState] - ] - MutatesNodePerformanceState: AsyncStateManager[ - Literal[EventCategoryEnum.MutatesNodePerformanceState] - ] - - -check_keys_in_map_match_enum_values(AsyncStateManagerMapping, EventCategoryEnum) diff --git a/master/state_manager/sync.py b/master/state_manager/sync.py deleted file mode 100644 index 4c4c70ba..00000000 --- a/master/state_manager/sync.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Literal, TypedDict - -from master.sanity_checking import check_keys_in_map_match_enum_values -from shared.types.events.common import EventCategoryEnum, State - - -class SyncStateManagerMapping(TypedDict): - MutatesTaskState: State[Literal[EventCategoryEnum.MutatesTaskState]] - MutatesTaskSagaState: State[Literal[EventCategoryEnum.MutatesTaskSagaState]] - MutatesTopologyState: State[Literal[EventCategoryEnum.MutatesTopologyState]] - MutatesRunnerStatus: State[Literal[EventCategoryEnum.MutatesRunnerStatus]] - MutatesInstanceState: State[Literal[EventCategoryEnum.MutatesInstanceState]] - MutatesNodePerformanceState: State[ - Literal[EventCategoryEnum.MutatesNodePerformanceState] - ] - - -check_keys_in_map_match_enum_values(SyncStateManagerMapping, EventCategoryEnum) diff --git a/shared/db/sqlite/connector.py b/shared/db/sqlite/connector.py index b0abff65..44de9efd 100644 --- a/shared/db/sqlite/connector.py +++ b/shared/db/sqlite/connector.py @@ -12,12 +12,9 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlmodel import SQLModel -from shared.types.events.common import ( - BaseEvent, - EventCategories, - NodeId, -) -from shared.types.events.registry import Event, EventFromEventLogTyped, EventParser +from shared.types.events.common import NodeId +from shared.types.events.components import EventFromEventLog +from shared.types.events.registry import Event, EventParser from .types import StoredEvent @@ -53,7 +50,7 @@ class AsyncSQLiteEventStorage: self._max_age_s = max_age_ms / 1000.0 self._logger = logger or getLogger(__name__) - self._write_queue: Queue[tuple[BaseEvent[EventCategories], NodeId]] = Queue() + self._write_queue: Queue[tuple[Event, NodeId]] = Queue() self._batch_writer_task: Task[None] | None = None self._engine = None self._closed = False @@ -72,7 +69,7 @@ class AsyncSQLiteEventStorage: async def append_events( self, - events: Sequence[BaseEvent[EventCategories]], + events: Sequence[Event], origin: NodeId ) -> None: """Append events to the log (fire-and-forget). The writes are batched and committed @@ -86,7 +83,7 @@ class AsyncSQLiteEventStorage: async def get_events_since( self, last_idx: int - ) -> Sequence[EventFromEventLogTyped]: + ) -> Sequence[EventFromEventLog[Event]]: """Retrieve events after a specific index.""" if self._closed: raise RuntimeError("Storage is closed") @@ -101,7 +98,7 @@ class AsyncSQLiteEventStorage: ) rows = result.fetchall() - events: list[EventFromEventLogTyped] = [] + events: list[EventFromEventLog[Event]] = [] for row in rows: rowid: int = cast(int, row[0]) origin: str = cast(str, row[1]) @@ -112,7 +109,7 @@ class AsyncSQLiteEventStorage: else: event_data = cast(dict[str, Any], raw_event_data) event = await self._deserialize_event(event_data) - events.append(EventFromEventLogTyped( + events.append(EventFromEventLog( event=event, origin=NodeId(uuid=UUID(origin)), idx_in_log=rowid # rowid becomes idx_in_log @@ -170,7 +167,7 @@ class AsyncSQLiteEventStorage: loop = asyncio.get_event_loop() while not self._closed: - batch: list[tuple[BaseEvent[EventCategories], NodeId]] = [] + batch: list[tuple[Event, NodeId]] = [] try: # Block waiting for first item @@ -208,7 +205,7 @@ class AsyncSQLiteEventStorage: if batch: await self._commit_batch(batch) - async def _commit_batch(self, batch: list[tuple[BaseEvent[EventCategories], NodeId]]) -> None: + async def _commit_batch(self, batch: list[tuple[Event, NodeId]]) -> None: """Commit a batch of events to SQLite.""" assert self._engine is not None @@ -218,7 +215,6 @@ class AsyncSQLiteEventStorage: stored_event = StoredEvent( origin=str(origin.uuid), event_type=str(event.event_type), - event_category=str(next(iter(event.event_category))), event_id=str(event.event_id), event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion ) @@ -237,8 +233,8 @@ class AsyncSQLiteEventStorage: """Deserialize event data back to typed Event.""" # EventParser expects the discriminator field for proper deserialization result = EventParser.validate_python(event_data) - # EventParser returns BaseEvent but we know it's actually a specific Event type - return result # type: ignore[reportReturnType] + # EventParser returns Event type which is our union of all event types + return result async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]: """Return raw event data for testing purposes.""" diff --git a/shared/db/sqlite/types.py b/shared/db/sqlite/types.py index 4b623e0c..880de7b3 100644 --- a/shared/db/sqlite/types.py +++ b/shared/db/sqlite/types.py @@ -4,12 +4,9 @@ from typing import Any, Protocol, Sequence from sqlalchemy import DateTime, Index from sqlmodel import JSON, Column, Field, SQLModel -from shared.types.events.common import ( - BaseEvent, - EventCategories, - EventFromEventLog, - NodeId, -) +from shared.types.common import NodeId +from shared.types.events.components import EventFromEventLog +from shared.types.events.registry import Event class StoredEvent(SQLModel, table=True): @@ -23,7 +20,6 @@ class StoredEvent(SQLModel, table=True): rowid: int | None = Field(default=None, primary_key=True, alias="rowid") origin: str = Field(index=True) event_type: str = Field(index=True) - event_category: str = Field(index=True) event_id: str = Field(index=True) event_data: dict[str, Any] = Field(sa_column=Column(JSON)) created_at: datetime = Field( @@ -33,7 +29,6 @@ class StoredEvent(SQLModel, table=True): __table_args__ = ( Index("idx_events_origin_created", "origin", "created_at"), - Index("idx_events_category_created", "event_category", "created_at"), ) class EventStorageProtocol(Protocol): @@ -41,7 +36,7 @@ class EventStorageProtocol(Protocol): async def append_events( self, - events: Sequence[BaseEvent[EventCategories]], + events: Sequence[Event], origin: NodeId ) -> None: """Append events to the log (fire-and-forget). @@ -54,7 +49,7 @@ class EventStorageProtocol(Protocol): async def get_events_since( self, last_idx: int - ) -> Sequence[EventFromEventLog[EventCategories]]: + ) -> Sequence[EventFromEventLog[Event]]: """Retrieve events after a specific index. Returns events in idx_in_log order. diff --git a/shared/event_loops/main.py b/shared/event_loops/main.py index c997028d..e89b4716 100644 --- a/shared/event_loops/main.py +++ b/shared/event_loops/main.py @@ -7,7 +7,10 @@ from typing import Any, Hashable, Mapping, Protocol, Sequence from fastapi.responses import Response, StreamingResponse from shared.event_loops.commands import ExternalCommand -from shared.types.events.common import Apply, EventCategory, EventFromEventLog, State +from shared.types.events.registry import Event +from shared.types.events.components import EventFromEventLog +from shared.types.state import State +from shared.types.events.components import Apply class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]): @@ -38,17 +41,16 @@ class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]): return len(self._store) -# Safety on Apply. -def safely_apply[T: EventCategory]( - state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]] -) -> State[T]: +def apply_events( + state: State, apply_fn: Apply, events: Sequence[EventFromEventLog[Event]] +) -> State: sorted_events = sorted(events, key=lambda event: event.idx_in_log) state = state.model_copy() - for event in sorted_events: - if event.idx_in_log <= state.last_event_applied_idx: + for wrapped_event in sorted_events: + if wrapped_event.idx_in_log <= state.last_event_applied_idx: continue - state.last_event_applied_idx = event.idx_in_log - state = apply_fn(state, event) + state.last_event_applied_idx = wrapped_event.idx_in_log + state = apply_fn(state, wrapped_event.event) return state @@ -69,11 +71,9 @@ class NodeCommandLoopProtocol(Protocol): async def _handle_command(self, command: ExternalCommand) -> None: ... -class NodeEventGetterProtocol[EventCategoryT: EventCategory](Protocol): +class NodeEventGetterProtocol(Protocol): _event_fetcher: Task[Any] | None = None - _event_queues: ExhaustiveMapping[ - EventCategoryT, AsyncQueue[EventFromEventLog[EventCategory]] - ] + _event_queue: AsyncQueue[EventFromEventLog[Event]] _logger: Logger @property @@ -84,18 +84,18 @@ class NodeEventGetterProtocol[EventCategoryT: EventCategory](Protocol): async def stop_event_fetcher(self) -> None: ... -class NodeStateStorageProtocol[EventCategoryT: EventCategory](Protocol): - _state_managers: ExhaustiveMapping[EventCategoryT, State[EventCategoryT]] +class NodeStateStorageProtocol(Protocol): + _state: State _state_lock: Lock _logger: Logger async def _read_state( - self, event_category: EventCategoryT - ) -> State[EventCategoryT]: ... + self, + ) -> State: ... -class NodeStateManagerProtocol[EventCategoryT: EventCategory]( - NodeEventGetterProtocol[EventCategoryT], NodeStateStorageProtocol[EventCategoryT] +class NodeStateManagerProtocol( + NodeEventGetterProtocol, NodeStateStorageProtocol ): _state_manager: Task[Any] | None = None _logger: Logger @@ -116,6 +116,6 @@ class NodeStateManagerProtocol[EventCategoryT: EventCategory]( async def _apply_queued_events(self) -> None: ... -class NodeEventLoopProtocol[EventCategoryT: EventCategory]( - NodeCommandLoopProtocol, NodeStateManagerProtocol[EventCategoryT] +class NodeEventLoopProtocol( + NodeCommandLoopProtocol, NodeStateManagerProtocol ): ... diff --git a/shared/event_loops/router.py b/shared/event_loops/router.py deleted file mode 100644 index 3dc27efe..00000000 --- a/shared/event_loops/router.py +++ /dev/null @@ -1,78 +0,0 @@ -from asyncio.queues import Queue -from typing import Sequence, cast, get_args - -from shared.event_loops.main import ExhaustiveMapping -from shared.types.events.common import ( - EventCategories, - EventCategory, - EventCategoryEnum, - EventFromEventLog, - narrow_event_from_event_log_type, -) - -""" -from asyncio import gather -from logging import Logger -from typing import Literal, Protocol, Sequence, TypedDict - -from master.sanity_checking import check_keys_in_map_match_enum_values -from shared.types.events.common import EventCategoryEnum -""" - -""" -class EventQueues(TypedDict): - MutatesTaskState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]] - ] - MutatesTaskSagaState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]] - ] - MutatesControlPlaneState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]] - ] - MutatesDataPlaneState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]] - ] - MutatesRunnerStatus: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]] - ] - MutatesInstanceState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]] - ] - MutatesNodePerformanceState: Queue[ - EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]] - ] - - -check_keys_in_map_match_enum_values(EventQueues, EventCategoryEnum) -""" - - -async def route_events[UnionOfRelevantEvents: EventCategory]( - queue_map: ExhaustiveMapping[ - UnionOfRelevantEvents, Queue[EventFromEventLog[EventCategory]] - ], - events: Sequence[EventFromEventLog[EventCategory | EventCategories]], -) -> None: - """Route an event to the appropriate queue.""" - tuple_of_categories: tuple[EventCategoryEnum, ...] = get_args(UnionOfRelevantEvents) - print(tuple_of_categories) - for event in events: - if isinstance(event.event.event_category, EventCategoryEnum): - category: EventCategory = event.event.event_category - if category not in tuple_of_categories: - continue - narrowed_event = narrow_event_from_event_log_type(event, category) - q1: Queue[EventFromEventLog[EventCategory]] = queue_map[ - cast(UnionOfRelevantEvents, category) - ] # TODO: make casting unnecessary - await q1.put(narrowed_event) - else: - for category in event.event.event_category: - if category not in tuple_of_categories: - continue - narrow_event = narrow_event_from_event_log_type(event, category) - q2 = queue_map[ - cast(UnionOfRelevantEvents, category) - ] # TODO: make casting unnecessary - await q2.put(narrow_event) diff --git a/shared/tests/test_sqlite_connector.py b/shared/tests/test_sqlite_connector.py index 6d3ec13f..c78e51dc 100644 --- a/shared/tests/test_sqlite_connector.py +++ b/shared/tests/test_sqlite_connector.py @@ -14,8 +14,7 @@ from shared.types.common import NodeId from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData from shared.types.events.events import ( ChunkGenerated, - EventCategoryEnum, - StreamingEventTypes, + EventType, ) from shared.types.tasks.common import TaskId @@ -91,11 +90,10 @@ class TestAsyncSQLiteEventStorage: async with AsyncSession(storage._engine) as session: await session.execute( - text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"), + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), { "origin": str(sample_node_id.uuid), "event_type": "test_event", - "event_category": "test_category", "event_id": str(uuid4()), "event_data": json.dumps(test_data) } @@ -137,11 +135,10 @@ class TestAsyncSQLiteEventStorage: async with AsyncSession(storage._engine) as session: for record in test_records: await session.execute( - text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"), + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), { "origin": str(sample_node_id.uuid), "event_type": record["event_type"], - "event_category": "test_category", "event_id": str(uuid4()), "event_data": json.dumps(record) } @@ -180,18 +177,18 @@ class TestAsyncSQLiteEventStorage: async with AsyncSession(storage._engine) as session: # Origin 1 - record 1 await session.execute( - text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"), - {"origin": str(origin1.uuid), "event_type": "event_1", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 1})} + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), + {"origin": str(origin1.uuid), "event_type": "event_1", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 1})} ) # Origin 2 - record 2 await session.execute( - text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"), - {"origin": str(origin2.uuid), "event_type": "event_2", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin2", "seq": 2})} + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), + {"origin": str(origin2.uuid), "event_type": "event_2", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin2", "seq": 2})} ) # Origin 1 - record 3 await session.execute( - text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"), - {"origin": str(origin1.uuid), "event_type": "event_3", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 3})} + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), + {"origin": str(origin1.uuid), "event_type": "event_3", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 3})} ) await session.commit() @@ -234,11 +231,10 @@ class TestAsyncSQLiteEventStorage: async with AsyncSession(storage._engine) as session: for i in range(10): await session.execute( - text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"), + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), { "origin": str(sample_node_id.uuid), "event_type": f"event_{i}", - "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"index": i}) } @@ -325,11 +321,10 @@ class TestAsyncSQLiteEventStorage: assert storage._engine is not None async with AsyncSession(storage._engine) as session: await session.execute( - text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"), + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), { "origin": str(sample_node_id.uuid), "event_type": "complex_event", - "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps(test_data) } @@ -364,11 +359,10 @@ class TestAsyncSQLiteEventStorage: async with AsyncSession(storage._engine) as session: for i in range(count): await session.execute( - text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"), + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), { "origin": origin_id, "event_type": f"batch_{batch_id}_event_{i}", - "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"batch": batch_id, "item": i}) } @@ -425,14 +419,12 @@ class TestAsyncSQLiteEventStorage: ) chunk_generated_event = ChunkGenerated( - event_type=StreamingEventTypes.ChunkGenerated, - event_category=EventCategoryEnum.MutatesTaskState, task_id=task_id, chunk=token_chunk ) # Store the event using the storage API - await storage.append_events([chunk_generated_event], sample_node_id) # type: ignore[reportArgumentType] + await storage.append_events([chunk_generated_event], sample_node_id) # Wait for batch to be written await asyncio.sleep(0.5) @@ -448,8 +440,7 @@ class TestAsyncSQLiteEventStorage: # Verify the event was deserialized correctly retrieved_event = retrieved_event_wrapper.event assert isinstance(retrieved_event, ChunkGenerated) - assert retrieved_event.event_type == StreamingEventTypes.ChunkGenerated - assert retrieved_event.event_category == EventCategoryEnum.MutatesTaskState + assert retrieved_event.event_type == EventType.ChunkGenerated assert retrieved_event.task_id == task_id # Verify the nested chunk was deserialized correctly diff --git a/shared/types/events/categories.py b/shared/types/events/categories.py new file mode 100644 index 00000000..0059348c --- /dev/null +++ b/shared/types/events/categories.py @@ -0,0 +1,10 @@ + +from shared.types.events.events import ( + MLXInferenceSagaPrepare, + MLXInferenceSagaStartPrepare, +) + +TaskSagaEvent = ( + MLXInferenceSagaPrepare + | MLXInferenceSagaStartPrepare +) \ No newline at end of file diff --git a/shared/types/events/commands.py b/shared/types/events/commands.py new file mode 100644 index 00000000..9d7cd1ff --- /dev/null +++ b/shared/types/events/commands.py @@ -0,0 +1,38 @@ +from enum import Enum +from typing import ( + TYPE_CHECKING, + Callable, + Sequence, +) + +if TYPE_CHECKING: + pass + +from pydantic import BaseModel + +from shared.types.common import NewUUID +from shared.types.events.registry import Event +from shared.types.state import State + + +class CommandId(NewUUID): + pass + + +class CommandTypes(str, Enum): + Create = "Create" + Update = "Update" + Delete = "Delete" + + +class Command[ + CommandType: CommandTypes, +](BaseModel): + command_type: CommandType + command_id: CommandId + + +type Decide[CommandTypeT: CommandTypes] = Callable[ + [State, Command[CommandTypeT]], + Sequence[Event], +] diff --git a/shared/types/events/common.py b/shared/types/events/common.py index cdae35c9..f19f17a4 100644 --- a/shared/types/events/common.py +++ b/shared/types/events/common.py @@ -1,26 +1,16 @@ -from enum import Enum, StrEnum +from enum import Enum from typing import ( TYPE_CHECKING, - Any, - Callable, - FrozenSet, - Literal, - NamedTuple, - Protocol, - Sequence, - cast, + Generic, + TypeVar, ) if TYPE_CHECKING: pass -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel from shared.types.common import NewUUID, NodeId -from shared.types.events.sanity_checking import ( - assert_literal_union_covers_enum, - check_event_type_union_is_consistent_with_registry, -) class EventId(NewUUID): @@ -32,114 +22,49 @@ class TimerId(NewUUID): # Here are all the unique kinds of events that can be sent over the network. -# I've defined them in different enums for clarity, but they're all part of the same set of possible events. -class TaskSagaEventTypes(str, Enum): +class EventType(str, Enum): + # Task Saga Events MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare" MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare" - - -class TaskEventTypes(str, Enum): + + # Task Events TaskCreated = "TaskCreated" TaskStateUpdated = "TaskStateUpdated" TaskDeleted = "TaskDeleted" - - -class StreamingEventTypes(str, Enum): + + # Streaming Events ChunkGenerated = "ChunkGenerated" - - -class InstanceEventTypes(str, Enum): + + # Instance Events InstanceCreated = "InstanceCreated" InstanceDeleted = "InstanceDeleted" InstanceActivated = "InstanceActivated" InstanceDeactivated = "InstanceDeactivated" InstanceReplacedAtomically = "InstanceReplacedAtomically" - - -class RunnerStatusEventTypes(str, Enum): + + # Runner Status Events RunnerStatusUpdated = "RunnerStatusUpdated" - - -class NodePerformanceEventTypes(str, Enum): + + # Node Performance Events NodePerformanceMeasured = "NodePerformanceMeasured" - - -class TopologyEventTypes(str, Enum): + + # Topology Events TopologyEdgeCreated = "TopologyEdgeCreated" TopologyEdgeReplacedAtomically = "TopologyEdgeReplacedAtomically" TopologyEdgeDeleted = "TopologyEdgeDeleted" WorkerConnected = "WorkerConnected" WorkerStatusUpdated = "WorkerStatusUpdated" WorkerDisconnected = "WorkerDisconnected" - - -class TimerEventTypes(str, Enum): + + # Timer Events TimerCreated = "TimerCreated" TimerFired = "TimerFired" - -# Registry of all event type enums -EVENT_TYPE_ENUMS = [ - TaskEventTypes, - StreamingEventTypes, - InstanceEventTypes, - RunnerStatusEventTypes, - NodePerformanceEventTypes, - TopologyEventTypes, - TimerEventTypes, - TaskSagaEventTypes, -] +EventTypeT = TypeVar("EventTypeT", bound=EventType) -# Here's the set of all possible events. -EventTypes = ( - TaskEventTypes - | StreamingEventTypes - | InstanceEventTypes - | RunnerStatusEventTypes - | NodePerformanceEventTypes - | TopologyEventTypes - | TimerEventTypes - | TaskSagaEventTypes -) - - -check_event_type_union_is_consistent_with_registry(EVENT_TYPE_ENUMS, EventTypes) - - -class EventCategoryEnum(StrEnum): - MutatesTaskState = "MutatesTaskState" - MutatesTaskSagaState = "MutatesTaskSagaState" - MutatesRunnerStatus = "MutatesRunnerStatus" - MutatesInstanceState = "MutatesInstanceState" - MutatesNodePerformanceState = "MutatesNodePerformanceState" - MutatesTopologyState = "MutatesTopologyState" - - -EventCategory = ( - Literal[EventCategoryEnum.MutatesTopologyState] - | Literal[EventCategoryEnum.MutatesTaskState] - | Literal[EventCategoryEnum.MutatesTaskSagaState] - | Literal[EventCategoryEnum.MutatesRunnerStatus] - | Literal[EventCategoryEnum.MutatesInstanceState] - | Literal[EventCategoryEnum.MutatesNodePerformanceState] - | Literal[EventCategoryEnum.MutatesTopologyState] -) - -EventCategories = FrozenSet[EventCategory] - -assert_literal_union_covers_enum(EventCategory, EventCategoryEnum) - - -EventTypeT = EventTypes # Type Alias placeholder; generic parameter will override - - -class BaseEvent[ - SetMembersT: EventCategories | EventCategory, - EventTypeLitT: EventTypes = EventTypes, -](BaseModel): - event_type: EventTypeLitT - event_category: SetMembersT +class BaseEvent(BaseModel, Generic[EventTypeT]): + event_type: EventTypeT event_id: EventId = EventId() def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: @@ -151,129 +76,4 @@ class BaseEvent[ return True -class EventFromEventLog[SetMembersT: EventCategories | EventCategory](BaseModel): - event: BaseEvent[SetMembersT] - origin: NodeId - idx_in_log: int = Field(gt=0) - @model_validator(mode="after") - def check_event_was_sent_by_correct_node( - self, - ) -> "EventFromEventLog[SetMembersT]": - if self.event.check_event_was_sent_by_correct_node(self.origin): - return self - raise ValueError("Invalid Event: Origin ID Does Not Match") - - -def narrow_event_type[T: EventCategory, Q: EventCategories | EventCategory]( - event: BaseEvent[Q], - target_category: T, -) -> BaseEvent[T]: - if target_category not in event.event_category: - raise ValueError(f"Event Does Not Contain Target Category {target_category}") - - narrowed_event = event.model_copy(update={"event_category": {target_category}}) - return cast(BaseEvent[T], narrowed_event) - - -def narrow_event_from_event_log_type[ - T: EventCategory, - Q: EventCategories | EventCategory, -]( - event: EventFromEventLog[Q], - target_category: T, -) -> EventFromEventLog[T]: - if target_category not in event.event.event_category: - raise ValueError(f"Event Does Not Contain Target Category {target_category}") - narrowed_event = event.model_copy( - update={"event": narrow_event_type(event.event, target_category)} - ) - - return cast(EventFromEventLog[T], narrowed_event) - - -class State[EventCategoryT: EventCategory](BaseModel): - event_category: EventCategoryT - last_event_applied_idx: int = Field(default=0, ge=0) - - -# Definitions for Type Variables -type Saga[EventCategoryT: EventCategory] = Callable[ - [State[EventCategoryT], EventFromEventLog[EventCategoryT]], - Sequence[BaseEvent[EventCategories]], -] -type Apply[EventCategoryT: EventCategory] = Callable[ - [State[EventCategoryT], EventFromEventLog[EventCategoryT]], - State[EventCategoryT], -] - - -class StateAndEvent[EventCategoryT: EventCategory](NamedTuple): - state: State[EventCategoryT] - event: EventFromEventLog[EventCategoryT] - - -type EffectHandler[EventCategoryT: EventCategory] = Callable[ - [StateAndEvent[EventCategoryT], State[EventCategoryT]], None -] -type EventPublisher = Callable[[BaseEvent[Any]], None] - - -# A component that can publish events -class EventPublisherProtocol(Protocol): - def send(self, events: Sequence[BaseEvent[EventCategories]]) -> None: ... - - -# A component that can fetch events to apply -class EventFetcherProtocol[EventCategoryT: EventCategory](Protocol): - def get_events_to_apply( - self, state: State[EventCategoryT] - ) -> Sequence[BaseEvent[EventCategoryT]]: ... - - -# A component that can get the effect handler for a saga -def get_saga_effect_handler[EventCategoryT: EventCategory]( - saga: Saga[EventCategoryT], event_publisher: EventPublisher -) -> EffectHandler[EventCategoryT]: - def effect_handler(state_and_event: StateAndEvent[EventCategoryT]) -> None: - trigger_state, trigger_event = state_and_event - for event in saga(trigger_state, trigger_event): - event_publisher(event) - - return lambda state_and_event, _: effect_handler(state_and_event) - - -def get_effects_from_sagas[EventCategoryT: EventCategory]( - sagas: Sequence[Saga[EventCategoryT]], - event_publisher: EventPublisher, -) -> Sequence[EffectHandler[EventCategoryT]]: - return [get_saga_effect_handler(saga, event_publisher) for saga in sagas] - - -type IdemKeyGenerator[EventCategoryT: EventCategory] = Callable[ - [State[EventCategoryT], int], Sequence[EventId] -] - - -class CommandId(NewUUID): - pass - - -class CommandTypes(str, Enum): - Create = "Create" - Update = "Update" - Delete = "Delete" - - -class Command[ - EventCategoryT: EventCategories | EventCategory, - CommandType: CommandTypes, -](BaseModel): - command_type: CommandType - command_id: CommandId - - -type Decide[EventCategoryT: EventCategory, CommandTypeT: CommandTypes] = Callable[ - [State[EventCategoryT], Command[EventCategoryT, CommandTypeT]], - Sequence[BaseEvent[EventCategoryT]], -] diff --git a/shared/types/events/components.py b/shared/types/events/components.py new file mode 100644 index 00000000..0c5f90e1 --- /dev/null +++ b/shared/types/events/components.py @@ -0,0 +1,38 @@ +# components.py defines the small event functions, adapters etc. +# this name could probably be improved. + +from typing import ( + TYPE_CHECKING, +) + +if TYPE_CHECKING: + pass + +from pydantic import BaseModel, Field, model_validator + +from typing import Callable + +from shared.types.common import NodeId +from shared.types.state import State +from shared.types.events.registry import Event + + +class EventFromEventLog[T: Event](BaseModel): + event: T + origin: NodeId + idx_in_log: int = Field(gt=0) + + @model_validator(mode="after") + def check_event_was_sent_by_correct_node( + self, + ) -> "EventFromEventLog[T]": + if self.event.check_event_was_sent_by_correct_node(self.origin): + return self + raise ValueError("Invalid Event: Origin ID Does Not Match") + + + +type Apply = Callable[ + [State, Event], + State +] \ No newline at end of file diff --git a/shared/types/events/events.py b/shared/types/events/events.py index 8def7eff..478e82de 100644 --- a/shared/types/events/events.py +++ b/shared/types/events/events.py @@ -6,14 +6,8 @@ from shared.types.common import NodeId from shared.types.events.chunks import GenerationChunk from shared.types.events.common import ( BaseEvent, - EventCategoryEnum, - InstanceEventTypes, - NodePerformanceEventTypes, - RunnerStatusEventTypes, - StreamingEventTypes, - TaskEventTypes, - TaskSagaEventTypes, - TopologyEventTypes, + EventType, + TimerId, ) from shared.types.graphs.topology import ( TopologyEdge, @@ -27,156 +21,122 @@ from shared.types.worker.common import InstanceId, NodeStatus from shared.types.worker.instances import InstanceParams, TypeOfInstance from shared.types.worker.runners import RunnerId, RunnerStatus -TaskEvent = BaseEvent[EventCategoryEnum.MutatesTaskState] -InstanceEvent = BaseEvent[EventCategoryEnum.MutatesInstanceState] -TopologyEvent = BaseEvent[EventCategoryEnum.MutatesTopologyState] -NodePerformanceEvent = BaseEvent[EventCategoryEnum.MutatesNodePerformanceState] - -class TaskCreated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskCreated]]): - event_type: Literal[TaskEventTypes.TaskCreated] = TaskEventTypes.TaskCreated - event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState +class TaskCreated(BaseEvent[EventType.TaskCreated]): + event_type: Literal[EventType.TaskCreated] = EventType.TaskCreated task_id: TaskId task: Task -# Covers Cancellation Of Task, Non-Cancelled Tasks Perist -class TaskDeleted(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskDeleted]]): - event_type: Literal[TaskEventTypes.TaskDeleted] = TaskEventTypes.TaskDeleted - event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState +class TaskDeleted(BaseEvent[EventType.TaskDeleted]): + event_type: Literal[EventType.TaskDeleted] = EventType.TaskDeleted task_id: TaskId -class TaskStateUpdated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskStateUpdated]]): - event_type: Literal[TaskEventTypes.TaskStateUpdated] = TaskEventTypes.TaskStateUpdated - event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState +class TaskStateUpdated(BaseEvent[EventType.TaskStateUpdated]): + event_type: Literal[EventType.TaskStateUpdated] = EventType.TaskStateUpdated task_id: TaskId task_status: TaskStatus -class InstanceCreated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceCreated]]): - event_type: Literal[InstanceEventTypes.InstanceCreated] = InstanceEventTypes.InstanceCreated - event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState +class InstanceCreated(BaseEvent[EventType.InstanceCreated]): + event_type: Literal[EventType.InstanceCreated] = EventType.InstanceCreated instance_id: InstanceId instance_params: InstanceParams instance_type: TypeOfInstance -class InstanceActivated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceActivated]]): - event_type: Literal[InstanceEventTypes.InstanceActivated] = InstanceEventTypes.InstanceActivated - event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState +class InstanceActivated(BaseEvent[EventType.InstanceActivated]): + event_type: Literal[EventType.InstanceActivated] = EventType.InstanceActivated instance_id: InstanceId -class InstanceDeactivated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceDeactivated]]): - event_type: Literal[InstanceEventTypes.InstanceDeactivated] = InstanceEventTypes.InstanceDeactivated - event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState +class InstanceDeactivated(BaseEvent[EventType.InstanceDeactivated]): + event_type: Literal[EventType.InstanceDeactivated] = EventType.InstanceDeactivated instance_id: InstanceId -class InstanceDeleted(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceDeleted]]): - event_type: Literal[InstanceEventTypes.InstanceDeleted] = InstanceEventTypes.InstanceDeleted - event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState +class InstanceDeleted(BaseEvent[EventType.InstanceDeleted]): + event_type: Literal[EventType.InstanceDeleted] = EventType.InstanceDeleted instance_id: InstanceId transition: Tuple[InstanceId, InstanceId] -class InstanceReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceReplacedAtomically]]): - event_type: Literal[InstanceEventTypes.InstanceReplacedAtomically] = InstanceEventTypes.InstanceReplacedAtomically - event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState +class InstanceReplacedAtomically(BaseEvent[EventType.InstanceReplacedAtomically]): + event_type: Literal[EventType.InstanceReplacedAtomically] = EventType.InstanceReplacedAtomically instance_to_replace: InstanceId new_instance_id: InstanceId -class RunnerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesRunnerStatus, Literal[RunnerStatusEventTypes.RunnerStatusUpdated]]): - event_type: Literal[RunnerStatusEventTypes.RunnerStatusUpdated] = RunnerStatusEventTypes.RunnerStatusUpdated - event_category: Literal[EventCategoryEnum.MutatesRunnerStatus] = EventCategoryEnum.MutatesRunnerStatus +class RunnerStatusUpdated(BaseEvent[EventType.RunnerStatusUpdated]): + event_type: Literal[EventType.RunnerStatusUpdated] = EventType.RunnerStatusUpdated runner_id: RunnerId runner_status: RunnerStatus -class MLXInferenceSagaPrepare(BaseEvent[EventCategoryEnum.MutatesTaskSagaState, Literal[TaskSagaEventTypes.MLXInferenceSagaPrepare]]): - event_type: Literal[TaskSagaEventTypes.MLXInferenceSagaPrepare] = TaskSagaEventTypes.MLXInferenceSagaPrepare - event_category: Literal[EventCategoryEnum.MutatesTaskSagaState] = EventCategoryEnum.MutatesTaskSagaState +class MLXInferenceSagaPrepare(BaseEvent[EventType.MLXInferenceSagaPrepare]): + event_type: Literal[EventType.MLXInferenceSagaPrepare] = EventType.MLXInferenceSagaPrepare task_id: TaskId instance_id: InstanceId -class MLXInferenceSagaStartPrepare(BaseEvent[EventCategoryEnum.MutatesTaskSagaState, Literal[TaskSagaEventTypes.MLXInferenceSagaStartPrepare]]): - event_type: Literal[TaskSagaEventTypes.MLXInferenceSagaStartPrepare] = TaskSagaEventTypes.MLXInferenceSagaStartPrepare - event_category: Literal[EventCategoryEnum.MutatesTaskSagaState] = EventCategoryEnum.MutatesTaskSagaState +class MLXInferenceSagaStartPrepare(BaseEvent[EventType.MLXInferenceSagaStartPrepare]): + event_type: Literal[EventType.MLXInferenceSagaStartPrepare] = EventType.MLXInferenceSagaStartPrepare task_id: TaskId instance_id: InstanceId -class NodePerformanceMeasured(BaseEvent[EventCategoryEnum.MutatesNodePerformanceState, Literal[NodePerformanceEventTypes.NodePerformanceMeasured]]): - event_type: Literal[NodePerformanceEventTypes.NodePerformanceMeasured] = NodePerformanceEventTypes.NodePerformanceMeasured - event_category: Literal[EventCategoryEnum.MutatesNodePerformanceState] = EventCategoryEnum.MutatesNodePerformanceState +class NodePerformanceMeasured(BaseEvent[EventType.NodePerformanceMeasured]): + event_type: Literal[EventType.NodePerformanceMeasured] = EventType.NodePerformanceMeasured node_id: NodeId node_profile: NodePerformanceProfile -class WorkerConnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerConnected]]): - event_type: Literal[TopologyEventTypes.WorkerConnected] = TopologyEventTypes.WorkerConnected - event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState +class WorkerConnected(BaseEvent[EventType.WorkerConnected]): + event_type: Literal[EventType.WorkerConnected] = EventType.WorkerConnected edge: TopologyEdge -class WorkerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerStatusUpdated]]): - event_type: Literal[TopologyEventTypes.WorkerStatusUpdated] = TopologyEventTypes.WorkerStatusUpdated - event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState +class WorkerStatusUpdated(BaseEvent[EventType.WorkerStatusUpdated]): + event_type: Literal[EventType.WorkerStatusUpdated] = EventType.WorkerStatusUpdated node_id: NodeId node_state: NodeStatus -class WorkerDisconnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerDisconnected]]): - event_type: Literal[TopologyEventTypes.WorkerDisconnected] = TopologyEventTypes.WorkerDisconnected - event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState +class WorkerDisconnected(BaseEvent[EventType.WorkerDisconnected]): + event_type: Literal[EventType.WorkerDisconnected] = EventType.WorkerDisconnected vertex_id: NodeId -class ChunkGenerated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[StreamingEventTypes.ChunkGenerated]]): - event_type: Literal[StreamingEventTypes.ChunkGenerated] = StreamingEventTypes.ChunkGenerated - event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState +class ChunkGenerated(BaseEvent[EventType.ChunkGenerated]): + event_type: Literal[EventType.ChunkGenerated] = EventType.ChunkGenerated task_id: TaskId chunk: GenerationChunk -class TopologyEdgeCreated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeCreated]]): - event_type: Literal[TopologyEventTypes.TopologyEdgeCreated] = TopologyEventTypes.TopologyEdgeCreated - event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState +class TopologyEdgeCreated(BaseEvent[EventType.TopologyEdgeCreated]): + event_type: Literal[EventType.TopologyEdgeCreated] = EventType.TopologyEdgeCreated vertex: TopologyNode -class TopologyEdgeReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically]]): - event_type: Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically] = TopologyEventTypes.TopologyEdgeReplacedAtomically - event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState +class TopologyEdgeReplacedAtomically(BaseEvent[EventType.TopologyEdgeReplacedAtomically]): + event_type: Literal[EventType.TopologyEdgeReplacedAtomically] = EventType.TopologyEdgeReplacedAtomically edge_id: TopologyEdgeId edge_profile: TopologyEdgeProfile -class TopologyEdgeDeleted(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeDeleted]]): - event_type: Literal[TopologyEventTypes.TopologyEdgeDeleted] = TopologyEventTypes.TopologyEdgeDeleted - event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState +class TopologyEdgeDeleted(BaseEvent[EventType.TopologyEdgeDeleted]): + event_type: Literal[EventType.TopologyEdgeDeleted] = EventType.TopologyEdgeDeleted edge_id: TopologyEdgeId -""" -TEST_EVENT_CATEGORIES_TYPE = FrozenSet[ - Literal[ - EventCategoryEnum.MutatesTaskState, - EventCategoryEnum.MutatesControlPlaneState, - ] -] -TEST_EVENT_CATEGORIES = frozenset( - ( - EventCategoryEnum.MutatesTaskState, - EventCategoryEnum.MutatesControlPlaneState, - ) -) + +class TimerCreated(BaseEvent[EventType.TimerCreated]): + event_type: Literal[EventType.TimerCreated] = EventType.TimerCreated + timer_id: TimerId + delay_seconds: float -class TestEvent(BaseEvent[TEST_EVENT_CATEGORIES_TYPE]): - event_category: TEST_EVENT_CATEGORIES_TYPE = TEST_EVENT_CATEGORIES - test_id: int -""" \ No newline at end of file +class TimerFired(BaseEvent[EventType.TimerFired]): + event_type: Literal[EventType.TimerFired] = EventType.TimerFired + timer_id: TimerId \ No newline at end of file diff --git a/shared/types/events/registry.py b/shared/types/events/registry.py index 8ba17138..959ada0f 100644 --- a/shared/types/events/registry.py +++ b/shared/types/events/registry.py @@ -1,25 +1,15 @@ -from types import UnionType -from typing import Annotated, Any, Mapping, Type, get_args +from typing import Annotated, Any, Mapping, Type, TypeAlias -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import Field, TypeAdapter -from shared.constants import get_error_reporting_message from shared.types.events.common import ( - BaseEvent, - EventCategories, - EventTypes, - InstanceEventTypes, - NodeId, - NodePerformanceEventTypes, - RunnerStatusEventTypes, - StreamingEventTypes, - TaskEventTypes, - TaskSagaEventTypes, - TopologyEventTypes, + EventType, ) from shared.types.events.events import ( ChunkGenerated, + InstanceActivated, InstanceCreated, + InstanceDeactivated, InstanceDeleted, InstanceReplacedAtomically, MLXInferenceSagaPrepare, @@ -29,6 +19,8 @@ from shared.types.events.events import ( TaskCreated, TaskDeleted, TaskStateUpdated, + TimerCreated, + TimerFired, TopologyEdgeCreated, TopologyEdgeDeleted, TopologyEdgeReplacedAtomically, @@ -36,6 +28,11 @@ from shared.types.events.events import ( WorkerDisconnected, WorkerStatusUpdated, ) +from shared.types.events.sanity_checking import ( + assert_event_union_covers_registry, + check_registry_has_all_event_types, + check_union_of_all_events_is_consistent_with_registry, +) """ class EventTypeNames(StrEnum): @@ -50,63 +47,38 @@ class EventTypeNames(StrEnum): check_event_categories_are_defined_for_all_event_types(EVENT_TYPE_ENUMS, EventTypeNames) """ -EventRegistry: Mapping[EventTypes, Type[Any]] = { - TaskEventTypes.TaskCreated: TaskCreated, - TaskEventTypes.TaskStateUpdated: TaskStateUpdated, - TaskEventTypes.TaskDeleted: TaskDeleted, - InstanceEventTypes.InstanceCreated: InstanceCreated, - InstanceEventTypes.InstanceDeleted: InstanceDeleted, - InstanceEventTypes.InstanceReplacedAtomically: InstanceReplacedAtomically, - RunnerStatusEventTypes.RunnerStatusUpdated: RunnerStatusUpdated, - NodePerformanceEventTypes.NodePerformanceMeasured: NodePerformanceMeasured, - TopologyEventTypes.WorkerConnected: WorkerConnected, - TopologyEventTypes.WorkerStatusUpdated: WorkerStatusUpdated, - TopologyEventTypes.WorkerDisconnected: WorkerDisconnected, - StreamingEventTypes.ChunkGenerated: ChunkGenerated, - TopologyEventTypes.TopologyEdgeCreated: TopologyEdgeCreated, - TopologyEventTypes.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically, - TopologyEventTypes.TopologyEdgeDeleted: TopologyEdgeDeleted, - TaskSagaEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare, - TaskSagaEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare, +EventRegistry: Mapping[EventType, Type[Any]] = { + EventType.TaskCreated: TaskCreated, + EventType.TaskStateUpdated: TaskStateUpdated, + EventType.TaskDeleted: TaskDeleted, + EventType.InstanceCreated: InstanceCreated, + EventType.InstanceActivated: InstanceActivated, + EventType.InstanceDeactivated: InstanceDeactivated, + EventType.InstanceDeleted: InstanceDeleted, + EventType.InstanceReplacedAtomically: InstanceReplacedAtomically, + EventType.RunnerStatusUpdated: RunnerStatusUpdated, + EventType.NodePerformanceMeasured: NodePerformanceMeasured, + EventType.WorkerConnected: WorkerConnected, + EventType.WorkerStatusUpdated: WorkerStatusUpdated, + EventType.WorkerDisconnected: WorkerDisconnected, + EventType.ChunkGenerated: ChunkGenerated, + EventType.TopologyEdgeCreated: TopologyEdgeCreated, + EventType.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically, + EventType.TopologyEdgeDeleted: TopologyEdgeDeleted, + EventType.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare, + EventType.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare, + EventType.TimerCreated: TimerCreated, + EventType.TimerFired: TimerFired, } -# Sanity Check. -def check_registry_has_all_event_types() -> None: - event_types: tuple[EventTypes, ...] = get_args(EventTypes) - missing_event_types = set(event_types) - set(EventRegistry.keys()) - - assert not missing_event_types, ( - f"{get_error_reporting_message()}" - f"There's an event missing from the registry: {missing_event_types}" - ) - - -def check_union_of_all_events_is_consistent_with_registry( - registry: Mapping[EventTypes, Type[Any]], union_type: UnionType -) -> None: - type_of_each_registry_entry = set(registry.values()) - type_of_each_entry_in_union = set(get_args(union_type)) - missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union - - assert not missing_from_union, ( - f"{get_error_reporting_message()}" - f"Event classes in registry are missing from all_events union: {missing_from_union}" - ) - - extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry - - assert not extra_in_union, ( - f"{get_error_reporting_message()}" - f"Event classes in all_events union are missing from registry: {extra_in_union}" - ) - - -Event = ( +AllEventsUnion = ( TaskCreated | TaskStateUpdated | TaskDeleted | InstanceCreated + | InstanceActivated + | InstanceDeactivated | InstanceDeleted | InstanceReplacedAtomically | RunnerStatusUpdated @@ -120,24 +92,16 @@ Event = ( | TopologyEdgeDeleted | MLXInferenceSagaPrepare | MLXInferenceSagaStartPrepare + | TimerCreated + | TimerFired ) -# Run the sanity check -check_union_of_all_events_is_consistent_with_registry(EventRegistry, Event) +Event: TypeAlias = Annotated[AllEventsUnion, Field(discriminator="event_type")] +EventParser: TypeAdapter[Event] = TypeAdapter(Event) -_EventType = Annotated[Event, Field(discriminator="event_type")] -EventParser: TypeAdapter[BaseEvent[EventCategories]] = TypeAdapter(_EventType) -# Define a properly typed EventFromEventLog that preserves specific event types - -class EventFromEventLogTyped(BaseModel): - """Properly typed EventFromEventLog that preserves specific event types.""" - event: _EventType - origin: NodeId - idx_in_log: int = Field(gt=0) - - def check_event_was_sent_by_correct_node(self) -> bool: - """Check if the event was sent by the correct node.""" - return self.event.check_event_was_sent_by_correct_node(self.origin) +assert_event_union_covers_registry(AllEventsUnion) +check_union_of_all_events_is_consistent_with_registry(EventRegistry, AllEventsUnion) +check_registry_has_all_event_types(EventRegistry) \ No newline at end of file diff --git a/shared/types/events/sanity_checking.py b/shared/types/events/sanity_checking.py index ca489f23..def11557 100644 --- a/shared/types/events/sanity_checking.py +++ b/shared/types/events/sanity_checking.py @@ -1,68 +1,75 @@ -from enum import Enum, StrEnum +from enum import StrEnum from types import UnionType -from typing import Any, LiteralString, Sequence, Set, Type, get_args +from typing import Any, Mapping, Set, Type, cast, get_args + +from pydantic.fields import FieldInfo from shared.constants import get_error_reporting_message +from shared.types.events.common import EventType -def check_event_type_union_is_consistent_with_registry( - event_type_enums: Sequence[Type[Enum]], event_types: UnionType -) -> None: - """Assert that every enum value from _EVENT_TYPE_ENUMS satisfies EventTypes.""" - - event_types_inferred_from_union = set(get_args(event_types)) - - event_types_inferred_from_registry = [ - member for enum_class in event_type_enums for member in enum_class - ] - - # Check that each registry value belongs to one of the types in the union - for tag_of_event_type in event_types_inferred_from_registry: - event_type = type(tag_of_event_type) - assert event_type in event_types_inferred_from_union, ( - f"{get_error_reporting_message()}" - f"There's a mismatch between the registry of event types and the union of possible event types." - f"The enum value {tag_of_event_type} for type {event_type} is not covered by {event_types_inferred_from_union}." - ) - - -def check_event_categories_are_defined_for_all_event_types( - event_definitions: Sequence[Type[Enum]], event_categories: Type[StrEnum] -) -> None: - """Assert that the event category names are consistent with the event type enums.""" - - expected_category_tags: list[str] = [ - enum_class.__name__ for enum_class in event_definitions - ] - tag_of_event_categories: list[str] = list(event_categories.__members__.values()) - assert tag_of_event_categories == expected_category_tags, ( - f"{get_error_reporting_message()}" - f"The values of the enum EventCategories are not named after the event type enums." - f"These are the missing categories: {set(expected_category_tags) - set(tag_of_event_categories)}" - f"These are the extra categories: {set(tag_of_event_categories) - set(expected_category_tags)}" - ) - - -def assert_literal_union_covers_enum[TEnum: StrEnum]( +def assert_event_union_covers_registry[TEnum: StrEnum]( literal_union: UnionType, - enum_type: Type[TEnum], ) -> None: - enum_values: Set[Any] = {member.value for member in enum_type} + """ + Ensure that our union of events (AllEventsUnion) has one member per element of Enum + """ + enum_values: Set[str] = {member.value for member in EventType} - def _flatten(tp: UnionType) -> Set[Any]: - values: Set[Any] = set() - args: tuple[LiteralString, ...] = get_args(tp) - for arg in args: - payloads: tuple[TEnum, ...] = get_args(arg) - for payload in payloads: - values.add(payload.value) + def _flatten(tp: UnionType) -> Set[str]: + values: Set[str] = set() + args = get_args(tp) # Get event classes from the union + for arg in args: # type: ignore[reportAny] + # Cast to type since we know these are class types + event_class = cast(type[Any], arg) + # Each event class is a Pydantic model with model_fields + if hasattr(event_class, 'model_fields'): + model_fields = cast(dict[str, FieldInfo], event_class.model_fields) + if 'event_type' in model_fields: + # Get the default value of the event_type field + event_type_field: FieldInfo = model_fields['event_type'] + if hasattr(event_type_field, 'default'): + default_value = cast(EventType, event_type_field.default) + # The default is an EventType enum member, get its value + values.add(default_value.value) return values - literal_values: Set[Any] = _flatten(literal_union) + literal_values: Set[str] = _flatten(literal_union) assert enum_values == literal_values, ( f"{get_error_reporting_message()}" - f"The values of the enum {enum_type} are not covered by the literal union {literal_union}.\n" + f"The values of the enum {EventType} are not covered by the literal union {literal_union}.\n" f"These are the missing values: {enum_values - literal_values}\n" f"These are the extra values: {literal_values - enum_values}\n" ) + +def check_union_of_all_events_is_consistent_with_registry( + registry: Mapping[EventType, Type[Any]], union_type: UnionType +) -> None: + type_of_each_registry_entry = set(registry.values()) + type_of_each_entry_in_union = set(get_args(union_type)) + missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union + + assert not missing_from_union, ( + f"{get_error_reporting_message()}" + f"Event classes in registry are missing from all_events union: {missing_from_union}" + ) + + extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry + + assert not extra_in_union, ( + f"{get_error_reporting_message()}" + f"Event classes in all_events union are missing from registry: {extra_in_union}" + ) + +def check_registry_has_all_event_types(event_registry: Mapping[EventType, Type[Any]]) -> None: + event_types: tuple[EventType, ...] = get_args(EventType) + missing_event_types = set(event_types) - set(event_registry.keys()) + + assert not missing_event_types, ( + f"{get_error_reporting_message()}" + f"There's an event missing from the registry: {missing_event_types}" + ) + +# TODO: Check all events have an apply function. +# probably in a different place though. \ No newline at end of file diff --git a/shared/types/state.py b/shared/types/state.py index 59d51957..0712d525 100644 --- a/shared/types/state.py +++ b/shared/types/state.py @@ -39,3 +39,6 @@ class State(BaseModel): task_inbox: List[Task] = Field(default_factory=list) task_outbox: List[Task] = Field(default_factory=list) cache_policy: CachePolicy = CachePolicy.KeepAll + + # TODO: implement / use this? + last_event_applied_idx: int = Field(default=0, ge=0)