diff --git a/.gitignore b/.gitignore index 8ac70684..16f168d6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,7 @@ __pycache__ hosts_*.json # hide direnv stuff -/.direnv \ No newline at end of file +/.direnv +# TODO figure out how to properly solve the issue with these target directories showing up +networking/target/ +networking/topology/target/ diff --git a/shared/db/sqlite/connector.py b/shared/db/sqlite/connector.py index 2009c8c0..cb7fe2e6 100644 --- a/shared/db/sqlite/connector.py +++ b/shared/db/sqlite/connector.py @@ -246,7 +246,3 @@ class AsyncSQLiteEventStorage: except Exception as e: self._logger.error(f"Failed to commit batch: {e}") raise - - async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]: - """Return raw event data for testing purposes.""" - return event_data diff --git a/shared/tests/test_sqlite_connector.py b/shared/tests/test_sqlite_connector.py index 9e4c8b4d..deacd72e 100644 --- a/shared/tests/test_sqlite_connector.py +++ b/shared/tests/test_sqlite_connector.py @@ -14,7 +14,6 @@ from shared.types.common import NodeId from shared.types.events import ( ChunkGenerated, CommandId, - _EventType, ) from shared.types.events.chunks import ChunkType, TokenChunk @@ -472,7 +471,6 @@ class TestAsyncSQLiteEventStorage: # Verify the event was deserialized correctly retrieved_event = retrieved_event_wrapper.event assert isinstance(retrieved_event, ChunkGenerated) - assert retrieved_event.event_type == _EventType.ChunkGenerated assert retrieved_event.command_id == command_id # Verify the nested chunk was deserialized correctly diff --git a/shared/types/events/__init__.py b/shared/types/events/__init__.py index db6adbd5..b3c5ac1b 100644 --- a/shared/types/events/__init__.py +++ b/shared/types/events/__init__.py @@ -1,99 +1,16 @@ # ruff: noqa: F403 # ruff: noqa: F405 -import types -import typing -from typing import Annotated, Union - # Note: we are implementing internal details here, so importing private stuff is fine!!! -from pydantic import Field, TypeAdapter +from pydantic import TypeAdapter -from ...constants import get_error_reporting_message +from shared.types.events.components import EventFromEventLog + +from ._apply import Event, apply from ._common import * -from ._common import _BaseEvent, _EventType # pyright: ignore[reportPrivateUsage] from ._events import * -_Event = Union[ - TaskCreated, - TaskStateUpdated, - TaskDeleted, - InstanceCreated, - InstanceActivated, - InstanceDeactivated, - InstanceDeleted, - InstanceReplacedAtomically, - RunnerStatusUpdated, - NodePerformanceMeasured, - WorkerConnected, - WorkerStatusUpdated, - WorkerDisconnected, - ChunkGenerated, - TopologyEdgeCreated, - TopologyEdgeReplacedAtomically, - TopologyEdgeDeleted, - MLXInferenceSagaPrepare, - MLXInferenceSagaStartPrepare, -] -""" -Un-annotated union of all events. Only used internally to create the registry. -For all other usecases, use the annotated union of events :class:`Event` :) -""" - -Event = Annotated[_Event, Field(discriminator="event_type")] -"""Type of events, a discriminated union.""" - EventParser: TypeAdapter[Event] = TypeAdapter(Event) """Type adaptor to parse :class:`Event`s.""" - -def _check_event_type_consistency(): - # Grab enum values from members - member_enum_values = [m for m in _EventType] - - # grab enum values from the union => scrape the type annotation - union_enum_values: list[_EventType] = [] - union_classes = list(typing.get_args(_Event)) - for cls in union_classes: # pyright: ignore[reportAny] - assert issubclass(cls, object), ( - f"{get_error_reporting_message()}", - f"The class {cls} is NOT a subclass of {object}." - ) - - # ensure the first base parameter is ALWAYS _BaseEvent - base_cls = list(types.get_original_bases(cls)) - assert len(base_cls) >= 1 and issubclass(base_cls[0], object) \ - and issubclass(base_cls[0], _BaseEvent), ( - f"{get_error_reporting_message()}", - f"The class {cls} does NOT inherit from {_BaseEvent} {typing.get_origin(base_cls[0])}." - ) - - # grab type hints and extract the right values from it - cls_hints = typing.get_type_hints(cls) - assert "event_type" in cls_hints and \ - typing.get_origin(cls_hints["event_type"]) is typing.Literal, ( # pyright: ignore[reportAny] - f"{get_error_reporting_message()}", - f"The class {cls} is missing a {typing.Literal}-annotated `event_type` field." - ) - - # make sure the value is an instance of `_EventType` - enum_value = list(typing.get_args(cls_hints["event_type"])) - assert len(enum_value) == 1 and isinstance(enum_value[0], _EventType), ( - f"{get_error_reporting_message()}", - f"The `event_type` of {cls} has a non-{_EventType} literal-type." - ) - union_enum_values.append(enum_value[0]) - - # ensure there is a 1:1 bijection between the two - for m in member_enum_values: - assert m in union_enum_values, ( - f"{get_error_reporting_message()}", - f"There is no event-type registered for {m} in {_Event}." - ) - union_enum_values.remove(m) - assert len(union_enum_values) == 0, ( - f"{get_error_reporting_message()}", - f"The following events have multiple event types defined in {_Event}: {union_enum_values}." - ) - - -_check_event_type_consistency() +__all__ = ["Event", "EventParser", "apply", "EventFromEventLog"] diff --git a/shared/types/events/_apply.py b/shared/types/events/_apply.py new file mode 100644 index 00000000..205517d9 --- /dev/null +++ b/shared/types/events/_apply.py @@ -0,0 +1,185 @@ +from functools import singledispatch +from typing import Mapping, TypeVar + +# from shared.topology import Topology +from shared.types.common import NodeId +from shared.types.events._events import Event +from shared.types.events.components import EventFromEventLog +from shared.types.profiling import NodePerformanceProfile +from shared.types.state import State +from shared.types.tasks import Task, TaskId +from shared.types.worker.common import NodeStatus, RunnerId +from shared.types.worker.instances import BaseInstance, InstanceId, TypeOfInstance +from shared.types.worker.runners import RunnerStatus + +from ._events import ( + ChunkGenerated, + InstanceActivated, + InstanceCreated, + InstanceDeactivated, + InstanceDeleted, + InstanceReplacedAtomically, + MLXInferenceSagaPrepare, + MLXInferenceSagaStartPrepare, + NodePerformanceMeasured, + RunnerStatusUpdated, + TaskCreated, + TaskDeleted, + TaskStateUpdated, + TopologyEdgeCreated, + TopologyEdgeDeleted, + TopologyEdgeReplacedAtomically, + WorkerConnected, + WorkerDisconnected, + WorkerStatusUpdated, +) + +S = TypeVar("S", bound=State) + +@singledispatch +def event_apply(state: State, event: Event) -> State: + raise RuntimeError(f"no handler for {type(event).__name__}") + +def apply(state: State, event: EventFromEventLog[Event]) -> State: + new_state: State = event_apply(state, event.event) + return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log}) + +@event_apply.register +def apply_task_created(state: State, event: TaskCreated) -> State: + new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task} + return state.model_copy(update={"tasks": new_tasks}) + +@event_apply.register +def apply_task_deleted(state: State, event: TaskDeleted) -> State: + new_tasks: Mapping[TaskId, Task] = {tid: task for tid, task in state.tasks.items() if tid != event.task_id} + return state.model_copy(update={"tasks": new_tasks}) + +@event_apply.register +def apply_task_state_updated(state: State, event: TaskStateUpdated) -> State: + if event.task_id not in state.tasks: + return state + + updated_task = state.tasks[event.task_id].model_copy(update={"task_status": event.task_status}) + new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task} + return state.model_copy(update={"tasks": new_tasks}) + +@event_apply.register +def apply_instance_created(state: State, event: InstanceCreated) -> State: + instance = BaseInstance(instance_params=event.instance_params, instance_type=event.instance_type) + new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: instance} + return state.model_copy(update={"instances": new_instances}) + +@event_apply.register +def apply_instance_activated(state: State, event: InstanceActivated) -> State: + if event.instance_id not in state.instances: + return state + + updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.ACTIVE}) + new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance} + return state.model_copy(update={"instances": new_instances}) + +@event_apply.register +def apply_instance_deactivated(state: State, event: InstanceDeactivated) -> State: + if event.instance_id not in state.instances: + return state + + updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.INACTIVE}) + new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance} + return state.model_copy(update={"instances": new_instances}) + +@event_apply.register +def apply_instance_deleted(state: State, event: InstanceDeleted) -> State: + new_instances: Mapping[InstanceId, BaseInstance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id} + return state.model_copy(update={"instances": new_instances}) + +@event_apply.register +def apply_instance_replaced_atomically(state: State, event: InstanceReplacedAtomically) -> State: + new_instances = dict(state.instances) + if event.instance_to_replace in new_instances: + del new_instances[event.instance_to_replace] + if event.new_instance_id in state.instances: + new_instances[event.new_instance_id] = state.instances[event.new_instance_id] + return state.model_copy(update={"instances": new_instances}) + +@event_apply.register +def apply_runner_status_updated(state: State, event: RunnerStatusUpdated) -> State: + new_runners: Mapping[RunnerId, RunnerStatus] = {**state.runners, event.runner_id: event.runner_status} + return state.model_copy(update={"runners": new_runners}) + +@event_apply.register +def apply_node_performance_measured(state: State, event: NodePerformanceMeasured) -> State: + new_profiles: Mapping[NodeId, NodePerformanceProfile] = {**state.node_profiles, event.node_id: event.node_profile} + return state.model_copy(update={"node_profiles": new_profiles}) + +@event_apply.register +def apply_worker_status_updated(state: State, event: WorkerStatusUpdated) -> State: + new_node_status: Mapping[NodeId, NodeStatus] = {**state.node_status, event.node_id: event.node_state} + return state.model_copy(update={"node_status": new_node_status}) + +@event_apply.register +def apply_chunk_generated(state: State, event: ChunkGenerated) -> State: + return state + +# TODO implemente these +@event_apply.register +def apply_worker_connected(state: State, event: WorkerConnected) -> State: + # source_node_id = event.edge.source_node_id + # sink_node_id = event.edge.sink_node_id + + # new_node_status = dict(state.node_status) + # if source_node_id not in new_node_status: + # new_node_status[source_node_id] = NodeStatus.Idle + # if sink_node_id not in new_node_status: + # new_node_status[sink_node_id] = NodeStatus.Idle + + # new_topology = Topology() + # new_topology.add_connection(event.edge) + + # return state.model_copy(update={"node_status": new_node_status, "topology": new_topology}) + return state + +@event_apply.register +def apply_worker_disconnected(state: State, event: WorkerDisconnected) -> State: + # new_node_status: Mapping[NodeId, NodeStatus] = {nid: status for nid, status in state.node_status.items() if nid != event.vertex_id} + + # new_topology = Topology() + + # new_history = list(state.history) + [state.topology] + + # return state.model_copy(update={ + # "node_status": new_node_status, + # "topology": new_topology, + # "history": new_history + # }) + return state + + +@event_apply.register +def apply_topology_edge_created(state: State, event: TopologyEdgeCreated) -> State: + # new_topology = Topology() + # new_topology.add_node(event.vertex, event.vertex.node_id) + # return state.model_copy(update={"topology": new_topology}) + return state + +@event_apply.register +def apply_topology_edge_replaced_atomically(state: State, event: TopologyEdgeReplacedAtomically) -> State: + # new_topology = Topology() + # new_topology.add_connection(event.edge) + # updated_connection = event.edge.model_copy(update={"connection_profile": event.edge_profile}) + # new_topology.update_connection_profile(updated_connection) + # return state.model_copy(update={"topology": new_topology}) + return state + +@event_apply.register +def apply_topology_edge_deleted(state: State, event: TopologyEdgeDeleted) -> State: + # new_topology = Topology() + # return state.model_copy(update={"topology": new_topology}) + return state + +@event_apply.register +def apply_mlx_inference_saga_prepare(state: State, event: MLXInferenceSagaPrepare) -> State: + return state + +@event_apply.register +def apply_mlx_inference_saga_start_prepare(state: State, event: MLXInferenceSagaStartPrepare) -> State: + return state \ No newline at end of file diff --git a/shared/types/events/_common.py b/shared/types/events/_common.py index a5a1b18a..53d2d4aa 100644 --- a/shared/types/events/_common.py +++ b/shared/types/events/_common.py @@ -1,6 +1,12 @@ +import types +import typing from enum import Enum from typing import TYPE_CHECKING +from shared.constants import get_error_reporting_message + +from ._events import _Event # pyright: ignore[reportPrivateUsage] + if TYPE_CHECKING: pass @@ -67,7 +73,7 @@ class _EventType(str, Enum): # TimerFired = "TimerFired" -class _BaseEvent[T: _EventType](BaseModel): # pyright: ignore[reportUnusedClass] +class _BaseEvent[T: _EventType](BaseModel): """ This is the event base-class, to please the Pydantic gods. PLEASE don't use this for anything unless you know why you are doing so, @@ -84,3 +90,58 @@ class _BaseEvent[T: _EventType](BaseModel): # pyright: ignore[reportUnusedClass Subclasses can override this method to implement specific validation logic. """ return True + + + +def _check_event_type_consistency(): + # Grab enum values from members + member_enum_values = [m for m in _EventType] + + # grab enum values from the union => scrape the type annotation + union_enum_values: list[_EventType] = [] + union_classes = list(typing.get_args(_Event)) + for cls in union_classes: # pyright: ignore[reportAny] + assert issubclass(cls, object), ( + f"{get_error_reporting_message()}", + f"The class {cls} is NOT a subclass of {object}." + ) + + # ensure the first base parameter is ALWAYS _BaseEvent + base_cls = list(types.get_original_bases(cls)) + assert len(base_cls) >= 1 and issubclass(base_cls[0], object) \ + and issubclass(base_cls[0], _BaseEvent), ( + f"{get_error_reporting_message()}", + f"The class {cls} does NOT inherit from {_BaseEvent} {typing.get_origin(base_cls[0])}." + ) + + # grab type hints and extract the right values from it + cls_hints = typing.get_type_hints(cls) + assert "event_type" in cls_hints and \ + typing.get_origin(cls_hints["event_type"]) is typing.Literal, ( # pyright: ignore[reportAny] + f"{get_error_reporting_message()}", + f"The class {cls} is missing a {typing.Literal}-annotated `event_type` field." + ) + + # make sure the value is an instance of `_EventType` + enum_value = list(typing.get_args(cls_hints["event_type"])) + assert len(enum_value) == 1 and isinstance(enum_value[0], _EventType), ( + f"{get_error_reporting_message()}", + f"The `event_type` of {cls} has a non-{_EventType} literal-type." + ) + union_enum_values.append(enum_value[0]) + + # ensure there is a 1:1 bijection between the two + for m in member_enum_values: + assert m in union_enum_values, ( + f"{get_error_reporting_message()}", + f"There is no event-type registered for {m} in {_Event}." + ) + union_enum_values.remove(m) + assert len(union_enum_values) == 0, ( + f"{get_error_reporting_message()}", + f"The following events have multiple event types defined in {_Event}: {union_enum_values}." + ) + + +_check_event_type_consistency() + diff --git a/shared/types/events/_events.py b/shared/types/events/_events.py index 07da96b9..06494877 100644 --- a/shared/types/events/_events.py +++ b/shared/types/events/_events.py @@ -1,4 +1,6 @@ -from typing import Literal +from typing import Annotated, Literal, Union + +from pydantic import Field from shared.topology import Connection, ConnectionProfile, Node, NodePerformanceProfile from shared.types.common import NodeId @@ -123,6 +125,34 @@ class TopologyEdgeDeleted(_BaseEvent[_EventType.TopologyEdgeDeleted]): event_type: Literal[_EventType.TopologyEdgeDeleted] = _EventType.TopologyEdgeDeleted edge: Connection +_Event = Union[ + TaskCreated, + TaskStateUpdated, + TaskDeleted, + InstanceCreated, + InstanceActivated, + InstanceDeactivated, + InstanceDeleted, + InstanceReplacedAtomically, + RunnerStatusUpdated, + NodePerformanceMeasured, + WorkerConnected, + WorkerStatusUpdated, + WorkerDisconnected, + ChunkGenerated, + TopologyEdgeCreated, + TopologyEdgeReplacedAtomically, + TopologyEdgeDeleted, + MLXInferenceSagaPrepare, + MLXInferenceSagaStartPrepare, +] +""" +Un-annotated union of all events. Only used internally to create the registry. +For all other usecases, use the annotated union of events :class:`Event` :) +""" + +Event = Annotated[_Event, Field(discriminator="event_type")] +"""Type of events, a discriminated union.""" # class TimerCreated(_BaseEvent[_EventType.TimerCreated]): # event_type: Literal[_EventType.TimerCreated] = _EventType.TimerCreated