New events

This commit is contained in:
Matt Beton
2025-07-22 15:16:06 +01:00
committed by GitHub
parent 108128b620
commit 5adad08e09
19 changed files with 318 additions and 797 deletions

View File

@@ -1,32 +0,0 @@
from hashlib import sha3_224 as hasher
from typing import Sequence
from uuid import UUID
from shared.types.events.common import EventCategory, EventId, IdemKeyGenerator, State
def get_idem_tag_generator[EventCategoryT: EventCategory](
base: str,
) -> IdemKeyGenerator[EventCategoryT]:
"""Generates idempotency keys for events.
The keys are generated by hashing the state sequence number against a base string.
You can pick any base string, **so long as it's not used in any other function that generates idempotency keys**.
"""
def get_idem_keys(state: State[EventCategoryT], num_keys: int) -> Sequence[EventId]:
def recurse(n: int, last: bytes) -> Sequence[EventId]:
if n == 0:
return []
next_hash = hasher(last).digest()
return (
EventId(UUID(bytes=next_hash, version=4)),
*recurse(n - 1, next_hash),
)
initial_bytes = state.last_event_applied_idx.to_bytes(
8, byteorder="big", signed=False
)
return recurse(num_keys, initial_bytes)
return get_idem_keys

View File

@@ -1,7 +1,6 @@
from contextlib import asynccontextmanager
from logging import Logger, LogRecord
from queue import Queue as PQueue
from typing import Literal
from fastapi import FastAPI
@@ -19,9 +18,6 @@ from shared.logger import (
create_queue_listener,
log,
)
from shared.types.events.common import (
EventCategoryEnum,
)
from shared.types.models import ModelId, ModelMetadata
from shared.types.state import State
from shared.types.worker.common import InstanceId
@@ -51,19 +47,8 @@ def get_state_dependency(data: object, logger: Logger) -> State:
return data
# What The Master Cares About
MasterEventCategories = (
Literal[EventCategoryEnum.MutatesTopologyState]
| Literal[EventCategoryEnum.MutatesTaskState]
| Literal[EventCategoryEnum.MutatesTaskSagaState]
| Literal[EventCategoryEnum.MutatesRunnerStatus]
| Literal[EventCategoryEnum.MutatesInstanceState]
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
)
# Takes Care Of All States And Events Related To The Master
class MasterEventLoopProtocol(NodeEventLoopProtocol[MasterEventCategories]): ...
class MasterEventLoopProtocol(NodeEventLoopProtocol): ...
@asynccontextmanager

View File

@@ -1,7 +1,7 @@
from queue import Queue
from typing import Mapping, Sequence
from shared.types.events.common import BaseEvent, EventCategory
from shared.types.events.registry import Event
from shared.types.graphs.topology import Topology
from shared.types.state import CachePolicy
from shared.types.tasks.common import Task
@@ -20,4 +20,4 @@ def get_instance_placement(
def get_transition_events(
current_instances: Mapping[InstanceId, InstanceParams],
target_instances: Mapping[InstanceId, InstanceParams],
) -> Sequence[BaseEvent[EventCategory]]: ...
) -> Sequence[Event]: ...

View File

@@ -1,13 +0,0 @@
from enum import StrEnum
from typing import Any, Mapping, Type
def check_keys_in_map_match_enum_values[TEnum: StrEnum](
mapping_type: Type[Mapping[Any, Any]],
enum: Type[TEnum],
) -> None:
mapping_keys = set(mapping_type.__annotations__.keys())
category_values = set(e.value for e in enum)
assert mapping_keys == category_values, (
f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}"
)

View File

@@ -1,125 +0,0 @@
from asyncio import Lock, Queue, Task, create_task
from logging import Logger
from typing import List, Literal, Protocol, TypedDict
from master.logging import (
StateUpdateEffectHandlerErrorLogEntry,
StateUpdateErrorLogEntry,
StateUpdateLoopAlreadyRunningLogEntry,
StateUpdateLoopNotRunningLogEntry,
StateUpdateLoopStartedLogEntry,
StateUpdateLoopStoppedLogEntry,
)
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.constants import get_error_reporting_message
from shared.logger import log
from shared.types.events.common import (
Apply,
EffectHandler,
EventCategory,
EventCategoryEnum,
EventFromEventLog,
State,
StateAndEvent,
)
class AsyncStateManager[EventCategoryT: EventCategory](Protocol):
"""Protocol for services that manage a specific state domain."""
_task: Task[None] | None
_logger: Logger
_apply: Apply[EventCategoryT]
_default_effects: List[EffectHandler[EventCategoryT]]
extra_effects: List[EffectHandler[EventCategoryT]]
state: State[EventCategoryT]
queue: Queue[EventFromEventLog[EventCategoryT]]
lock: Lock
def __init__(
self,
state: State[EventCategoryT],
queue: Queue[EventFromEventLog[EventCategoryT]],
extra_effects: List[EffectHandler[EventCategoryT]],
logger: Logger,
) -> None:
"""Initialise the service with its event queue."""
self.state = state
self.queue = queue
self.extra_effects = extra_effects
self._logger = logger
self._task = None
async def read_state(self) -> State[EventCategoryT]:
"""Get a thread-safe snapshot of this service's state domain."""
return self.state.model_copy(deep=True)
@property
def is_running(self) -> bool:
"""Check if the service's event loop is running."""
return self._task is not None and not self._task.done()
async def start(self) -> None:
"""Start the service's event loop."""
if self.is_running:
log(self._logger, StateUpdateLoopAlreadyRunningLogEntry())
raise RuntimeError("State Update Loop Already Running")
log(self._logger, StateUpdateLoopStartedLogEntry())
self._task = create_task(self._event_loop())
async def stop(self) -> None:
"""Stop the service's event loop."""
if not self.is_running:
log(self._logger, StateUpdateLoopNotRunningLogEntry())
raise RuntimeError("State Update Loop Not Running")
assert self._task is not None, (
f"{get_error_reporting_message()}"
"BUG: is_running is True but _task is None, this should never happen!"
)
self._task.cancel()
log(self._logger, StateUpdateLoopStoppedLogEntry())
async def _event_loop(self) -> None:
"""Event loop for the service."""
while True:
event = await self.queue.get()
previous_state = self.state.model_copy(deep=True)
try:
async with self.lock:
updated_state = self._apply(
self.state,
event,
)
self.state = updated_state
except Exception as e:
log(self._logger, StateUpdateErrorLogEntry(error=e))
raise e
try:
for effect_handler in self._default_effects + self.extra_effects:
effect_handler(StateAndEvent(previous_state, event), updated_state)
except Exception as e:
log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e))
raise e
class AsyncStateManagerMapping(TypedDict):
MutatesTaskState: AsyncStateManager[Literal[EventCategoryEnum.MutatesTaskState]]
MutatesTaskSagaState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesTaskSagaState]
]
MutatesTopologyState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesTopologyState]
]
MutatesRunnerStatus: AsyncStateManager[
Literal[EventCategoryEnum.MutatesRunnerStatus]
]
MutatesInstanceState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesInstanceState]
]
MutatesNodePerformanceState: AsyncStateManager[
Literal[EventCategoryEnum.MutatesNodePerformanceState]
]
check_keys_in_map_match_enum_values(AsyncStateManagerMapping, EventCategoryEnum)

