mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
New events
This commit is contained in:
@@ -1,32 +0,0 @@
|
||||
from hashlib import sha3_224 as hasher
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from shared.types.events.common import EventCategory, EventId, IdemKeyGenerator, State
|
||||
|
||||
|
||||
def get_idem_tag_generator[EventCategoryT: EventCategory](
|
||||
base: str,
|
||||
) -> IdemKeyGenerator[EventCategoryT]:
|
||||
"""Generates idempotency keys for events.
|
||||
|
||||
The keys are generated by hashing the state sequence number against a base string.
|
||||
You can pick any base string, **so long as it's not used in any other function that generates idempotency keys**.
|
||||
"""
|
||||
|
||||
def get_idem_keys(state: State[EventCategoryT], num_keys: int) -> Sequence[EventId]:
|
||||
def recurse(n: int, last: bytes) -> Sequence[EventId]:
|
||||
if n == 0:
|
||||
return []
|
||||
next_hash = hasher(last).digest()
|
||||
return (
|
||||
EventId(UUID(bytes=next_hash, version=4)),
|
||||
*recurse(n - 1, next_hash),
|
||||
)
|
||||
|
||||
initial_bytes = state.last_event_applied_idx.to_bytes(
|
||||
8, byteorder="big", signed=False
|
||||
)
|
||||
return recurse(num_keys, initial_bytes)
|
||||
|
||||
return get_idem_keys
|
||||
@@ -1,7 +1,6 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, LogRecord
|
||||
from queue import Queue as PQueue
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
@@ -19,9 +18,6 @@ from shared.logger import (
|
||||
create_queue_listener,
|
||||
log,
|
||||
)
|
||||
from shared.types.events.common import (
|
||||
EventCategoryEnum,
|
||||
)
|
||||
from shared.types.models import ModelId, ModelMetadata
|
||||
from shared.types.state import State
|
||||
from shared.types.worker.common import InstanceId
|
||||
@@ -51,19 +47,8 @@ def get_state_dependency(data: object, logger: Logger) -> State:
|
||||
return data
|
||||
|
||||
|
||||
# What The Master Cares About
|
||||
MasterEventCategories = (
|
||||
Literal[EventCategoryEnum.MutatesTopologyState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskSagaState]
|
||||
| Literal[EventCategoryEnum.MutatesRunnerStatus]
|
||||
| Literal[EventCategoryEnum.MutatesInstanceState]
|
||||
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
)
|
||||
|
||||
|
||||
# Takes Care Of All States And Events Related To The Master
|
||||
class MasterEventLoopProtocol(NodeEventLoopProtocol[MasterEventCategories]): ...
|
||||
class MasterEventLoopProtocol(NodeEventLoopProtocol): ...
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from queue import Queue
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
from shared.types.events.common import BaseEvent, EventCategory
|
||||
from shared.types.events.registry import Event
|
||||
from shared.types.graphs.topology import Topology
|
||||
from shared.types.state import CachePolicy
|
||||
from shared.types.tasks.common import Task
|
||||
@@ -20,4 +20,4 @@ def get_instance_placement(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, InstanceParams],
|
||||
target_instances: Mapping[InstanceId, InstanceParams],
|
||||
) -> Sequence[BaseEvent[EventCategory]]: ...
|
||||
) -> Sequence[Event]: ...
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Mapping, Type
|
||||
|
||||
|
||||
def check_keys_in_map_match_enum_values[TEnum: StrEnum](
|
||||
mapping_type: Type[Mapping[Any, Any]],
|
||||
enum: Type[TEnum],
|
||||
) -> None:
|
||||
mapping_keys = set(mapping_type.__annotations__.keys())
|
||||
category_values = set(e.value for e in enum)
|
||||
assert mapping_keys == category_values, (
|
||||
f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}"
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
from asyncio import Lock, Queue, Task, create_task
|
||||
from logging import Logger
|
||||
from typing import List, Literal, Protocol, TypedDict
|
||||
|
||||
from master.logging import (
|
||||
StateUpdateEffectHandlerErrorLogEntry,
|
||||
StateUpdateErrorLogEntry,
|
||||
StateUpdateLoopAlreadyRunningLogEntry,
|
||||
StateUpdateLoopNotRunningLogEntry,
|
||||
StateUpdateLoopStartedLogEntry,
|
||||
StateUpdateLoopStoppedLogEntry,
|
||||
)
|
||||
from master.sanity_checking import check_keys_in_map_match_enum_values
|
||||
from shared.constants import get_error_reporting_message
|
||||
from shared.logger import log
|
||||
from shared.types.events.common import (
|
||||
Apply,
|
||||
EffectHandler,
|
||||
EventCategory,
|
||||
EventCategoryEnum,
|
||||
EventFromEventLog,
|
||||
State,
|
||||
StateAndEvent,
|
||||
)
|
||||
|
||||
|
||||
class AsyncStateManager[EventCategoryT: EventCategory](Protocol):
|
||||
"""Protocol for services that manage a specific state domain."""
|
||||
|
||||
_task: Task[None] | None
|
||||
_logger: Logger
|
||||
_apply: Apply[EventCategoryT]
|
||||
_default_effects: List[EffectHandler[EventCategoryT]]
|
||||
extra_effects: List[EffectHandler[EventCategoryT]]
|
||||
state: State[EventCategoryT]
|
||||
queue: Queue[EventFromEventLog[EventCategoryT]]
|
||||
lock: Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: State[EventCategoryT],
|
||||
queue: Queue[EventFromEventLog[EventCategoryT]],
|
||||
extra_effects: List[EffectHandler[EventCategoryT]],
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
"""Initialise the service with its event queue."""
|
||||
self.state = state
|
||||
self.queue = queue
|
||||
self.extra_effects = extra_effects
|
||||
self._logger = logger
|
||||
self._task = None
|
||||
|
||||
async def read_state(self) -> State[EventCategoryT]:
|
||||
"""Get a thread-safe snapshot of this service's state domain."""
|
||||
return self.state.model_copy(deep=True)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the service's event loop is running."""
|
||||
return self._task is not None and not self._task.done()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the service's event loop."""
|
||||
if self.is_running:
|
||||
log(self._logger, StateUpdateLoopAlreadyRunningLogEntry())
|
||||
raise RuntimeError("State Update Loop Already Running")
|
||||
log(self._logger, StateUpdateLoopStartedLogEntry())
|
||||
self._task = create_task(self._event_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the service's event loop."""
|
||||
if not self.is_running:
|
||||
log(self._logger, StateUpdateLoopNotRunningLogEntry())
|
||||
raise RuntimeError("State Update Loop Not Running")
|
||||
|
||||
assert self._task is not None, (
|
||||
f"{get_error_reporting_message()}"
|
||||
"BUG: is_running is True but _task is None, this should never happen!"
|
||||
)
|
||||
self._task.cancel()
|
||||
log(self._logger, StateUpdateLoopStoppedLogEntry())
|
||||
|
||||
async def _event_loop(self) -> None:
|
||||
"""Event loop for the service."""
|
||||
while True:
|
||||
event = await self.queue.get()
|
||||
previous_state = self.state.model_copy(deep=True)
|
||||
try:
|
||||
async with self.lock:
|
||||
updated_state = self._apply(
|
||||
self.state,
|
||||
event,
|
||||
)
|
||||
self.state = updated_state
|
||||
except Exception as e:
|
||||
log(self._logger, StateUpdateErrorLogEntry(error=e))
|
||||
raise e
|
||||
try:
|
||||
for effect_handler in self._default_effects + self.extra_effects:
|
||||
effect_handler(StateAndEvent(previous_state, event), updated_state)
|
||||
except Exception as e:
|
||||
log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e))
|
||||
raise e
|
||||
|
||||
|
||||
class AsyncStateManagerMapping(TypedDict):
|
||||
MutatesTaskState: AsyncStateManager[Literal[EventCategoryEnum.MutatesTaskState]]
|
||||
MutatesTaskSagaState: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesTaskSagaState]
|
||||
]
|
||||
MutatesTopologyState: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesTopologyState]
|
||||
]
|
||||
MutatesRunnerStatus: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesRunnerStatus]
|
||||
]
|
||||
MutatesInstanceState: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesInstanceState]
|
||||
]
|
||||
MutatesNodePerformanceState: AsyncStateManager[
|
||||
Literal[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
]
|
||||
|
||||
|
||||
check_keys_in_map_match_enum_values(AsyncStateManagerMapping, EventCategoryEnum)
|
||||
@@ -1,18 +0,0 @@
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from master.sanity_checking import check_keys_in_map_match_enum_values
|
||||
from shared.types.events.common import EventCategoryEnum, State
|
||||
|
||||
|
||||
class SyncStateManagerMapping(TypedDict):
|
||||
MutatesTaskState: State[Literal[EventCategoryEnum.MutatesTaskState]]
|
||||
MutatesTaskSagaState: State[Literal[EventCategoryEnum.MutatesTaskSagaState]]
|
||||
MutatesTopologyState: State[Literal[EventCategoryEnum.MutatesTopologyState]]
|
||||
MutatesRunnerStatus: State[Literal[EventCategoryEnum.MutatesRunnerStatus]]
|
||||
MutatesInstanceState: State[Literal[EventCategoryEnum.MutatesInstanceState]]
|
||||
MutatesNodePerformanceState: State[
|
||||
Literal[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
]
|
||||
|
||||
|
||||
check_keys_in_map_match_enum_values(SyncStateManagerMapping, EventCategoryEnum)
|
||||
@@ -12,12 +12,9 @@ from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from shared.types.events.common import (
|
||||
BaseEvent,
|
||||
EventCategories,
|
||||
NodeId,
|
||||
)
|
||||
from shared.types.events.registry import Event, EventFromEventLogTyped, EventParser
|
||||
from shared.types.events.common import NodeId
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.events.registry import Event, EventParser
|
||||
|
||||
from .types import StoredEvent
|
||||
|
||||
@@ -53,7 +50,7 @@ class AsyncSQLiteEventStorage:
|
||||
self._max_age_s = max_age_ms / 1000.0
|
||||
self._logger = logger or getLogger(__name__)
|
||||
|
||||
self._write_queue: Queue[tuple[BaseEvent[EventCategories], NodeId]] = Queue()
|
||||
self._write_queue: Queue[tuple[Event, NodeId]] = Queue()
|
||||
self._batch_writer_task: Task[None] | None = None
|
||||
self._engine = None
|
||||
self._closed = False
|
||||
@@ -72,7 +69,7 @@ class AsyncSQLiteEventStorage:
|
||||
|
||||
async def append_events(
|
||||
self,
|
||||
events: Sequence[BaseEvent[EventCategories]],
|
||||
events: Sequence[Event],
|
||||
origin: NodeId
|
||||
) -> None:
|
||||
"""Append events to the log (fire-and-forget). The writes are batched and committed
|
||||
@@ -86,7 +83,7 @@ class AsyncSQLiteEventStorage:
|
||||
async def get_events_since(
|
||||
self,
|
||||
last_idx: int
|
||||
) -> Sequence[EventFromEventLogTyped]:
|
||||
) -> Sequence[EventFromEventLog[Event]]:
|
||||
"""Retrieve events after a specific index."""
|
||||
if self._closed:
|
||||
raise RuntimeError("Storage is closed")
|
||||
@@ -101,7 +98,7 @@ class AsyncSQLiteEventStorage:
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
events: list[EventFromEventLogTyped] = []
|
||||
events: list[EventFromEventLog[Event]] = []
|
||||
for row in rows:
|
||||
rowid: int = cast(int, row[0])
|
||||
origin: str = cast(str, row[1])
|
||||
@@ -112,7 +109,7 @@ class AsyncSQLiteEventStorage:
|
||||
else:
|
||||
event_data = cast(dict[str, Any], raw_event_data)
|
||||
event = await self._deserialize_event(event_data)
|
||||
events.append(EventFromEventLogTyped(
|
||||
events.append(EventFromEventLog(
|
||||
event=event,
|
||||
origin=NodeId(uuid=UUID(origin)),
|
||||
idx_in_log=rowid # rowid becomes idx_in_log
|
||||
@@ -170,7 +167,7 @@ class AsyncSQLiteEventStorage:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
while not self._closed:
|
||||
batch: list[tuple[BaseEvent[EventCategories], NodeId]] = []
|
||||
batch: list[tuple[Event, NodeId]] = []
|
||||
|
||||
try:
|
||||
# Block waiting for first item
|
||||
@@ -208,7 +205,7 @@ class AsyncSQLiteEventStorage:
|
||||
if batch:
|
||||
await self._commit_batch(batch)
|
||||
|
||||
async def _commit_batch(self, batch: list[tuple[BaseEvent[EventCategories], NodeId]]) -> None:
|
||||
async def _commit_batch(self, batch: list[tuple[Event, NodeId]]) -> None:
|
||||
"""Commit a batch of events to SQLite."""
|
||||
assert self._engine is not None
|
||||
|
||||
@@ -218,7 +215,6 @@ class AsyncSQLiteEventStorage:
|
||||
stored_event = StoredEvent(
|
||||
origin=str(origin.uuid),
|
||||
event_type=str(event.event_type),
|
||||
event_category=str(next(iter(event.event_category))),
|
||||
event_id=str(event.event_id),
|
||||
event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion
|
||||
)
|
||||
@@ -237,8 +233,8 @@ class AsyncSQLiteEventStorage:
|
||||
"""Deserialize event data back to typed Event."""
|
||||
# EventParser expects the discriminator field for proper deserialization
|
||||
result = EventParser.validate_python(event_data)
|
||||
# EventParser returns BaseEvent but we know it's actually a specific Event type
|
||||
return result # type: ignore[reportReturnType]
|
||||
# EventParser returns Event type which is our union of all event types
|
||||
return result
|
||||
|
||||
async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return raw event data for testing purposes."""
|
||||
|
||||
@@ -4,12 +4,9 @@ from typing import Any, Protocol, Sequence
|
||||
from sqlalchemy import DateTime, Index
|
||||
from sqlmodel import JSON, Column, Field, SQLModel
|
||||
|
||||
from shared.types.events.common import (
|
||||
BaseEvent,
|
||||
EventCategories,
|
||||
EventFromEventLog,
|
||||
NodeId,
|
||||
)
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.events.registry import Event
|
||||
|
||||
|
||||
class StoredEvent(SQLModel, table=True):
|
||||
@@ -23,7 +20,6 @@ class StoredEvent(SQLModel, table=True):
|
||||
rowid: int | None = Field(default=None, primary_key=True, alias="rowid")
|
||||
origin: str = Field(index=True)
|
||||
event_type: str = Field(index=True)
|
||||
event_category: str = Field(index=True)
|
||||
event_id: str = Field(index=True)
|
||||
event_data: dict[str, Any] = Field(sa_column=Column(JSON))
|
||||
created_at: datetime = Field(
|
||||
@@ -33,7 +29,6 @@ class StoredEvent(SQLModel, table=True):
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_events_origin_created", "origin", "created_at"),
|
||||
Index("idx_events_category_created", "event_category", "created_at"),
|
||||
)
|
||||
|
||||
class EventStorageProtocol(Protocol):
|
||||
@@ -41,7 +36,7 @@ class EventStorageProtocol(Protocol):
|
||||
|
||||
async def append_events(
|
||||
self,
|
||||
events: Sequence[BaseEvent[EventCategories]],
|
||||
events: Sequence[Event],
|
||||
origin: NodeId
|
||||
) -> None:
|
||||
"""Append events to the log (fire-and-forget).
|
||||
@@ -54,7 +49,7 @@ class EventStorageProtocol(Protocol):
|
||||
async def get_events_since(
|
||||
self,
|
||||
last_idx: int
|
||||
) -> Sequence[EventFromEventLog[EventCategories]]:
|
||||
) -> Sequence[EventFromEventLog[Event]]:
|
||||
"""Retrieve events after a specific index.
|
||||
|
||||
Returns events in idx_in_log order.
|
||||
|
||||
@@ -7,7 +7,10 @@ from typing import Any, Hashable, Mapping, Protocol, Sequence
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from shared.event_loops.commands import ExternalCommand
|
||||
from shared.types.events.common import Apply, EventCategory, EventFromEventLog, State
|
||||
from shared.types.events.registry import Event
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.state import State
|
||||
from shared.types.events.components import Apply
|
||||
|
||||
|
||||
class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]):
|
||||
@@ -38,17 +41,16 @@ class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]):
|
||||
return len(self._store)
|
||||
|
||||
|
||||
# Safety on Apply.
|
||||
def safely_apply[T: EventCategory](
|
||||
state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]]
|
||||
) -> State[T]:
|
||||
def apply_events(
|
||||
state: State, apply_fn: Apply, events: Sequence[EventFromEventLog[Event]]
|
||||
) -> State:
|
||||
sorted_events = sorted(events, key=lambda event: event.idx_in_log)
|
||||
state = state.model_copy()
|
||||
for event in sorted_events:
|
||||
if event.idx_in_log <= state.last_event_applied_idx:
|
||||
for wrapped_event in sorted_events:
|
||||
if wrapped_event.idx_in_log <= state.last_event_applied_idx:
|
||||
continue
|
||||
state.last_event_applied_idx = event.idx_in_log
|
||||
state = apply_fn(state, event)
|
||||
state.last_event_applied_idx = wrapped_event.idx_in_log
|
||||
state = apply_fn(state, wrapped_event.event)
|
||||
return state
|
||||
|
||||
|
||||
@@ -69,11 +71,9 @@ class NodeCommandLoopProtocol(Protocol):
|
||||
async def _handle_command(self, command: ExternalCommand) -> None: ...
|
||||
|
||||
|
||||
class NodeEventGetterProtocol[EventCategoryT: EventCategory](Protocol):
|
||||
class NodeEventGetterProtocol(Protocol):
|
||||
_event_fetcher: Task[Any] | None = None
|
||||
_event_queues: ExhaustiveMapping[
|
||||
EventCategoryT, AsyncQueue[EventFromEventLog[EventCategory]]
|
||||
]
|
||||
_event_queue: AsyncQueue[EventFromEventLog[Event]]
|
||||
_logger: Logger
|
||||
|
||||
@property
|
||||
@@ -84,18 +84,18 @@ class NodeEventGetterProtocol[EventCategoryT: EventCategory](Protocol):
|
||||
async def stop_event_fetcher(self) -> None: ...
|
||||
|
||||
|
||||
class NodeStateStorageProtocol[EventCategoryT: EventCategory](Protocol):
|
||||
_state_managers: ExhaustiveMapping[EventCategoryT, State[EventCategoryT]]
|
||||
class NodeStateStorageProtocol(Protocol):
|
||||
_state: State
|
||||
_state_lock: Lock
|
||||
_logger: Logger
|
||||
|
||||
async def _read_state(
|
||||
self, event_category: EventCategoryT
|
||||
) -> State[EventCategoryT]: ...
|
||||
self,
|
||||
) -> State: ...
|
||||
|
||||
|
||||
class NodeStateManagerProtocol[EventCategoryT: EventCategory](
|
||||
NodeEventGetterProtocol[EventCategoryT], NodeStateStorageProtocol[EventCategoryT]
|
||||
class NodeStateManagerProtocol(
|
||||
NodeEventGetterProtocol, NodeStateStorageProtocol
|
||||
):
|
||||
_state_manager: Task[Any] | None = None
|
||||
_logger: Logger
|
||||
@@ -116,6 +116,6 @@ class NodeStateManagerProtocol[EventCategoryT: EventCategory](
|
||||
async def _apply_queued_events(self) -> None: ...
|
||||
|
||||
|
||||
class NodeEventLoopProtocol[EventCategoryT: EventCategory](
|
||||
NodeCommandLoopProtocol, NodeStateManagerProtocol[EventCategoryT]
|
||||
class NodeEventLoopProtocol(
|
||||
NodeCommandLoopProtocol, NodeStateManagerProtocol
|
||||
): ...
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
from asyncio.queues import Queue
|
||||
from typing import Sequence, cast, get_args
|
||||
|
||||
from shared.event_loops.main import ExhaustiveMapping
|
||||
from shared.types.events.common import (
|
||||
EventCategories,
|
||||
EventCategory,
|
||||
EventCategoryEnum,
|
||||
EventFromEventLog,
|
||||
narrow_event_from_event_log_type,
|
||||
)
|
||||
|
||||
"""
|
||||
from asyncio import gather
|
||||
from logging import Logger
|
||||
from typing import Literal, Protocol, Sequence, TypedDict
|
||||
|
||||
from master.sanity_checking import check_keys_in_map_match_enum_values
|
||||
from shared.types.events.common import EventCategoryEnum
|
||||
"""
|
||||
|
||||
"""
|
||||
class EventQueues(TypedDict):
|
||||
MutatesTaskState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
|
||||
]
|
||||
MutatesTaskSagaState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
|
||||
]
|
||||
MutatesControlPlaneState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
|
||||
]
|
||||
MutatesDataPlaneState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
|
||||
]
|
||||
MutatesRunnerStatus: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
|
||||
]
|
||||
MutatesInstanceState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
|
||||
]
|
||||
MutatesNodePerformanceState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
|
||||
]
|
||||
|
||||
|
||||
check_keys_in_map_match_enum_values(EventQueues, EventCategoryEnum)
|
||||
"""
|
||||
|
||||
|
||||
async def route_events[UnionOfRelevantEvents: EventCategory](
|
||||
queue_map: ExhaustiveMapping[
|
||||
UnionOfRelevantEvents, Queue[EventFromEventLog[EventCategory]]
|
||||
],
|
||||
events: Sequence[EventFromEventLog[EventCategory | EventCategories]],
|
||||
) -> None:
|
||||
"""Route an event to the appropriate queue."""
|
||||
tuple_of_categories: tuple[EventCategoryEnum, ...] = get_args(UnionOfRelevantEvents)
|
||||
print(tuple_of_categories)
|
||||
for event in events:
|
||||
if isinstance(event.event.event_category, EventCategoryEnum):
|
||||
category: EventCategory = event.event.event_category
|
||||
if category not in tuple_of_categories:
|
||||
continue
|
||||
narrowed_event = narrow_event_from_event_log_type(event, category)
|
||||
q1: Queue[EventFromEventLog[EventCategory]] = queue_map[
|
||||
cast(UnionOfRelevantEvents, category)
|
||||
] # TODO: make casting unnecessary
|
||||
await q1.put(narrowed_event)
|
||||
else:
|
||||
for category in event.event.event_category:
|
||||
if category not in tuple_of_categories:
|
||||
continue
|
||||
narrow_event = narrow_event_from_event_log_type(event, category)
|
||||
q2 = queue_map[
|
||||
cast(UnionOfRelevantEvents, category)
|
||||
] # TODO: make casting unnecessary
|
||||
await q2.put(narrow_event)
|
||||
@@ -14,8 +14,7 @@ from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData
|
||||
from shared.types.events.events import (
|
||||
ChunkGenerated,
|
||||
EventCategoryEnum,
|
||||
StreamingEventTypes,
|
||||
EventType,
|
||||
)
|
||||
from shared.types.tasks.common import TaskId
|
||||
|
||||
@@ -91,11 +90,10 @@ class TestAsyncSQLiteEventStorage:
|
||||
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{
|
||||
"origin": str(sample_node_id.uuid),
|
||||
"event_type": "test_event",
|
||||
"event_category": "test_category",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps(test_data)
|
||||
}
|
||||
@@ -137,11 +135,10 @@ class TestAsyncSQLiteEventStorage:
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
for record in test_records:
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{
|
||||
"origin": str(sample_node_id.uuid),
|
||||
"event_type": record["event_type"],
|
||||
"event_category": "test_category",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps(record)
|
||||
}
|
||||
@@ -180,18 +177,18 @@ class TestAsyncSQLiteEventStorage:
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
# Origin 1 - record 1
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
|
||||
{"origin": str(origin1.uuid), "event_type": "event_1", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 1})}
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{"origin": str(origin1.uuid), "event_type": "event_1", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 1})}
|
||||
)
|
||||
# Origin 2 - record 2
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
|
||||
{"origin": str(origin2.uuid), "event_type": "event_2", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin2", "seq": 2})}
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{"origin": str(origin2.uuid), "event_type": "event_2", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin2", "seq": 2})}
|
||||
)
|
||||
# Origin 1 - record 3
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
|
||||
{"origin": str(origin1.uuid), "event_type": "event_3", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 3})}
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{"origin": str(origin1.uuid), "event_type": "event_3", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 3})}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -234,11 +231,10 @@ class TestAsyncSQLiteEventStorage:
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
for i in range(10):
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{
|
||||
"origin": str(sample_node_id.uuid),
|
||||
"event_type": f"event_{i}",
|
||||
"event_category": "test",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps({"index": i})
|
||||
}
|
||||
@@ -325,11 +321,10 @@ class TestAsyncSQLiteEventStorage:
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{
|
||||
"origin": str(sample_node_id.uuid),
|
||||
"event_type": "complex_event",
|
||||
"event_category": "test",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps(test_data)
|
||||
}
|
||||
@@ -364,11 +359,10 @@ class TestAsyncSQLiteEventStorage:
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
for i in range(count):
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{
|
||||
"origin": origin_id,
|
||||
"event_type": f"batch_{batch_id}_event_{i}",
|
||||
"event_category": "test",
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps({"batch": batch_id, "item": i})
|
||||
}
|
||||
@@ -425,14 +419,12 @@ class TestAsyncSQLiteEventStorage:
|
||||
)
|
||||
|
||||
chunk_generated_event = ChunkGenerated(
|
||||
event_type=StreamingEventTypes.ChunkGenerated,
|
||||
event_category=EventCategoryEnum.MutatesTaskState,
|
||||
task_id=task_id,
|
||||
chunk=token_chunk
|
||||
)
|
||||
|
||||
# Store the event using the storage API
|
||||
await storage.append_events([chunk_generated_event], sample_node_id) # type: ignore[reportArgumentType]
|
||||
await storage.append_events([chunk_generated_event], sample_node_id)
|
||||
|
||||
# Wait for batch to be written
|
||||
await asyncio.sleep(0.5)
|
||||
@@ -448,8 +440,7 @@ class TestAsyncSQLiteEventStorage:
|
||||
# Verify the event was deserialized correctly
|
||||
retrieved_event = retrieved_event_wrapper.event
|
||||
assert isinstance(retrieved_event, ChunkGenerated)
|
||||
assert retrieved_event.event_type == StreamingEventTypes.ChunkGenerated
|
||||
assert retrieved_event.event_category == EventCategoryEnum.MutatesTaskState
|
||||
assert retrieved_event.event_type == EventType.ChunkGenerated
|
||||
assert retrieved_event.task_id == task_id
|
||||
|
||||
# Verify the nested chunk was deserialized correctly
|
||||
|
||||
10
shared/types/events/categories.py
Normal file
10
shared/types/events/categories.py
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
from shared.types.events.events import (
|
||||
MLXInferenceSagaPrepare,
|
||||
MLXInferenceSagaStartPrepare,
|
||||
)
|
||||
|
||||
TaskSagaEvent = (
|
||||
MLXInferenceSagaPrepare
|
||||
| MLXInferenceSagaStartPrepare
|
||||
)
|
||||
38
shared/types/events/commands.py
Normal file
38
shared/types/events/commands.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NewUUID
|
||||
from shared.types.events.registry import Event
|
||||
from shared.types.state import State
|
||||
|
||||
|
||||
class CommandId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
class CommandTypes(str, Enum):
|
||||
Create = "Create"
|
||||
Update = "Update"
|
||||
Delete = "Delete"
|
||||
|
||||
|
||||
class Command[
|
||||
CommandType: CommandTypes,
|
||||
](BaseModel):
|
||||
command_type: CommandType
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
type Decide[CommandTypeT: CommandTypes] = Callable[
|
||||
[State, Command[CommandTypeT]],
|
||||
Sequence[Event],
|
||||
]
|
||||
@@ -1,26 +1,16 @@
|
||||
from enum import Enum, StrEnum
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
FrozenSet,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Protocol,
|
||||
Sequence,
|
||||
cast,
|
||||
Generic,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NewUUID, NodeId
|
||||
from shared.types.events.sanity_checking import (
|
||||
assert_literal_union_covers_enum,
|
||||
check_event_type_union_is_consistent_with_registry,
|
||||
)
|
||||
|
||||
|
||||
class EventId(NewUUID):
|
||||
@@ -32,114 +22,49 @@ class TimerId(NewUUID):
|
||||
|
||||
|
||||
# Here are all the unique kinds of events that can be sent over the network.
|
||||
# I've defined them in different enums for clarity, but they're all part of the same set of possible events.
|
||||
class TaskSagaEventTypes(str, Enum):
|
||||
class EventType(str, Enum):
|
||||
# Task Saga Events
|
||||
MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare"
|
||||
MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare"
|
||||
|
||||
|
||||
class TaskEventTypes(str, Enum):
|
||||
|
||||
# Task Events
|
||||
TaskCreated = "TaskCreated"
|
||||
TaskStateUpdated = "TaskStateUpdated"
|
||||
TaskDeleted = "TaskDeleted"
|
||||
|
||||
|
||||
class StreamingEventTypes(str, Enum):
|
||||
|
||||
# Streaming Events
|
||||
ChunkGenerated = "ChunkGenerated"
|
||||
|
||||
|
||||
class InstanceEventTypes(str, Enum):
|
||||
|
||||
# Instance Events
|
||||
InstanceCreated = "InstanceCreated"
|
||||
InstanceDeleted = "InstanceDeleted"
|
||||
InstanceActivated = "InstanceActivated"
|
||||
InstanceDeactivated = "InstanceDeactivated"
|
||||
InstanceReplacedAtomically = "InstanceReplacedAtomically"
|
||||
|
||||
|
||||
class RunnerStatusEventTypes(str, Enum):
|
||||
|
||||
# Runner Status Events
|
||||
RunnerStatusUpdated = "RunnerStatusUpdated"
|
||||
|
||||
|
||||
class NodePerformanceEventTypes(str, Enum):
|
||||
|
||||
# Node Performance Events
|
||||
NodePerformanceMeasured = "NodePerformanceMeasured"
|
||||
|
||||
|
||||
class TopologyEventTypes(str, Enum):
|
||||
|
||||
# Topology Events
|
||||
TopologyEdgeCreated = "TopologyEdgeCreated"
|
||||
TopologyEdgeReplacedAtomically = "TopologyEdgeReplacedAtomically"
|
||||
TopologyEdgeDeleted = "TopologyEdgeDeleted"
|
||||
WorkerConnected = "WorkerConnected"
|
||||
WorkerStatusUpdated = "WorkerStatusUpdated"
|
||||
WorkerDisconnected = "WorkerDisconnected"
|
||||
|
||||
|
||||
class TimerEventTypes(str, Enum):
|
||||
|
||||
# Timer Events
|
||||
TimerCreated = "TimerCreated"
|
||||
TimerFired = "TimerFired"
|
||||
|
||||
|
||||
# Registry of all event type enums
|
||||
EVENT_TYPE_ENUMS = [
|
||||
TaskEventTypes,
|
||||
StreamingEventTypes,
|
||||
InstanceEventTypes,
|
||||
RunnerStatusEventTypes,
|
||||
NodePerformanceEventTypes,
|
||||
TopologyEventTypes,
|
||||
TimerEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
]
|
||||
EventTypeT = TypeVar("EventTypeT", bound=EventType)
|
||||
|
||||
|
||||
# Here's the set of all possible events.
|
||||
EventTypes = (
|
||||
TaskEventTypes
|
||||
| StreamingEventTypes
|
||||
| InstanceEventTypes
|
||||
| RunnerStatusEventTypes
|
||||
| NodePerformanceEventTypes
|
||||
| TopologyEventTypes
|
||||
| TimerEventTypes
|
||||
| TaskSagaEventTypes
|
||||
)
|
||||
|
||||
|
||||
check_event_type_union_is_consistent_with_registry(EVENT_TYPE_ENUMS, EventTypes)
|
||||
|
||||
|
||||
class EventCategoryEnum(StrEnum):
|
||||
MutatesTaskState = "MutatesTaskState"
|
||||
MutatesTaskSagaState = "MutatesTaskSagaState"
|
||||
MutatesRunnerStatus = "MutatesRunnerStatus"
|
||||
MutatesInstanceState = "MutatesInstanceState"
|
||||
MutatesNodePerformanceState = "MutatesNodePerformanceState"
|
||||
MutatesTopologyState = "MutatesTopologyState"
|
||||
|
||||
|
||||
EventCategory = (
|
||||
Literal[EventCategoryEnum.MutatesTopologyState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskSagaState]
|
||||
| Literal[EventCategoryEnum.MutatesRunnerStatus]
|
||||
| Literal[EventCategoryEnum.MutatesInstanceState]
|
||||
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
| Literal[EventCategoryEnum.MutatesTopologyState]
|
||||
)
|
||||
|
||||
EventCategories = FrozenSet[EventCategory]
|
||||
|
||||
assert_literal_union_covers_enum(EventCategory, EventCategoryEnum)
|
||||
|
||||
|
||||
EventTypeT = EventTypes # Type Alias placeholder; generic parameter will override
|
||||
|
||||
|
||||
class BaseEvent[
|
||||
SetMembersT: EventCategories | EventCategory,
|
||||
EventTypeLitT: EventTypes = EventTypes,
|
||||
](BaseModel):
|
||||
event_type: EventTypeLitT
|
||||
event_category: SetMembersT
|
||||
class BaseEvent(BaseModel, Generic[EventTypeT]):
|
||||
event_type: EventTypeT
|
||||
event_id: EventId = EventId()
|
||||
|
||||
def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool:
|
||||
@@ -151,129 +76,4 @@ class BaseEvent[
|
||||
return True
|
||||
|
||||
|
||||
class EventFromEventLog[SetMembersT: EventCategories | EventCategory](BaseModel):
|
||||
event: BaseEvent[SetMembersT]
|
||||
origin: NodeId
|
||||
idx_in_log: int = Field(gt=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_event_was_sent_by_correct_node(
|
||||
self,
|
||||
) -> "EventFromEventLog[SetMembersT]":
|
||||
if self.event.check_event_was_sent_by_correct_node(self.origin):
|
||||
return self
|
||||
raise ValueError("Invalid Event: Origin ID Does Not Match")
|
||||
|
||||
|
||||
def narrow_event_type[T: EventCategory, Q: EventCategories | EventCategory](
|
||||
event: BaseEvent[Q],
|
||||
target_category: T,
|
||||
) -> BaseEvent[T]:
|
||||
if target_category not in event.event_category:
|
||||
raise ValueError(f"Event Does Not Contain Target Category {target_category}")
|
||||
|
||||
narrowed_event = event.model_copy(update={"event_category": {target_category}})
|
||||
return cast(BaseEvent[T], narrowed_event)
|
||||
|
||||
|
||||
def narrow_event_from_event_log_type[
|
||||
T: EventCategory,
|
||||
Q: EventCategories | EventCategory,
|
||||
](
|
||||
event: EventFromEventLog[Q],
|
||||
target_category: T,
|
||||
) -> EventFromEventLog[T]:
|
||||
if target_category not in event.event.event_category:
|
||||
raise ValueError(f"Event Does Not Contain Target Category {target_category}")
|
||||
narrowed_event = event.model_copy(
|
||||
update={"event": narrow_event_type(event.event, target_category)}
|
||||
)
|
||||
|
||||
return cast(EventFromEventLog[T], narrowed_event)
|
||||
|
||||
|
||||
class State[EventCategoryT: EventCategory](BaseModel):
|
||||
event_category: EventCategoryT
|
||||
last_event_applied_idx: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
# Definitions for Type Variables
|
||||
type Saga[EventCategoryT: EventCategory] = Callable[
|
||||
[State[EventCategoryT], EventFromEventLog[EventCategoryT]],
|
||||
Sequence[BaseEvent[EventCategories]],
|
||||
]
|
||||
type Apply[EventCategoryT: EventCategory] = Callable[
|
||||
[State[EventCategoryT], EventFromEventLog[EventCategoryT]],
|
||||
State[EventCategoryT],
|
||||
]
|
||||
|
||||
|
||||
class StateAndEvent[EventCategoryT: EventCategory](NamedTuple):
|
||||
state: State[EventCategoryT]
|
||||
event: EventFromEventLog[EventCategoryT]
|
||||
|
||||
|
||||
type EffectHandler[EventCategoryT: EventCategory] = Callable[
|
||||
[StateAndEvent[EventCategoryT], State[EventCategoryT]], None
|
||||
]
|
||||
type EventPublisher = Callable[[BaseEvent[Any]], None]
|
||||
|
||||
|
||||
# A component that can publish events
|
||||
class EventPublisherProtocol(Protocol):
|
||||
def send(self, events: Sequence[BaseEvent[EventCategories]]) -> None: ...
|
||||
|
||||
|
||||
# A component that can fetch events to apply
|
||||
class EventFetcherProtocol[EventCategoryT: EventCategory](Protocol):
|
||||
def get_events_to_apply(
|
||||
self, state: State[EventCategoryT]
|
||||
) -> Sequence[BaseEvent[EventCategoryT]]: ...
|
||||
|
||||
|
||||
# A component that can get the effect handler for a saga
|
||||
def get_saga_effect_handler[EventCategoryT: EventCategory](
|
||||
saga: Saga[EventCategoryT], event_publisher: EventPublisher
|
||||
) -> EffectHandler[EventCategoryT]:
|
||||
def effect_handler(state_and_event: StateAndEvent[EventCategoryT]) -> None:
|
||||
trigger_state, trigger_event = state_and_event
|
||||
for event in saga(trigger_state, trigger_event):
|
||||
event_publisher(event)
|
||||
|
||||
return lambda state_and_event, _: effect_handler(state_and_event)
|
||||
|
||||
|
||||
def get_effects_from_sagas[EventCategoryT: EventCategory](
|
||||
sagas: Sequence[Saga[EventCategoryT]],
|
||||
event_publisher: EventPublisher,
|
||||
) -> Sequence[EffectHandler[EventCategoryT]]:
|
||||
return [get_saga_effect_handler(saga, event_publisher) for saga in sagas]
|
||||
|
||||
|
||||
type IdemKeyGenerator[EventCategoryT: EventCategory] = Callable[
|
||||
[State[EventCategoryT], int], Sequence[EventId]
|
||||
]
|
||||
|
||||
|
||||
class CommandId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
class CommandTypes(str, Enum):
|
||||
Create = "Create"
|
||||
Update = "Update"
|
||||
Delete = "Delete"
|
||||
|
||||
|
||||
class Command[
|
||||
EventCategoryT: EventCategories | EventCategory,
|
||||
CommandType: CommandTypes,
|
||||
](BaseModel):
|
||||
command_type: CommandType
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
type Decide[EventCategoryT: EventCategory, CommandTypeT: CommandTypes] = Callable[
|
||||
[State[EventCategoryT], Command[EventCategoryT, CommandTypeT]],
|
||||
Sequence[BaseEvent[EventCategoryT]],
|
||||
]
|
||||
|
||||
38
shared/types/events/components.py
Normal file
38
shared/types/events/components.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# components.py defines the small event functions, adapters etc.
|
||||
# this name could probably be improved.
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.state import State
|
||||
from shared.types.events.registry import Event
|
||||
|
||||
|
||||
class EventFromEventLog[T: Event](BaseModel):
|
||||
event: T
|
||||
origin: NodeId
|
||||
idx_in_log: int = Field(gt=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_event_was_sent_by_correct_node(
|
||||
self,
|
||||
) -> "EventFromEventLog[T]":
|
||||
if self.event.check_event_was_sent_by_correct_node(self.origin):
|
||||
return self
|
||||
raise ValueError("Invalid Event: Origin ID Does Not Match")
|
||||
|
||||
|
||||
|
||||
type Apply = Callable[
|
||||
[State, Event],
|
||||
State
|
||||
]
|
||||
@@ -6,14 +6,8 @@ from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import GenerationChunk
|
||||
from shared.types.events.common import (
|
||||
BaseEvent,
|
||||
EventCategoryEnum,
|
||||
InstanceEventTypes,
|
||||
NodePerformanceEventTypes,
|
||||
RunnerStatusEventTypes,
|
||||
StreamingEventTypes,
|
||||
TaskEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
TopologyEventTypes,
|
||||
EventType,
|
||||
TimerId,
|
||||
)
|
||||
from shared.types.graphs.topology import (
|
||||
TopologyEdge,
|
||||
@@ -27,156 +21,122 @@ from shared.types.worker.common import InstanceId, NodeStatus
|
||||
from shared.types.worker.instances import InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
|
||||
TaskEvent = BaseEvent[EventCategoryEnum.MutatesTaskState]
|
||||
InstanceEvent = BaseEvent[EventCategoryEnum.MutatesInstanceState]
|
||||
TopologyEvent = BaseEvent[EventCategoryEnum.MutatesTopologyState]
|
||||
NodePerformanceEvent = BaseEvent[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
|
||||
|
||||
class TaskCreated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskCreated]]):
|
||||
event_type: Literal[TaskEventTypes.TaskCreated] = TaskEventTypes.TaskCreated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
|
||||
class TaskCreated(BaseEvent[EventType.TaskCreated]):
|
||||
event_type: Literal[EventType.TaskCreated] = EventType.TaskCreated
|
||||
task_id: TaskId
|
||||
task: Task
|
||||
|
||||
|
||||
# Covers Cancellation Of Task, Non-Cancelled Tasks Perist
|
||||
class TaskDeleted(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskDeleted]]):
|
||||
event_type: Literal[TaskEventTypes.TaskDeleted] = TaskEventTypes.TaskDeleted
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
|
||||
class TaskDeleted(BaseEvent[EventType.TaskDeleted]):
|
||||
event_type: Literal[EventType.TaskDeleted] = EventType.TaskDeleted
|
||||
task_id: TaskId
|
||||
|
||||
|
||||
class TaskStateUpdated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskStateUpdated]]):
|
||||
event_type: Literal[TaskEventTypes.TaskStateUpdated] = TaskEventTypes.TaskStateUpdated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
|
||||
class TaskStateUpdated(BaseEvent[EventType.TaskStateUpdated]):
|
||||
event_type: Literal[EventType.TaskStateUpdated] = EventType.TaskStateUpdated
|
||||
task_id: TaskId
|
||||
task_status: TaskStatus
|
||||
|
||||
|
||||
class InstanceCreated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceCreated]]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceCreated] = InstanceEventTypes.InstanceCreated
|
||||
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
|
||||
class InstanceCreated(BaseEvent[EventType.InstanceCreated]):
|
||||
event_type: Literal[EventType.InstanceCreated] = EventType.InstanceCreated
|
||||
instance_id: InstanceId
|
||||
instance_params: InstanceParams
|
||||
instance_type: TypeOfInstance
|
||||
|
||||
|
||||
class InstanceActivated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceActivated]]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceActivated] = InstanceEventTypes.InstanceActivated
|
||||
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
|
||||
class InstanceActivated(BaseEvent[EventType.InstanceActivated]):
|
||||
event_type: Literal[EventType.InstanceActivated] = EventType.InstanceActivated
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceDeactivated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceDeactivated]]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceDeactivated] = InstanceEventTypes.InstanceDeactivated
|
||||
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
|
||||
class InstanceDeactivated(BaseEvent[EventType.InstanceDeactivated]):
|
||||
event_type: Literal[EventType.InstanceDeactivated] = EventType.InstanceDeactivated
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceDeleted(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceDeleted]]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceDeleted] = InstanceEventTypes.InstanceDeleted
|
||||
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
|
||||
class InstanceDeleted(BaseEvent[EventType.InstanceDeleted]):
|
||||
event_type: Literal[EventType.InstanceDeleted] = EventType.InstanceDeleted
|
||||
instance_id: InstanceId
|
||||
|
||||
transition: Tuple[InstanceId, InstanceId]
|
||||
|
||||
|
||||
class InstanceReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceReplacedAtomically]]):
|
||||
event_type: Literal[InstanceEventTypes.InstanceReplacedAtomically] = InstanceEventTypes.InstanceReplacedAtomically
|
||||
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
|
||||
class InstanceReplacedAtomically(BaseEvent[EventType.InstanceReplacedAtomically]):
|
||||
event_type: Literal[EventType.InstanceReplacedAtomically] = EventType.InstanceReplacedAtomically
|
||||
instance_to_replace: InstanceId
|
||||
new_instance_id: InstanceId
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesRunnerStatus, Literal[RunnerStatusEventTypes.RunnerStatusUpdated]]):
|
||||
event_type: Literal[RunnerStatusEventTypes.RunnerStatusUpdated] = RunnerStatusEventTypes.RunnerStatusUpdated
|
||||
event_category: Literal[EventCategoryEnum.MutatesRunnerStatus] = EventCategoryEnum.MutatesRunnerStatus
|
||||
class RunnerStatusUpdated(BaseEvent[EventType.RunnerStatusUpdated]):
|
||||
event_type: Literal[EventType.RunnerStatusUpdated] = EventType.RunnerStatusUpdated
|
||||
runner_id: RunnerId
|
||||
runner_status: RunnerStatus
|
||||
|
||||
|
||||
class MLXInferenceSagaPrepare(BaseEvent[EventCategoryEnum.MutatesTaskSagaState, Literal[TaskSagaEventTypes.MLXInferenceSagaPrepare]]):
|
||||
event_type: Literal[TaskSagaEventTypes.MLXInferenceSagaPrepare] = TaskSagaEventTypes.MLXInferenceSagaPrepare
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskSagaState] = EventCategoryEnum.MutatesTaskSagaState
|
||||
class MLXInferenceSagaPrepare(BaseEvent[EventType.MLXInferenceSagaPrepare]):
|
||||
event_type: Literal[EventType.MLXInferenceSagaPrepare] = EventType.MLXInferenceSagaPrepare
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class MLXInferenceSagaStartPrepare(BaseEvent[EventCategoryEnum.MutatesTaskSagaState, Literal[TaskSagaEventTypes.MLXInferenceSagaStartPrepare]]):
|
||||
event_type: Literal[TaskSagaEventTypes.MLXInferenceSagaStartPrepare] = TaskSagaEventTypes.MLXInferenceSagaStartPrepare
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskSagaState] = EventCategoryEnum.MutatesTaskSagaState
|
||||
class MLXInferenceSagaStartPrepare(BaseEvent[EventType.MLXInferenceSagaStartPrepare]):
|
||||
event_type: Literal[EventType.MLXInferenceSagaStartPrepare] = EventType.MLXInferenceSagaStartPrepare
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class NodePerformanceMeasured(BaseEvent[EventCategoryEnum.MutatesNodePerformanceState, Literal[NodePerformanceEventTypes.NodePerformanceMeasured]]):
|
||||
event_type: Literal[NodePerformanceEventTypes.NodePerformanceMeasured] = NodePerformanceEventTypes.NodePerformanceMeasured
|
||||
event_category: Literal[EventCategoryEnum.MutatesNodePerformanceState] = EventCategoryEnum.MutatesNodePerformanceState
|
||||
class NodePerformanceMeasured(BaseEvent[EventType.NodePerformanceMeasured]):
|
||||
event_type: Literal[EventType.NodePerformanceMeasured] = EventType.NodePerformanceMeasured
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class WorkerConnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerConnected]]):
|
||||
event_type: Literal[TopologyEventTypes.WorkerConnected] = TopologyEventTypes.WorkerConnected
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
class WorkerConnected(BaseEvent[EventType.WorkerConnected]):
|
||||
event_type: Literal[EventType.WorkerConnected] = EventType.WorkerConnected
|
||||
edge: TopologyEdge
|
||||
|
||||
|
||||
class WorkerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerStatusUpdated]]):
|
||||
event_type: Literal[TopologyEventTypes.WorkerStatusUpdated] = TopologyEventTypes.WorkerStatusUpdated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
class WorkerStatusUpdated(BaseEvent[EventType.WorkerStatusUpdated]):
|
||||
event_type: Literal[EventType.WorkerStatusUpdated] = EventType.WorkerStatusUpdated
|
||||
node_id: NodeId
|
||||
node_state: NodeStatus
|
||||
|
||||
|
||||
class WorkerDisconnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerDisconnected]]):
|
||||
event_type: Literal[TopologyEventTypes.WorkerDisconnected] = TopologyEventTypes.WorkerDisconnected
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
class WorkerDisconnected(BaseEvent[EventType.WorkerDisconnected]):
|
||||
event_type: Literal[EventType.WorkerDisconnected] = EventType.WorkerDisconnected
|
||||
vertex_id: NodeId
|
||||
|
||||
|
||||
class ChunkGenerated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[StreamingEventTypes.ChunkGenerated]]):
|
||||
event_type: Literal[StreamingEventTypes.ChunkGenerated] = StreamingEventTypes.ChunkGenerated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
|
||||
class ChunkGenerated(BaseEvent[EventType.ChunkGenerated]):
|
||||
event_type: Literal[EventType.ChunkGenerated] = EventType.ChunkGenerated
|
||||
task_id: TaskId
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeCreated]]):
|
||||
event_type: Literal[TopologyEventTypes.TopologyEdgeCreated] = TopologyEventTypes.TopologyEdgeCreated
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
class TopologyEdgeCreated(BaseEvent[EventType.TopologyEdgeCreated]):
|
||||
event_type: Literal[EventType.TopologyEdgeCreated] = EventType.TopologyEdgeCreated
|
||||
vertex: TopologyNode
|
||||
|
||||
|
||||
class TopologyEdgeReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically]]):
|
||||
event_type: Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically] = TopologyEventTypes.TopologyEdgeReplacedAtomically
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
class TopologyEdgeReplacedAtomically(BaseEvent[EventType.TopologyEdgeReplacedAtomically]):
|
||||
event_type: Literal[EventType.TopologyEdgeReplacedAtomically] = EventType.TopologyEdgeReplacedAtomically
|
||||
edge_id: TopologyEdgeId
|
||||
edge_profile: TopologyEdgeProfile
|
||||
|
||||
|
||||
class TopologyEdgeDeleted(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeDeleted]]):
|
||||
event_type: Literal[TopologyEventTypes.TopologyEdgeDeleted] = TopologyEventTypes.TopologyEdgeDeleted
|
||||
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
|
||||
class TopologyEdgeDeleted(BaseEvent[EventType.TopologyEdgeDeleted]):
|
||||
event_type: Literal[EventType.TopologyEdgeDeleted] = EventType.TopologyEdgeDeleted
|
||||
edge_id: TopologyEdgeId
|
||||
|
||||
"""
|
||||
TEST_EVENT_CATEGORIES_TYPE = FrozenSet[
|
||||
Literal[
|
||||
EventCategoryEnum.MutatesTaskState,
|
||||
EventCategoryEnum.MutatesControlPlaneState,
|
||||
]
|
||||
]
|
||||
TEST_EVENT_CATEGORIES = frozenset(
|
||||
(
|
||||
EventCategoryEnum.MutatesTaskState,
|
||||
EventCategoryEnum.MutatesControlPlaneState,
|
||||
)
|
||||
)
|
||||
|
||||
class TimerCreated(BaseEvent[EventType.TimerCreated]):
|
||||
event_type: Literal[EventType.TimerCreated] = EventType.TimerCreated
|
||||
timer_id: TimerId
|
||||
delay_seconds: float
|
||||
|
||||
|
||||
class TestEvent(BaseEvent[TEST_EVENT_CATEGORIES_TYPE]):
|
||||
event_category: TEST_EVENT_CATEGORIES_TYPE = TEST_EVENT_CATEGORIES
|
||||
test_id: int
|
||||
"""
|
||||
class TimerFired(BaseEvent[EventType.TimerFired]):
|
||||
event_type: Literal[EventType.TimerFired] = EventType.TimerFired
|
||||
timer_id: TimerId
|
||||
@@ -1,25 +1,15 @@
|
||||
from types import UnionType
|
||||
from typing import Annotated, Any, Mapping, Type, get_args
|
||||
from typing import Annotated, Any, Mapping, Type, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
from shared.constants import get_error_reporting_message
|
||||
from shared.types.events.common import (
|
||||
BaseEvent,
|
||||
EventCategories,
|
||||
EventTypes,
|
||||
InstanceEventTypes,
|
||||
NodeId,
|
||||
NodePerformanceEventTypes,
|
||||
RunnerStatusEventTypes,
|
||||
StreamingEventTypes,
|
||||
TaskEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
TopologyEventTypes,
|
||||
EventType,
|
||||
)
|
||||
from shared.types.events.events import (
|
||||
ChunkGenerated,
|
||||
InstanceActivated,
|
||||
InstanceCreated,
|
||||
InstanceDeactivated,
|
||||
InstanceDeleted,
|
||||
InstanceReplacedAtomically,
|
||||
MLXInferenceSagaPrepare,
|
||||
@@ -29,6 +19,8 @@ from shared.types.events.events import (
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStateUpdated,
|
||||
TimerCreated,
|
||||
TimerFired,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
TopologyEdgeReplacedAtomically,
|
||||
@@ -36,6 +28,11 @@ from shared.types.events.events import (
|
||||
WorkerDisconnected,
|
||||
WorkerStatusUpdated,
|
||||
)
|
||||
from shared.types.events.sanity_checking import (
|
||||
assert_event_union_covers_registry,
|
||||
check_registry_has_all_event_types,
|
||||
check_union_of_all_events_is_consistent_with_registry,
|
||||
)
|
||||
|
||||
"""
|
||||
class EventTypeNames(StrEnum):
|
||||
@@ -50,63 +47,38 @@ class EventTypeNames(StrEnum):
|
||||
|
||||
check_event_categories_are_defined_for_all_event_types(EVENT_TYPE_ENUMS, EventTypeNames)
|
||||
"""
|
||||
EventRegistry: Mapping[EventTypes, Type[Any]] = {
|
||||
TaskEventTypes.TaskCreated: TaskCreated,
|
||||
TaskEventTypes.TaskStateUpdated: TaskStateUpdated,
|
||||
TaskEventTypes.TaskDeleted: TaskDeleted,
|
||||
InstanceEventTypes.InstanceCreated: InstanceCreated,
|
||||
InstanceEventTypes.InstanceDeleted: InstanceDeleted,
|
||||
InstanceEventTypes.InstanceReplacedAtomically: InstanceReplacedAtomically,
|
||||
RunnerStatusEventTypes.RunnerStatusUpdated: RunnerStatusUpdated,
|
||||
NodePerformanceEventTypes.NodePerformanceMeasured: NodePerformanceMeasured,
|
||||
TopologyEventTypes.WorkerConnected: WorkerConnected,
|
||||
TopologyEventTypes.WorkerStatusUpdated: WorkerStatusUpdated,
|
||||
TopologyEventTypes.WorkerDisconnected: WorkerDisconnected,
|
||||
StreamingEventTypes.ChunkGenerated: ChunkGenerated,
|
||||
TopologyEventTypes.TopologyEdgeCreated: TopologyEdgeCreated,
|
||||
TopologyEventTypes.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically,
|
||||
TopologyEventTypes.TopologyEdgeDeleted: TopologyEdgeDeleted,
|
||||
TaskSagaEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
|
||||
TaskSagaEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
|
||||
EventRegistry: Mapping[EventType, Type[Any]] = {
|
||||
EventType.TaskCreated: TaskCreated,
|
||||
EventType.TaskStateUpdated: TaskStateUpdated,
|
||||
EventType.TaskDeleted: TaskDeleted,
|
||||
EventType.InstanceCreated: InstanceCreated,
|
||||
EventType.InstanceActivated: InstanceActivated,
|
||||
EventType.InstanceDeactivated: InstanceDeactivated,
|
||||
EventType.InstanceDeleted: InstanceDeleted,
|
||||
EventType.InstanceReplacedAtomically: InstanceReplacedAtomically,
|
||||
EventType.RunnerStatusUpdated: RunnerStatusUpdated,
|
||||
EventType.NodePerformanceMeasured: NodePerformanceMeasured,
|
||||
EventType.WorkerConnected: WorkerConnected,
|
||||
EventType.WorkerStatusUpdated: WorkerStatusUpdated,
|
||||
EventType.WorkerDisconnected: WorkerDisconnected,
|
||||
EventType.ChunkGenerated: ChunkGenerated,
|
||||
EventType.TopologyEdgeCreated: TopologyEdgeCreated,
|
||||
EventType.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically,
|
||||
EventType.TopologyEdgeDeleted: TopologyEdgeDeleted,
|
||||
EventType.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
|
||||
EventType.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
|
||||
EventType.TimerCreated: TimerCreated,
|
||||
EventType.TimerFired: TimerFired,
|
||||
}
|
||||
|
||||
|
||||
# Sanity Check.
|
||||
def check_registry_has_all_event_types() -> None:
|
||||
event_types: tuple[EventTypes, ...] = get_args(EventTypes)
|
||||
missing_event_types = set(event_types) - set(EventRegistry.keys())
|
||||
|
||||
assert not missing_event_types, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"There's an event missing from the registry: {missing_event_types}"
|
||||
)
|
||||
|
||||
|
||||
def check_union_of_all_events_is_consistent_with_registry(
|
||||
registry: Mapping[EventTypes, Type[Any]], union_type: UnionType
|
||||
) -> None:
|
||||
type_of_each_registry_entry = set(registry.values())
|
||||
type_of_each_entry_in_union = set(get_args(union_type))
|
||||
missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union
|
||||
|
||||
assert not missing_from_union, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"Event classes in registry are missing from all_events union: {missing_from_union}"
|
||||
)
|
||||
|
||||
extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry
|
||||
|
||||
assert not extra_in_union, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"Event classes in all_events union are missing from registry: {extra_in_union}"
|
||||
)
|
||||
|
||||
|
||||
Event = (
|
||||
AllEventsUnion = (
|
||||
TaskCreated
|
||||
| TaskStateUpdated
|
||||
| TaskDeleted
|
||||
| InstanceCreated
|
||||
| InstanceActivated
|
||||
| InstanceDeactivated
|
||||
| InstanceDeleted
|
||||
| InstanceReplacedAtomically
|
||||
| RunnerStatusUpdated
|
||||
@@ -120,24 +92,16 @@ Event = (
|
||||
| TopologyEdgeDeleted
|
||||
| MLXInferenceSagaPrepare
|
||||
| MLXInferenceSagaStartPrepare
|
||||
| TimerCreated
|
||||
| TimerFired
|
||||
)
|
||||
|
||||
# Run the sanity check
|
||||
check_union_of_all_events_is_consistent_with_registry(EventRegistry, Event)
|
||||
Event: TypeAlias = Annotated[AllEventsUnion, Field(discriminator="event_type")]
|
||||
EventParser: TypeAdapter[Event] = TypeAdapter(Event)
|
||||
|
||||
|
||||
_EventType = Annotated[Event, Field(discriminator="event_type")]
|
||||
EventParser: TypeAdapter[BaseEvent[EventCategories]] = TypeAdapter(_EventType)
|
||||
|
||||
|
||||
# Define a properly typed EventFromEventLog that preserves specific event types
|
||||
|
||||
class EventFromEventLogTyped(BaseModel):
|
||||
"""Properly typed EventFromEventLog that preserves specific event types."""
|
||||
event: _EventType
|
||||
origin: NodeId
|
||||
idx_in_log: int = Field(gt=0)
|
||||
|
||||
def check_event_was_sent_by_correct_node(self) -> bool:
|
||||
"""Check if the event was sent by the correct node."""
|
||||
return self.event.check_event_was_sent_by_correct_node(self.origin)
|
||||
assert_event_union_covers_registry(AllEventsUnion)
|
||||
check_union_of_all_events_is_consistent_with_registry(EventRegistry, AllEventsUnion)
|
||||
check_registry_has_all_event_types(EventRegistry)
|
||||
@@ -1,68 +1,75 @@
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum
|
||||
from types import UnionType
|
||||
from typing import Any, LiteralString, Sequence, Set, Type, get_args
|
||||
from typing import Any, Mapping, Set, Type, cast, get_args
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from shared.constants import get_error_reporting_message
|
||||
from shared.types.events.common import EventType
|
||||
|
||||
|
||||
def check_event_type_union_is_consistent_with_registry(
|
||||
event_type_enums: Sequence[Type[Enum]], event_types: UnionType
|
||||
) -> None:
|
||||
"""Assert that every enum value from _EVENT_TYPE_ENUMS satisfies EventTypes."""
|
||||
|
||||
event_types_inferred_from_union = set(get_args(event_types))
|
||||
|
||||
event_types_inferred_from_registry = [
|
||||
member for enum_class in event_type_enums for member in enum_class
|
||||
]
|
||||
|
||||
# Check that each registry value belongs to one of the types in the union
|
||||
for tag_of_event_type in event_types_inferred_from_registry:
|
||||
event_type = type(tag_of_event_type)
|
||||
assert event_type in event_types_inferred_from_union, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"There's a mismatch between the registry of event types and the union of possible event types."
|
||||
f"The enum value {tag_of_event_type} for type {event_type} is not covered by {event_types_inferred_from_union}."
|
||||
)
|
||||
|
||||
|
||||
def check_event_categories_are_defined_for_all_event_types(
|
||||
event_definitions: Sequence[Type[Enum]], event_categories: Type[StrEnum]
|
||||
) -> None:
|
||||
"""Assert that the event category names are consistent with the event type enums."""
|
||||
|
||||
expected_category_tags: list[str] = [
|
||||
enum_class.__name__ for enum_class in event_definitions
|
||||
]
|
||||
tag_of_event_categories: list[str] = list(event_categories.__members__.values())
|
||||
assert tag_of_event_categories == expected_category_tags, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"The values of the enum EventCategories are not named after the event type enums."
|
||||
f"These are the missing categories: {set(expected_category_tags) - set(tag_of_event_categories)}"
|
||||
f"These are the extra categories: {set(tag_of_event_categories) - set(expected_category_tags)}"
|
||||
)
|
||||
|
||||
|
||||
def assert_literal_union_covers_enum[TEnum: StrEnum](
|
||||
def assert_event_union_covers_registry[TEnum: StrEnum](
|
||||
literal_union: UnionType,
|
||||
enum_type: Type[TEnum],
|
||||
) -> None:
|
||||
enum_values: Set[Any] = {member.value for member in enum_type}
|
||||
"""
|
||||
Ensure that our union of events (AllEventsUnion) has one member per element of Enum
|
||||
"""
|
||||
enum_values: Set[str] = {member.value for member in EventType}
|
||||
|
||||
def _flatten(tp: UnionType) -> Set[Any]:
|
||||
values: Set[Any] = set()
|
||||
args: tuple[LiteralString, ...] = get_args(tp)
|
||||
for arg in args:
|
||||
payloads: tuple[TEnum, ...] = get_args(arg)
|
||||
for payload in payloads:
|
||||
values.add(payload.value)
|
||||
def _flatten(tp: UnionType) -> Set[str]:
|
||||
values: Set[str] = set()
|
||||
args = get_args(tp) # Get event classes from the union
|
||||
for arg in args: # type: ignore[reportAny]
|
||||
# Cast to type since we know these are class types
|
||||
event_class = cast(type[Any], arg)
|
||||
# Each event class is a Pydantic model with model_fields
|
||||
if hasattr(event_class, 'model_fields'):
|
||||
model_fields = cast(dict[str, FieldInfo], event_class.model_fields)
|
||||
if 'event_type' in model_fields:
|
||||
# Get the default value of the event_type field
|
||||
event_type_field: FieldInfo = model_fields['event_type']
|
||||
if hasattr(event_type_field, 'default'):
|
||||
default_value = cast(EventType, event_type_field.default)
|
||||
# The default is an EventType enum member, get its value
|
||||
values.add(default_value.value)
|
||||
return values
|
||||
|
||||
literal_values: Set[Any] = _flatten(literal_union)
|
||||
literal_values: Set[str] = _flatten(literal_union)
|
||||
|
||||
assert enum_values == literal_values, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"The values of the enum {enum_type} are not covered by the literal union {literal_union}.\n"
|
||||
f"The values of the enum {EventType} are not covered by the literal union {literal_union}.\n"
|
||||
f"These are the missing values: {enum_values - literal_values}\n"
|
||||
f"These are the extra values: {literal_values - enum_values}\n"
|
||||
)
|
||||
|
||||
def check_union_of_all_events_is_consistent_with_registry(
|
||||
registry: Mapping[EventType, Type[Any]], union_type: UnionType
|
||||
) -> None:
|
||||
type_of_each_registry_entry = set(registry.values())
|
||||
type_of_each_entry_in_union = set(get_args(union_type))
|
||||
missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union
|
||||
|
||||
assert not missing_from_union, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"Event classes in registry are missing from all_events union: {missing_from_union}"
|
||||
)
|
||||
|
||||
extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry
|
||||
|
||||
assert not extra_in_union, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"Event classes in all_events union are missing from registry: {extra_in_union}"
|
||||
)
|
||||
|
||||
def check_registry_has_all_event_types(event_registry: Mapping[EventType, Type[Any]]) -> None:
|
||||
event_types: tuple[EventType, ...] = get_args(EventType)
|
||||
missing_event_types = set(event_types) - set(event_registry.keys())
|
||||
|
||||
assert not missing_event_types, (
|
||||
f"{get_error_reporting_message()}"
|
||||
f"There's an event missing from the registry: {missing_event_types}"
|
||||
)
|
||||
|
||||
# TODO: Check all events have an apply function.
|
||||
# probably in a different place though.
|
||||
@@ -39,3 +39,6 @@ class State(BaseModel):
|
||||
task_inbox: List[Task] = Field(default_factory=list)
|
||||
task_outbox: List[Task] = Field(default_factory=list)
|
||||
cache_policy: CachePolicy = CachePolicy.KeepAll
|
||||
|
||||
# TODO: implement / use this?
|
||||
last_event_applied_idx: int = Field(default=0, ge=0)
|
||||
|
||||
Reference in New Issue
Block a user