View File

@@ -1,18 +0,0 @@
from typing import Literal, TypedDict
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.types.events.common import EventCategoryEnum, State
class SyncStateManagerMapping(TypedDict):
MutatesTaskState: State[Literal[EventCategoryEnum.MutatesTaskState]]
MutatesTaskSagaState: State[Literal[EventCategoryEnum.MutatesTaskSagaState]]
MutatesTopologyState: State[Literal[EventCategoryEnum.MutatesTopologyState]]
MutatesRunnerStatus: State[Literal[EventCategoryEnum.MutatesRunnerStatus]]
MutatesInstanceState: State[Literal[EventCategoryEnum.MutatesInstanceState]]
MutatesNodePerformanceState: State[
Literal[EventCategoryEnum.MutatesNodePerformanceState]
]
check_keys_in_map_match_enum_values(SyncStateManagerMapping, EventCategoryEnum)

View File

@@ -12,12 +12,9 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlmodel import SQLModel
from shared.types.events.common import (
BaseEvent,
EventCategories,
NodeId,
)
from shared.types.events.registry import Event, EventFromEventLogTyped, EventParser
from shared.types.events.common import NodeId
from shared.types.events.components import EventFromEventLog
from shared.types.events.registry import Event, EventParser
from .types import StoredEvent
@@ -53,7 +50,7 @@ class AsyncSQLiteEventStorage:
self._max_age_s = max_age_ms / 1000.0
self._logger = logger or getLogger(__name__)
self._write_queue: Queue[tuple[BaseEvent[EventCategories], NodeId]] = Queue()
self._write_queue: Queue[tuple[Event, NodeId]] = Queue()
self._batch_writer_task: Task[None] | None = None
self._engine = None
self._closed = False
@@ -72,7 +69,7 @@ class AsyncSQLiteEventStorage:
async def append_events(
self,
events: Sequence[BaseEvent[EventCategories]],
events: Sequence[Event],
origin: NodeId
) -> None:
"""Append events to the log (fire-and-forget). The writes are batched and committed
@@ -86,7 +83,7 @@ class AsyncSQLiteEventStorage:
async def get_events_since(
self,
last_idx: int
) -> Sequence[EventFromEventLogTyped]:
) -> Sequence[EventFromEventLog[Event]]:
"""Retrieve events after a specific index."""
if self._closed:
raise RuntimeError("Storage is closed")
@@ -101,7 +98,7 @@ class AsyncSQLiteEventStorage:
)
rows = result.fetchall()
events: list[EventFromEventLogTyped] = []
events: list[EventFromEventLog[Event]] = []
for row in rows:
rowid: int = cast(int, row[0])
origin: str = cast(str, row[1])
@@ -112,7 +109,7 @@ class AsyncSQLiteEventStorage:
else:
event_data = cast(dict[str, Any], raw_event_data)
event = await self._deserialize_event(event_data)
events.append(EventFromEventLogTyped(
events.append(EventFromEventLog(
event=event,
origin=NodeId(uuid=UUID(origin)),
idx_in_log=rowid # rowid becomes idx_in_log
@@ -170,7 +167,7 @@ class AsyncSQLiteEventStorage:
loop = asyncio.get_event_loop()
while not self._closed:
batch: list[tuple[BaseEvent[EventCategories], NodeId]] = []
batch: list[tuple[Event, NodeId]] = []
try:
# Block waiting for first item
@@ -208,7 +205,7 @@ class AsyncSQLiteEventStorage:
if batch:
await self._commit_batch(batch)
async def _commit_batch(self, batch: list[tuple[BaseEvent[EventCategories], NodeId]]) -> None:
async def _commit_batch(self, batch: list[tuple[Event, NodeId]]) -> None:
"""Commit a batch of events to SQLite."""
assert self._engine is not None
@@ -218,7 +215,6 @@ class AsyncSQLiteEventStorage:
stored_event = StoredEvent(
origin=str(origin.uuid),
event_type=str(event.event_type),
event_category=str(next(iter(event.event_category))),
event_id=str(event.event_id),
event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion
)
@@ -237,8 +233,8 @@ class AsyncSQLiteEventStorage:
"""Deserialize event data back to typed Event."""
# EventParser expects the discriminator field for proper deserialization
result = EventParser.validate_python(event_data)
# EventParser returns BaseEvent but we know it's actually a specific Event type
return result # type: ignore[reportReturnType]
# EventParser returns Event type which is our union of all event types
return result
async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]:
"""Return raw event data for testing purposes."""

View File

@@ -4,12 +4,9 @@ from typing import Any, Protocol, Sequence
from sqlalchemy import DateTime, Index
from sqlmodel import JSON, Column, Field, SQLModel
from shared.types.events.common import (
BaseEvent,
EventCategories,
EventFromEventLog,
NodeId,
)
from shared.types.common import NodeId
from shared.types.events.components import EventFromEventLog
from shared.types.events.registry import Event
class StoredEvent(SQLModel, table=True):
@@ -23,7 +20,6 @@ class StoredEvent(SQLModel, table=True):
rowid: int | None = Field(default=None, primary_key=True, alias="rowid")
origin: str = Field(index=True)
event_type: str = Field(index=True)
event_category: str = Field(index=True)
event_id: str = Field(index=True)
event_data: dict[str, Any] = Field(sa_column=Column(JSON))
created_at: datetime = Field(
@@ -33,7 +29,6 @@ class StoredEvent(SQLModel, table=True):
__table_args__ = (
Index("idx_events_origin_created", "origin", "created_at"),
Index("idx_events_category_created", "event_category", "created_at"),
)
class EventStorageProtocol(Protocol):
@@ -41,7 +36,7 @@ class EventStorageProtocol(Protocol):
async def append_events(
self,
events: Sequence[BaseEvent[EventCategories]],
events: Sequence[Event],
origin: NodeId
) -> None:
"""Append events to the log (fire-and-forget).
@@ -54,7 +49,7 @@ class EventStorageProtocol(Protocol):
async def get_events_since(
self,
last_idx: int
) -> Sequence[EventFromEventLog[EventCategories]]:
) -> Sequence[EventFromEventLog[Event]]:
"""Retrieve events after a specific index.
Returns events in idx_in_log order.

View File

@@ -7,7 +7,10 @@ from typing import Any, Hashable, Mapping, Protocol, Sequence
from fastapi.responses import Response, StreamingResponse
from shared.event_loops.commands import ExternalCommand
from shared.types.events.common import Apply, EventCategory, EventFromEventLog, State
from shared.types.events.registry import Event
from shared.types.events.components import EventFromEventLog
from shared.types.state import State
from shared.types.events.components import Apply
class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]):
@@ -38,17 +41,16 @@ class ExhaustiveMapping[K: Hashable, V](MutableMapping[K, V]):
return len(self._store)
# Safety on Apply.
def safely_apply[T: EventCategory](
state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]]
) -> State[T]:
def apply_events(
state: State, apply_fn: Apply, events: Sequence[EventFromEventLog[Event]]
) -> State:
sorted_events = sorted(events, key=lambda event: event.idx_in_log)
state = state.model_copy()
for event in sorted_events:
if event.idx_in_log <= state.last_event_applied_idx:
for wrapped_event in sorted_events:
if wrapped_event.idx_in_log <= state.last_event_applied_idx:
continue
state.last_event_applied_idx = event.idx_in_log
state = apply_fn(state, event)
state.last_event_applied_idx = wrapped_event.idx_in_log
state = apply_fn(state, wrapped_event.event)
return state
@@ -69,11 +71,9 @@ class NodeCommandLoopProtocol(Protocol):
async def _handle_command(self, command: ExternalCommand) -> None: ...
class NodeEventGetterProtocol[EventCategoryT: EventCategory](Protocol):
class NodeEventGetterProtocol(Protocol):
_event_fetcher: Task[Any] | None = None
_event_queues: ExhaustiveMapping[
EventCategoryT, AsyncQueue[EventFromEventLog[EventCategory]]
]
_event_queue: AsyncQueue[EventFromEventLog[Event]]
_logger: Logger
@property
@@ -84,18 +84,18 @@ class NodeEventGetterProtocol[EventCategoryT: EventCategory](Protocol):
async def stop_event_fetcher(self) -> None: ...
class NodeStateStorageProtocol[EventCategoryT: EventCategory](Protocol):
_state_managers: ExhaustiveMapping[EventCategoryT, State[EventCategoryT]]
class NodeStateStorageProtocol(Protocol):
_state: State
_state_lock: Lock
_logger: Logger
async def _read_state(
self, event_category: EventCategoryT
) -> State[EventCategoryT]: ...
self,
) -> State: ...
class NodeStateManagerProtocol[EventCategoryT: EventCategory](
NodeEventGetterProtocol[EventCategoryT], NodeStateStorageProtocol[EventCategoryT]
class NodeStateManagerProtocol(
NodeEventGetterProtocol, NodeStateStorageProtocol
):
_state_manager: Task[Any] | None = None
_logger: Logger
@@ -116,6 +116,6 @@ class NodeStateManagerProtocol[EventCategoryT: EventCategory](
async def _apply_queued_events(self) -> None: ...
class NodeEventLoopProtocol[EventCategoryT: EventCategory](
NodeCommandLoopProtocol, NodeStateManagerProtocol[EventCategoryT]
class NodeEventLoopProtocol(
NodeCommandLoopProtocol, NodeStateManagerProtocol
): ...

View File

@@ -1,78 +0,0 @@
from asyncio.queues import Queue
from typing import Sequence, cast, get_args
from shared.event_loops.main import ExhaustiveMapping
from shared.types.events.common import (
EventCategories,
EventCategory,
EventCategoryEnum,
EventFromEventLog,
narrow_event_from_event_log_type,
)
"""
from asyncio import gather
from logging import Logger
from typing import Literal, Protocol, Sequence, TypedDict
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.types.events.common import EventCategoryEnum
"""
"""
class EventQueues(TypedDict):
MutatesTaskState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
]
MutatesTaskSagaState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
]
MutatesControlPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
]
MutatesDataPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
]
MutatesRunnerStatus: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
]
MutatesInstanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
]
MutatesNodePerformanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
]
check_keys_in_map_match_enum_values(EventQueues, EventCategoryEnum)
"""
async def route_events[UnionOfRelevantEvents: EventCategory](
queue_map: ExhaustiveMapping[
UnionOfRelevantEvents, Queue[EventFromEventLog[EventCategory]]
],
events: Sequence[EventFromEventLog[EventCategory | EventCategories]],
) -> None:
"""Route an event to the appropriate queue."""
tuple_of_categories: tuple[EventCategoryEnum, ...] = get_args(UnionOfRelevantEvents)
print(tuple_of_categories)
for event in events:
if isinstance(event.event.event_category, EventCategoryEnum):
category: EventCategory = event.event.event_category
if category not in tuple_of_categories:
continue
narrowed_event = narrow_event_from_event_log_type(event, category)
q1: Queue[EventFromEventLog[EventCategory]] = queue_map[
cast(UnionOfRelevantEvents, category)
] # TODO: make casting unnecessary
await q1.put(narrowed_event)
else:
for category in event.event.event_category:
if category not in tuple_of_categories:
continue
narrow_event = narrow_event_from_event_log_type(event, category)
q2 = queue_map[
cast(UnionOfRelevantEvents, category)
] # TODO: make casting unnecessary
await q2.put(narrow_event)

View File

@@ -14,8 +14,7 @@ from shared.types.common import NodeId
from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData
from shared.types.events.events import (
ChunkGenerated,
EventCategoryEnum,
StreamingEventTypes,
EventType,
)
from shared.types.tasks.common import TaskId
@@ -91,11 +90,10 @@ class TestAsyncSQLiteEventStorage:
async with AsyncSession(storage._engine) as session:
await session.execute(
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"event_type": "test_event",
"event_category": "test_category",
"event_id": str(uuid4()),
"event_data": json.dumps(test_data)
}
@@ -137,11 +135,10 @@ class TestAsyncSQLiteEventStorage:
async with AsyncSession(storage._engine) as session:
for record in test_records:
await session.execute(
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"event_type": record["event_type"],
"event_category": "test_category",
"event_id": str(uuid4()),
"event_data": json.dumps(record)
}
@@ -180,18 +177,18 @@ class TestAsyncSQLiteEventStorage:
async with AsyncSession(storage._engine) as session:
# Origin 1 - record 1
await session.execute(
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
{"origin": str(origin1.uuid), "event_type": "event_1", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 1})}
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{"origin": str(origin1.uuid), "event_type": "event_1", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 1})}
)
# Origin 2 - record 2
await session.execute(
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
{"origin": str(origin2.uuid), "event_type": "event_2", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin2", "seq": 2})}
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{"origin": str(origin2.uuid), "event_type": "event_2", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin2", "seq": 2})}
)
# Origin 1 - record 3
await session.execute(
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
{"origin": str(origin1.uuid), "event_type": "event_3", "event_category": "test", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 3})}
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{"origin": str(origin1.uuid), "event_type": "event_3", "event_id": str(uuid4()), "event_data": json.dumps({"from": "origin1", "seq": 3})}
)
await session.commit()
@@ -234,11 +231,10 @@ class TestAsyncSQLiteEventStorage:
async with AsyncSession(storage._engine) as session:
for i in range(10):
await session.execute(
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"event_type": f"event_{i}",
"event_category": "test",
"event_id": str(uuid4()),
"event_data": json.dumps({"index": i})
}
@@ -325,11 +321,10 @@ class TestAsyncSQLiteEventStorage:
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
await session.execute(
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"event_type": "complex_event",
"event_category": "test",
"event_id": str(uuid4()),
"event_data": json.dumps(test_data)
}
@@ -364,11 +359,10 @@ class TestAsyncSQLiteEventStorage:
async with AsyncSession(storage._engine) as session:
for i in range(count):
await session.execute(
text("INSERT INTO events (origin, event_type, event_category, event_id, event_data) VALUES (:origin, :event_type, :event_category, :event_id, :event_data)"),
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": origin_id,
"event_type": f"batch_{batch_id}_event_{i}",
"event_category": "test",
"event_id": str(uuid4()),
"event_data": json.dumps({"batch": batch_id, "item": i})
}
@@ -425,14 +419,12 @@ class TestAsyncSQLiteEventStorage:
)
chunk_generated_event = ChunkGenerated(
event_type=StreamingEventTypes.ChunkGenerated,
event_category=EventCategoryEnum.MutatesTaskState,
task_id=task_id,
chunk=token_chunk
)
# Store the event using the storage API
await storage.append_events([chunk_generated_event], sample_node_id) # type: ignore[reportArgumentType]
await storage.append_events([chunk_generated_event], sample_node_id)
# Wait for batch to be written
await asyncio.sleep(0.5)
@@ -448,8 +440,7 @@ class TestAsyncSQLiteEventStorage:
# Verify the event was deserialized correctly
retrieved_event = retrieved_event_wrapper.event
assert isinstance(retrieved_event, ChunkGenerated)
assert retrieved_event.event_type == StreamingEventTypes.ChunkGenerated
assert retrieved_event.event_category == EventCategoryEnum.MutatesTaskState
assert retrieved_event.event_type == EventType.ChunkGenerated
assert retrieved_event.task_id == task_id
# Verify the nested chunk was deserialized correctly

View File

@@ -0,0 +1,10 @@
from shared.types.events.events import (
MLXInferenceSagaPrepare,
MLXInferenceSagaStartPrepare,
)
TaskSagaEvent = (
MLXInferenceSagaPrepare
| MLXInferenceSagaStartPrepare
)

View File

@@ -0,0 +1,38 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
Callable,
Sequence,
)
if TYPE_CHECKING:
pass
from pydantic import BaseModel
from shared.types.common import NewUUID
from shared.types.events.registry import Event
from shared.types.state import State
class CommandId(NewUUID):
pass
class CommandTypes(str, Enum):
Create = "Create"
Update = "Update"
Delete = "Delete"
class Command[
CommandType: CommandTypes,
](BaseModel):
command_type: CommandType
command_id: CommandId
type Decide[CommandTypeT: CommandTypes] = Callable[
[State, Command[CommandTypeT]],
Sequence[Event],
]

View File

@@ -1,26 +1,16 @@
from enum import Enum, StrEnum
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
FrozenSet,
Literal,
NamedTuple,
Protocol,
Sequence,
cast,
Generic,
TypeVar,
)
if TYPE_CHECKING:
pass
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel
from shared.types.common import NewUUID, NodeId
from shared.types.events.sanity_checking import (
assert_literal_union_covers_enum,
check_event_type_union_is_consistent_with_registry,
)
class EventId(NewUUID):
@@ -32,114 +22,49 @@ class TimerId(NewUUID):
# Here are all the unique kinds of events that can be sent over the network.
# I've defined them in different enums for clarity, but they're all part of the same set of possible events.
class TaskSagaEventTypes(str, Enum):
class EventType(str, Enum):
# Task Saga Events
MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare"
MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare"
class TaskEventTypes(str, Enum):
# Task Events
TaskCreated = "TaskCreated"
TaskStateUpdated = "TaskStateUpdated"
TaskDeleted = "TaskDeleted"
class StreamingEventTypes(str, Enum):
# Streaming Events
ChunkGenerated = "ChunkGenerated"
class InstanceEventTypes(str, Enum):
# Instance Events
InstanceCreated = "InstanceCreated"
InstanceDeleted = "InstanceDeleted"
InstanceActivated = "InstanceActivated"
InstanceDeactivated = "InstanceDeactivated"
InstanceReplacedAtomically = "InstanceReplacedAtomically"
class RunnerStatusEventTypes(str, Enum):
# Runner Status Events
RunnerStatusUpdated = "RunnerStatusUpdated"
class NodePerformanceEventTypes(str, Enum):
# Node Performance Events
NodePerformanceMeasured = "NodePerformanceMeasured"
class TopologyEventTypes(str, Enum):
# Topology Events
TopologyEdgeCreated = "TopologyEdgeCreated"
TopologyEdgeReplacedAtomically = "TopologyEdgeReplacedAtomically"
TopologyEdgeDeleted = "TopologyEdgeDeleted"
WorkerConnected = "WorkerConnected"
WorkerStatusUpdated = "WorkerStatusUpdated"
WorkerDisconnected = "WorkerDisconnected"
class TimerEventTypes(str, Enum):
# Timer Events
TimerCreated = "TimerCreated"
TimerFired = "TimerFired"
# Registry of all event type enums
EVENT_TYPE_ENUMS = [
TaskEventTypes,
StreamingEventTypes,
InstanceEventTypes,
RunnerStatusEventTypes,
NodePerformanceEventTypes,
TopologyEventTypes,
TimerEventTypes,
TaskSagaEventTypes,
]
EventTypeT = TypeVar("EventTypeT", bound=EventType)
# Here's the set of all possible events.
EventTypes = (
TaskEventTypes
| StreamingEventTypes
| InstanceEventTypes
| RunnerStatusEventTypes
| NodePerformanceEventTypes
| TopologyEventTypes
| TimerEventTypes
| TaskSagaEventTypes
)
check_event_type_union_is_consistent_with_registry(EVENT_TYPE_ENUMS, EventTypes)
class EventCategoryEnum(StrEnum):
MutatesTaskState = "MutatesTaskState"
MutatesTaskSagaState = "MutatesTaskSagaState"
MutatesRunnerStatus = "MutatesRunnerStatus"
MutatesInstanceState = "MutatesInstanceState"
MutatesNodePerformanceState = "MutatesNodePerformanceState"
MutatesTopologyState = "MutatesTopologyState"
EventCategory = (
Literal[EventCategoryEnum.MutatesTopologyState]
| Literal[EventCategoryEnum.MutatesTaskState]
| Literal[EventCategoryEnum.MutatesTaskSagaState]
| Literal[EventCategoryEnum.MutatesRunnerStatus]
| Literal[EventCategoryEnum.MutatesInstanceState]
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
| Literal[EventCategoryEnum.MutatesTopologyState]
)
EventCategories = FrozenSet[EventCategory]
assert_literal_union_covers_enum(EventCategory, EventCategoryEnum)
EventTypeT = EventTypes # Type Alias placeholder; generic parameter will override
class BaseEvent[
SetMembersT: EventCategories | EventCategory,
EventTypeLitT: EventTypes = EventTypes,
](BaseModel):
event_type: EventTypeLitT
event_category: SetMembersT
class BaseEvent(BaseModel, Generic[EventTypeT]):
event_type: EventTypeT
event_id: EventId = EventId()
def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool:
@@ -151,129 +76,4 @@ class BaseEvent[
return True
class EventFromEventLog[SetMembersT: EventCategories | EventCategory](BaseModel):
event: BaseEvent[SetMembersT]
origin: NodeId
idx_in_log: int = Field(gt=0)
@model_validator(mode="after")
def check_event_was_sent_by_correct_node(
self,
) -> "EventFromEventLog[SetMembersT]":
if self.event.check_event_was_sent_by_correct_node(self.origin):
return self
raise ValueError("Invalid Event: Origin ID Does Not Match")
def narrow_event_type[T: EventCategory, Q: EventCategories | EventCategory](
event: BaseEvent[Q],
target_category: T,
) -> BaseEvent[T]:
if target_category not in event.event_category:
raise ValueError(f"Event Does Not Contain Target Category {target_category}")
narrowed_event = event.model_copy(update={"event_category": {target_category}})
return cast(BaseEvent[T], narrowed_event)
def narrow_event_from_event_log_type[
T: EventCategory,
Q: EventCategories | EventCategory,
](
event: EventFromEventLog[Q],
target_category: T,
) -> EventFromEventLog[T]:
if target_category not in event.event.event_category:
raise ValueError(f"Event Does Not Contain Target Category {target_category}")
narrowed_event = event.model_copy(
update={"event": narrow_event_type(event.event, target_category)}
)
return cast(EventFromEventLog[T], narrowed_event)
class State[EventCategoryT: EventCategory](BaseModel):
event_category: EventCategoryT
last_event_applied_idx: int = Field(default=0, ge=0)
# Definitions for Type Variables
type Saga[EventCategoryT: EventCategory] = Callable[
[State[EventCategoryT], EventFromEventLog[EventCategoryT]],
Sequence[BaseEvent[EventCategories]],
]
type Apply[EventCategoryT: EventCategory] = Callable[
[State[EventCategoryT], EventFromEventLog[EventCategoryT]],
State[EventCategoryT],
]
class StateAndEvent[EventCategoryT: EventCategory](NamedTuple):
state: State[EventCategoryT]
event: EventFromEventLog[EventCategoryT]
type EffectHandler[EventCategoryT: EventCategory] = Callable[
[StateAndEvent[EventCategoryT], State[EventCategoryT]], None
]
type EventPublisher = Callable[[BaseEvent[Any]], None]
# A component that can publish events
class EventPublisherProtocol(Protocol):
def send(self, events: Sequence[BaseEvent[EventCategories]]) -> None: ...
# A component that can fetch events to apply
class EventFetcherProtocol[EventCategoryT: EventCategory](Protocol):
def get_events_to_apply(
self, state: State[EventCategoryT]
) -> Sequence[BaseEvent[EventCategoryT]]: ...
# A component that can get the effect handler for a saga
def get_saga_effect_handler[EventCategoryT: EventCategory](
saga: Saga[EventCategoryT], event_publisher: EventPublisher
) -> EffectHandler[EventCategoryT]:
def effect_handler(state_and_event: StateAndEvent[EventCategoryT]) -> None:
trigger_state, trigger_event = state_and_event
for event in saga(trigger_state, trigger_event):
event_publisher(event)
return lambda state_and_event, _: effect_handler(state_and_event)
def get_effects_from_sagas[EventCategoryT: EventCategory](
sagas: Sequence[Saga[EventCategoryT]],
event_publisher: EventPublisher,
) -> Sequence[EffectHandler[EventCategoryT]]:
return [get_saga_effect_handler(saga, event_publisher) for saga in sagas]
type IdemKeyGenerator[EventCategoryT: EventCategory] = Callable[
[State[EventCategoryT], int], Sequence[EventId]
]
class CommandId(NewUUID):
pass
class CommandTypes(str, Enum):
Create = "Create"
Update = "Update"
Delete = "Delete"
class Command[
EventCategoryT: EventCategories | EventCategory,
CommandType: CommandTypes,
](BaseModel):
command_type: CommandType
command_id: CommandId
type Decide[EventCategoryT: EventCategory, CommandTypeT: CommandTypes] = Callable[
[State[EventCategoryT], Command[EventCategoryT, CommandTypeT]],
Sequence[BaseEvent[EventCategoryT]],
]

View File

@@ -0,0 +1,38 @@
# components.py defines the small event functions, adapters etc.
# this name could probably be improved.
from typing import (
TYPE_CHECKING,
)
if TYPE_CHECKING:
pass
from pydantic import BaseModel, Field, model_validator
from typing import Callable
from shared.types.common import NodeId
from shared.types.state import State
from shared.types.events.registry import Event
class EventFromEventLog[T: Event](BaseModel):
event: T
origin: NodeId
idx_in_log: int = Field(gt=0)
@model_validator(mode="after")
def check_event_was_sent_by_correct_node(
self,
) -> "EventFromEventLog[T]":
if self.event.check_event_was_sent_by_correct_node(self.origin):
return self
raise ValueError("Invalid Event: Origin ID Does Not Match")
type Apply = Callable[
[State, Event],
State
]

View File

@@ -6,14 +6,8 @@ from shared.types.common import NodeId
from shared.types.events.chunks import GenerationChunk
from shared.types.events.common import (
BaseEvent,
EventCategoryEnum,
InstanceEventTypes,
NodePerformanceEventTypes,
RunnerStatusEventTypes,
StreamingEventTypes,
TaskEventTypes,
TaskSagaEventTypes,
TopologyEventTypes,
EventType,
TimerId,
)
from shared.types.graphs.topology import (
TopologyEdge,
@@ -27,156 +21,122 @@ from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import InstanceParams, TypeOfInstance
from shared.types.worker.runners import RunnerId, RunnerStatus
TaskEvent = BaseEvent[EventCategoryEnum.MutatesTaskState]
InstanceEvent = BaseEvent[EventCategoryEnum.MutatesInstanceState]
TopologyEvent = BaseEvent[EventCategoryEnum.MutatesTopologyState]
NodePerformanceEvent = BaseEvent[EventCategoryEnum.MutatesNodePerformanceState]
class TaskCreated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskCreated]]):
event_type: Literal[TaskEventTypes.TaskCreated] = TaskEventTypes.TaskCreated
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
class TaskCreated(BaseEvent[EventType.TaskCreated]):
event_type: Literal[EventType.TaskCreated] = EventType.TaskCreated
task_id: TaskId
task: Task
# Covers Cancellation Of Task, Non-Cancelled Tasks Perist
class TaskDeleted(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskDeleted]]):
event_type: Literal[TaskEventTypes.TaskDeleted] = TaskEventTypes.TaskDeleted
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
class TaskDeleted(BaseEvent[EventType.TaskDeleted]):
event_type: Literal[EventType.TaskDeleted] = EventType.TaskDeleted
task_id: TaskId
class TaskStateUpdated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[TaskEventTypes.TaskStateUpdated]]):
event_type: Literal[TaskEventTypes.TaskStateUpdated] = TaskEventTypes.TaskStateUpdated
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
class TaskStateUpdated(BaseEvent[EventType.TaskStateUpdated]):
event_type: Literal[EventType.TaskStateUpdated] = EventType.TaskStateUpdated
task_id: TaskId
task_status: TaskStatus
class InstanceCreated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceCreated]]):
event_type: Literal[InstanceEventTypes.InstanceCreated] = InstanceEventTypes.InstanceCreated
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
class InstanceCreated(BaseEvent[EventType.InstanceCreated]):
event_type: Literal[EventType.InstanceCreated] = EventType.InstanceCreated
instance_id: InstanceId
instance_params: InstanceParams
instance_type: TypeOfInstance
class InstanceActivated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceActivated]]):
event_type: Literal[InstanceEventTypes.InstanceActivated] = InstanceEventTypes.InstanceActivated
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
class InstanceActivated(BaseEvent[EventType.InstanceActivated]):
event_type: Literal[EventType.InstanceActivated] = EventType.InstanceActivated
instance_id: InstanceId
class InstanceDeactivated(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceDeactivated]]):
event_type: Literal[InstanceEventTypes.InstanceDeactivated] = InstanceEventTypes.InstanceDeactivated
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
class InstanceDeactivated(BaseEvent[EventType.InstanceDeactivated]):
event_type: Literal[EventType.InstanceDeactivated] = EventType.InstanceDeactivated
instance_id: InstanceId
class InstanceDeleted(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceDeleted]]):
event_type: Literal[InstanceEventTypes.InstanceDeleted] = InstanceEventTypes.InstanceDeleted
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
class InstanceDeleted(BaseEvent[EventType.InstanceDeleted]):
event_type: Literal[EventType.InstanceDeleted] = EventType.InstanceDeleted
instance_id: InstanceId
transition: Tuple[InstanceId, InstanceId]
class InstanceReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesInstanceState, Literal[InstanceEventTypes.InstanceReplacedAtomically]]):
event_type: Literal[InstanceEventTypes.InstanceReplacedAtomically] = InstanceEventTypes.InstanceReplacedAtomically
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = EventCategoryEnum.MutatesInstanceState
class InstanceReplacedAtomically(BaseEvent[EventType.InstanceReplacedAtomically]):
event_type: Literal[EventType.InstanceReplacedAtomically] = EventType.InstanceReplacedAtomically
instance_to_replace: InstanceId
new_instance_id: InstanceId
class RunnerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesRunnerStatus, Literal[RunnerStatusEventTypes.RunnerStatusUpdated]]):
event_type: Literal[RunnerStatusEventTypes.RunnerStatusUpdated] = RunnerStatusEventTypes.RunnerStatusUpdated
event_category: Literal[EventCategoryEnum.MutatesRunnerStatus] = EventCategoryEnum.MutatesRunnerStatus
class RunnerStatusUpdated(BaseEvent[EventType.RunnerStatusUpdated]):
event_type: Literal[EventType.RunnerStatusUpdated] = EventType.RunnerStatusUpdated
runner_id: RunnerId
runner_status: RunnerStatus
class MLXInferenceSagaPrepare(BaseEvent[EventCategoryEnum.MutatesTaskSagaState, Literal[TaskSagaEventTypes.MLXInferenceSagaPrepare]]):
event_type: Literal[TaskSagaEventTypes.MLXInferenceSagaPrepare] = TaskSagaEventTypes.MLXInferenceSagaPrepare
event_category: Literal[EventCategoryEnum.MutatesTaskSagaState] = EventCategoryEnum.MutatesTaskSagaState
class MLXInferenceSagaPrepare(BaseEvent[EventType.MLXInferenceSagaPrepare]):
event_type: Literal[EventType.MLXInferenceSagaPrepare] = EventType.MLXInferenceSagaPrepare
task_id: TaskId
instance_id: InstanceId
class MLXInferenceSagaStartPrepare(BaseEvent[EventCategoryEnum.MutatesTaskSagaState, Literal[TaskSagaEventTypes.MLXInferenceSagaStartPrepare]]):
event_type: Literal[TaskSagaEventTypes.MLXInferenceSagaStartPrepare] = TaskSagaEventTypes.MLXInferenceSagaStartPrepare
event_category: Literal[EventCategoryEnum.MutatesTaskSagaState] = EventCategoryEnum.MutatesTaskSagaState
class MLXInferenceSagaStartPrepare(BaseEvent[EventType.MLXInferenceSagaStartPrepare]):
event_type: Literal[EventType.MLXInferenceSagaStartPrepare] = EventType.MLXInferenceSagaStartPrepare
task_id: TaskId
instance_id: InstanceId
class NodePerformanceMeasured(BaseEvent[EventCategoryEnum.MutatesNodePerformanceState, Literal[NodePerformanceEventTypes.NodePerformanceMeasured]]):
event_type: Literal[NodePerformanceEventTypes.NodePerformanceMeasured] = NodePerformanceEventTypes.NodePerformanceMeasured
event_category: Literal[EventCategoryEnum.MutatesNodePerformanceState] = EventCategoryEnum.MutatesNodePerformanceState
class NodePerformanceMeasured(BaseEvent[EventType.NodePerformanceMeasured]):
event_type: Literal[EventType.NodePerformanceMeasured] = EventType.NodePerformanceMeasured
node_id: NodeId
node_profile: NodePerformanceProfile
class WorkerConnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerConnected]]):
event_type: Literal[TopologyEventTypes.WorkerConnected] = TopologyEventTypes.WorkerConnected
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
class WorkerConnected(BaseEvent[EventType.WorkerConnected]):
event_type: Literal[EventType.WorkerConnected] = EventType.WorkerConnected
edge: TopologyEdge
class WorkerStatusUpdated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerStatusUpdated]]):
event_type: Literal[TopologyEventTypes.WorkerStatusUpdated] = TopologyEventTypes.WorkerStatusUpdated
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
class WorkerStatusUpdated(BaseEvent[EventType.WorkerStatusUpdated]):
event_type: Literal[EventType.WorkerStatusUpdated] = EventType.WorkerStatusUpdated
node_id: NodeId
node_state: NodeStatus
class WorkerDisconnected(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.WorkerDisconnected]]):
event_type: Literal[TopologyEventTypes.WorkerDisconnected] = TopologyEventTypes.WorkerDisconnected
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
class WorkerDisconnected(BaseEvent[EventType.WorkerDisconnected]):
event_type: Literal[EventType.WorkerDisconnected] = EventType.WorkerDisconnected
vertex_id: NodeId
class ChunkGenerated(BaseEvent[EventCategoryEnum.MutatesTaskState, Literal[StreamingEventTypes.ChunkGenerated]]):
event_type: Literal[StreamingEventTypes.ChunkGenerated] = StreamingEventTypes.ChunkGenerated
event_category: Literal[EventCategoryEnum.MutatesTaskState] = EventCategoryEnum.MutatesTaskState
class ChunkGenerated(BaseEvent[EventType.ChunkGenerated]):
event_type: Literal[EventType.ChunkGenerated] = EventType.ChunkGenerated
task_id: TaskId
chunk: GenerationChunk
class TopologyEdgeCreated(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeCreated]]):
event_type: Literal[TopologyEventTypes.TopologyEdgeCreated] = TopologyEventTypes.TopologyEdgeCreated
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
class TopologyEdgeCreated(BaseEvent[EventType.TopologyEdgeCreated]):
event_type: Literal[EventType.TopologyEdgeCreated] = EventType.TopologyEdgeCreated
vertex: TopologyNode
class TopologyEdgeReplacedAtomically(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically]]):
event_type: Literal[TopologyEventTypes.TopologyEdgeReplacedAtomically] = TopologyEventTypes.TopologyEdgeReplacedAtomically
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
class TopologyEdgeReplacedAtomically(BaseEvent[EventType.TopologyEdgeReplacedAtomically]):
event_type: Literal[EventType.TopologyEdgeReplacedAtomically] = EventType.TopologyEdgeReplacedAtomically
edge_id: TopologyEdgeId
edge_profile: TopologyEdgeProfile
class TopologyEdgeDeleted(BaseEvent[EventCategoryEnum.MutatesTopologyState, Literal[TopologyEventTypes.TopologyEdgeDeleted]]):
event_type: Literal[TopologyEventTypes.TopologyEdgeDeleted] = TopologyEventTypes.TopologyEdgeDeleted
event_category: Literal[EventCategoryEnum.MutatesTopologyState] = EventCategoryEnum.MutatesTopologyState
class TopologyEdgeDeleted(BaseEvent[EventType.TopologyEdgeDeleted]):
event_type: Literal[EventType.TopologyEdgeDeleted] = EventType.TopologyEdgeDeleted
edge_id: TopologyEdgeId
"""
TEST_EVENT_CATEGORIES_TYPE = FrozenSet[
Literal[
EventCategoryEnum.MutatesTaskState,
EventCategoryEnum.MutatesControlPlaneState,
]
]
TEST_EVENT_CATEGORIES = frozenset(
(
EventCategoryEnum.MutatesTaskState,
EventCategoryEnum.MutatesControlPlaneState,
)
)
class TimerCreated(BaseEvent[EventType.TimerCreated]):
event_type: Literal[EventType.TimerCreated] = EventType.TimerCreated
timer_id: TimerId
delay_seconds: float
class TestEvent(BaseEvent[TEST_EVENT_CATEGORIES_TYPE]):
event_category: TEST_EVENT_CATEGORIES_TYPE = TEST_EVENT_CATEGORIES
test_id: int
"""
class TimerFired(BaseEvent[EventType.TimerFired]):
event_type: Literal[EventType.TimerFired] = EventType.TimerFired
timer_id: TimerId

View File

@@ -1,25 +1,15 @@
from types import UnionType
from typing import Annotated, Any, Mapping, Type, get_args
from typing import Annotated, Any, Mapping, Type, TypeAlias
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import Field, TypeAdapter
from shared.constants import get_error_reporting_message
from shared.types.events.common import (
BaseEvent,
EventCategories,
EventTypes,
InstanceEventTypes,
NodeId,
NodePerformanceEventTypes,
RunnerStatusEventTypes,
StreamingEventTypes,
TaskEventTypes,
TaskSagaEventTypes,
TopologyEventTypes,
EventType,
)
from shared.types.events.events import (
ChunkGenerated,
InstanceActivated,
InstanceCreated,
InstanceDeactivated,
InstanceDeleted,
InstanceReplacedAtomically,
MLXInferenceSagaPrepare,
@@ -29,6 +19,8 @@ from shared.types.events.events import (
TaskCreated,
TaskDeleted,
TaskStateUpdated,
TimerCreated,
TimerFired,
TopologyEdgeCreated,
TopologyEdgeDeleted,
TopologyEdgeReplacedAtomically,
@@ -36,6 +28,11 @@ from shared.types.events.events import (
WorkerDisconnected,
WorkerStatusUpdated,
)
from shared.types.events.sanity_checking import (
assert_event_union_covers_registry,
check_registry_has_all_event_types,
check_union_of_all_events_is_consistent_with_registry,
)
"""
class EventTypeNames(StrEnum):
@@ -50,63 +47,38 @@ class EventTypeNames(StrEnum):
check_event_categories_are_defined_for_all_event_types(EVENT_TYPE_ENUMS, EventTypeNames)
"""
EventRegistry: Mapping[EventTypes, Type[Any]] = {
TaskEventTypes.TaskCreated: TaskCreated,
TaskEventTypes.TaskStateUpdated: TaskStateUpdated,
TaskEventTypes.TaskDeleted: TaskDeleted,
InstanceEventTypes.InstanceCreated: InstanceCreated,
InstanceEventTypes.InstanceDeleted: InstanceDeleted,
InstanceEventTypes.InstanceReplacedAtomically: InstanceReplacedAtomically,
RunnerStatusEventTypes.RunnerStatusUpdated: RunnerStatusUpdated,
NodePerformanceEventTypes.NodePerformanceMeasured: NodePerformanceMeasured,
TopologyEventTypes.WorkerConnected: WorkerConnected,
TopologyEventTypes.WorkerStatusUpdated: WorkerStatusUpdated,
TopologyEventTypes.WorkerDisconnected: WorkerDisconnected,
StreamingEventTypes.ChunkGenerated: ChunkGenerated,
TopologyEventTypes.TopologyEdgeCreated: TopologyEdgeCreated,
TopologyEventTypes.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically,
TopologyEventTypes.TopologyEdgeDeleted: TopologyEdgeDeleted,
TaskSagaEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
TaskSagaEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
EventRegistry: Mapping[EventType, Type[Any]] = {
EventType.TaskCreated: TaskCreated,
EventType.TaskStateUpdated: TaskStateUpdated,
EventType.TaskDeleted: TaskDeleted,
EventType.InstanceCreated: InstanceCreated,
EventType.InstanceActivated: InstanceActivated,
EventType.InstanceDeactivated: InstanceDeactivated,
EventType.InstanceDeleted: InstanceDeleted,
EventType.InstanceReplacedAtomically: InstanceReplacedAtomically,
EventType.RunnerStatusUpdated: RunnerStatusUpdated,
EventType.NodePerformanceMeasured: NodePerformanceMeasured,
EventType.WorkerConnected: WorkerConnected,
EventType.WorkerStatusUpdated: WorkerStatusUpdated,
EventType.WorkerDisconnected: WorkerDisconnected,
EventType.ChunkGenerated: ChunkGenerated,
EventType.TopologyEdgeCreated: TopologyEdgeCreated,
EventType.TopologyEdgeReplacedAtomically: TopologyEdgeReplacedAtomically,
EventType.TopologyEdgeDeleted: TopologyEdgeDeleted,
EventType.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
EventType.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
EventType.TimerCreated: TimerCreated,
EventType.TimerFired: TimerFired,
}
# Sanity Check.
def check_registry_has_all_event_types() -> None:
event_types: tuple[EventTypes, ...] = get_args(EventTypes)
missing_event_types = set(event_types) - set(EventRegistry.keys())
assert not missing_event_types, (
f"{get_error_reporting_message()}"
f"There's an event missing from the registry: {missing_event_types}"
)
def check_union_of_all_events_is_consistent_with_registry(
registry: Mapping[EventTypes, Type[Any]], union_type: UnionType
) -> None:
type_of_each_registry_entry = set(registry.values())
type_of_each_entry_in_union = set(get_args(union_type))
missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union
assert not missing_from_union, (
f"{get_error_reporting_message()}"
f"Event classes in registry are missing from all_events union: {missing_from_union}"
)
extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry
assert not extra_in_union, (
f"{get_error_reporting_message()}"
f"Event classes in all_events union are missing from registry: {extra_in_union}"
)
Event = (
AllEventsUnion = (
TaskCreated
| TaskStateUpdated
| TaskDeleted
| InstanceCreated
| InstanceActivated
| InstanceDeactivated
| InstanceDeleted
| InstanceReplacedAtomically
| RunnerStatusUpdated
@@ -120,24 +92,16 @@ Event = (
| TopologyEdgeDeleted
| MLXInferenceSagaPrepare
| MLXInferenceSagaStartPrepare
| TimerCreated
| TimerFired
)
# Run the sanity check
check_union_of_all_events_is_consistent_with_registry(EventRegistry, Event)
Event: TypeAlias = Annotated[AllEventsUnion, Field(discriminator="event_type")]
EventParser: TypeAdapter[Event] = TypeAdapter(Event)
_EventType = Annotated[Event, Field(discriminator="event_type")]
EventParser: TypeAdapter[BaseEvent[EventCategories]] = TypeAdapter(_EventType)
# Define a properly typed EventFromEventLog that preserves specific event types
class EventFromEventLogTyped(BaseModel):
"""Properly typed EventFromEventLog that preserves specific event types."""
event: _EventType
origin: NodeId
idx_in_log: int = Field(gt=0)
def check_event_was_sent_by_correct_node(self) -> bool:
"""Check if the event was sent by the correct node."""
return self.event.check_event_was_sent_by_correct_node(self.origin)
assert_event_union_covers_registry(AllEventsUnion)
check_union_of_all_events_is_consistent_with_registry(EventRegistry, AllEventsUnion)
check_registry_has_all_event_types(EventRegistry)

View File

@@ -1,68 +1,75 @@
from enum import Enum, StrEnum
from enum import StrEnum
from types import UnionType
from typing import Any, LiteralString, Sequence, Set, Type, get_args
from typing import Any, Mapping, Set, Type, cast, get_args
from pydantic.fields import FieldInfo
from shared.constants import get_error_reporting_message
from shared.types.events.common import EventType
def check_event_type_union_is_consistent_with_registry(
event_type_enums: Sequence[Type[Enum]], event_types: UnionType
) -> None:
"""Assert that every enum value from _EVENT_TYPE_ENUMS satisfies EventTypes."""
event_types_inferred_from_union = set(get_args(event_types))
event_types_inferred_from_registry = [
member for enum_class in event_type_enums for member in enum_class
]
# Check that each registry value belongs to one of the types in the union
for tag_of_event_type in event_types_inferred_from_registry:
event_type = type(tag_of_event_type)
assert event_type in event_types_inferred_from_union, (
f"{get_error_reporting_message()}"
f"There's a mismatch between the registry of event types and the union of possible event types."
f"The enum value {tag_of_event_type} for type {event_type} is not covered by {event_types_inferred_from_union}."
)
def check_event_categories_are_defined_for_all_event_types(
event_definitions: Sequence[Type[Enum]], event_categories: Type[StrEnum]
) -> None:
"""Assert that the event category names are consistent with the event type enums."""
expected_category_tags: list[str] = [
enum_class.__name__ for enum_class in event_definitions
]
tag_of_event_categories: list[str] = list(event_categories.__members__.values())
assert tag_of_event_categories == expected_category_tags, (
f"{get_error_reporting_message()}"
f"The values of the enum EventCategories are not named after the event type enums."
f"These are the missing categories: {set(expected_category_tags) - set(tag_of_event_categories)}"
f"These are the extra categories: {set(tag_of_event_categories) - set(expected_category_tags)}"
)
def assert_literal_union_covers_enum[TEnum: StrEnum](
def assert_event_union_covers_registry[TEnum: StrEnum](
literal_union: UnionType,
enum_type: Type[TEnum],
) -> None:
enum_values: Set[Any] = {member.value for member in enum_type}
"""
Ensure that our union of events (AllEventsUnion) has one member per element of Enum
"""
enum_values: Set[str] = {member.value for member in EventType}
def _flatten(tp: UnionType) -> Set[Any]:
values: Set[Any] = set()
args: tuple[LiteralString, ...] = get_args(tp)
for arg in args:
payloads: tuple[TEnum, ...] = get_args(arg)
for payload in payloads:
values.add(payload.value)
def _flatten(tp: UnionType) -> Set[str]:
values: Set[str] = set()
args = get_args(tp) # Get event classes from the union
for arg in args: # type: ignore[reportAny]
# Cast to type since we know these are class types
event_class = cast(type[Any], arg)
# Each event class is a Pydantic model with model_fields
if hasattr(event_class, 'model_fields'):
model_fields = cast(dict[str, FieldInfo], event_class.model_fields)
if 'event_type' in model_fields:
# Get the default value of the event_type field
event_type_field: FieldInfo = model_fields['event_type']
if hasattr(event_type_field, 'default'):
default_value = cast(EventType, event_type_field.default)
# The default is an EventType enum member, get its value
values.add(default_value.value)
return values
literal_values: Set[Any] = _flatten(literal_union)
literal_values: Set[str] = _flatten(literal_union)
assert enum_values == literal_values, (
f"{get_error_reporting_message()}"
f"The values of the enum {enum_type} are not covered by the literal union {literal_union}.\n"
f"The values of the enum {EventType} are not covered by the literal union {literal_union}.\n"
f"These are the missing values: {enum_values - literal_values}\n"
f"These are the extra values: {literal_values - enum_values}\n"
)
def check_union_of_all_events_is_consistent_with_registry(
registry: Mapping[EventType, Type[Any]], union_type: UnionType
) -> None:
type_of_each_registry_entry = set(registry.values())
type_of_each_entry_in_union = set(get_args(union_type))
missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union
assert not missing_from_union, (
f"{get_error_reporting_message()}"
f"Event classes in registry are missing from all_events union: {missing_from_union}"
)
extra_in_union = type_of_each_entry_in_union - type_of_each_registry_entry
assert not extra_in_union, (
f"{get_error_reporting_message()}"
f"Event classes in all_events union are missing from registry: {extra_in_union}"
)
def check_registry_has_all_event_types(event_registry: Mapping[EventType, Type[Any]]) -> None:
event_types: tuple[EventType, ...] = get_args(EventType)
missing_event_types = set(event_types) - set(event_registry.keys())
assert not missing_event_types, (
f"{get_error_reporting_message()}"
f"There's an event missing from the registry: {missing_event_types}"
)
# TODO: Check all events have an apply function.
# probably in a different place though.

View File

@@ -39,3 +39,6 @@ class State(BaseModel):
task_inbox: List[Task] = Field(default_factory=list)
task_outbox: List[Task] = Field(default_factory=list)
cache_policy: CachePolicy = CachePolicy.KeepAll
# TODO: implement / use this?
last_event_applied_idx: int = Field(default=0, ge=0)