Add apply functions

Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
This commit is contained in:
Gelu Vrabie
2025-07-24 11:02:20 +01:00
committed by GitHub
parent 3ab5609289
commit 56d3565781
7 changed files with 287 additions and 97 deletions

5
.gitignore vendored
View File

@@ -5,4 +5,7 @@ __pycache__
hosts_*.json
# hide direnv stuff
/.direnv
/.direnv
# TODO figure out how to properly solve the issue with these target directories showing up
networking/target/
networking/topology/target/

View File

@@ -246,7 +246,3 @@ class AsyncSQLiteEventStorage:
except Exception as e:
self._logger.error(f"Failed to commit batch: {e}")
raise
async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]:
"""Return raw event data for testing purposes."""
return event_data

View File

@@ -14,7 +14,6 @@ from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
CommandId,
_EventType,
)
from shared.types.events.chunks import ChunkType, TokenChunk
@@ -472,7 +471,6 @@ class TestAsyncSQLiteEventStorage:
# Verify the event was deserialized correctly
retrieved_event = retrieved_event_wrapper.event
assert isinstance(retrieved_event, ChunkGenerated)
assert retrieved_event.event_type == _EventType.ChunkGenerated
assert retrieved_event.command_id == command_id
# Verify the nested chunk was deserialized correctly

View File

@@ -1,99 +1,16 @@
# ruff: noqa: F403
# ruff: noqa: F405
import types
import typing
from typing import Annotated, Union
# Note: we are implementing internal details here, so importing private stuff is fine!!!
from pydantic import Field, TypeAdapter
from pydantic import TypeAdapter
from ...constants import get_error_reporting_message
from shared.types.events.components import EventFromEventLog
from ._apply import Event, apply
from ._common import *
from ._common import _BaseEvent, _EventType # pyright: ignore[reportPrivateUsage]
from ._events import *
_Event = Union[
TaskCreated,
TaskStateUpdated,
TaskDeleted,
InstanceCreated,
InstanceActivated,
InstanceDeactivated,
InstanceDeleted,
InstanceReplacedAtomically,
RunnerStatusUpdated,
NodePerformanceMeasured,
WorkerConnected,
WorkerStatusUpdated,
WorkerDisconnected,
ChunkGenerated,
TopologyEdgeCreated,
TopologyEdgeReplacedAtomically,
TopologyEdgeDeleted,
MLXInferenceSagaPrepare,
MLXInferenceSagaStartPrepare,
]
"""
Un-annotated union of all events. Only used internally to create the registry.
For all other usecases, use the annotated union of events :class:`Event` :)
"""
Event = Annotated[_Event, Field(discriminator="event_type")]
"""Type of events, a discriminated union."""
EventParser: TypeAdapter[Event] = TypeAdapter(Event)
"""Type adaptor to parse :class:`Event`s."""
def _check_event_type_consistency():
# Grab enum values from members
member_enum_values = [m for m in _EventType]
# grab enum values from the union => scrape the type annotation
union_enum_values: list[_EventType] = []
union_classes = list(typing.get_args(_Event))
for cls in union_classes: # pyright: ignore[reportAny]
assert issubclass(cls, object), (
f"{get_error_reporting_message()}",
f"The class {cls} is NOT a subclass of {object}."
)
# ensure the first base parameter is ALWAYS _BaseEvent
base_cls = list(types.get_original_bases(cls))
assert len(base_cls) >= 1 and issubclass(base_cls[0], object) \
and issubclass(base_cls[0], _BaseEvent), (
f"{get_error_reporting_message()}",
f"The class {cls} does NOT inherit from {_BaseEvent} {typing.get_origin(base_cls[0])}."
)
# grab type hints and extract the right values from it
cls_hints = typing.get_type_hints(cls)
assert "event_type" in cls_hints and \
typing.get_origin(cls_hints["event_type"]) is typing.Literal, ( # pyright: ignore[reportAny]
f"{get_error_reporting_message()}",
f"The class {cls} is missing a {typing.Literal}-annotated `event_type` field."
)
# make sure the value is an instance of `_EventType`
enum_value = list(typing.get_args(cls_hints["event_type"]))
assert len(enum_value) == 1 and isinstance(enum_value[0], _EventType), (
f"{get_error_reporting_message()}",
f"The `event_type` of {cls} has a non-{_EventType} literal-type."
)
union_enum_values.append(enum_value[0])
# ensure there is a 1:1 bijection between the two
for m in member_enum_values:
assert m in union_enum_values, (
f"{get_error_reporting_message()}",
f"There is no event-type registered for {m} in {_Event}."
)
union_enum_values.remove(m)
assert len(union_enum_values) == 0, (
f"{get_error_reporting_message()}",
f"The following events have multiple event types defined in {_Event}: {union_enum_values}."
)
_check_event_type_consistency()
__all__ = ["Event", "EventParser", "apply", "EventFromEventLog"]

View File

@@ -0,0 +1,185 @@
from functools import singledispatch
from typing import Mapping, TypeVar
# from shared.topology import Topology
from shared.types.common import NodeId
from shared.types.events._events import Event
from shared.types.events.components import EventFromEventLog
from shared.types.profiling import NodePerformanceProfile
from shared.types.state import State
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import NodeStatus, RunnerId
from shared.types.worker.instances import BaseInstance, InstanceId, TypeOfInstance
from shared.types.worker.runners import RunnerStatus
from ._events import (
ChunkGenerated,
InstanceActivated,
InstanceCreated,
InstanceDeactivated,
InstanceDeleted,
InstanceReplacedAtomically,
MLXInferenceSagaPrepare,
MLXInferenceSagaStartPrepare,
NodePerformanceMeasured,
RunnerStatusUpdated,
TaskCreated,
TaskDeleted,
TaskStateUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
TopologyEdgeReplacedAtomically,
WorkerConnected,
WorkerDisconnected,
WorkerStatusUpdated,
)
S = TypeVar("S", bound=State)
@singledispatch
def event_apply(state: State, event: Event) -> State:
raise RuntimeError(f"no handler for {type(event).__name__}")
def apply(state: State, event: EventFromEventLog[Event]) -> State:
new_state: State = event_apply(state, event.event)
return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log})
@event_apply.register
def apply_task_created(state: State, event: TaskCreated) -> State:
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task}
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register
def apply_task_deleted(state: State, event: TaskDeleted) -> State:
new_tasks: Mapping[TaskId, Task] = {tid: task for tid, task in state.tasks.items() if tid != event.task_id}
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register
def apply_task_state_updated(state: State, event: TaskStateUpdated) -> State:
if event.task_id not in state.tasks:
return state
updated_task = state.tasks[event.task_id].model_copy(update={"task_status": event.task_status})
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register
def apply_instance_created(state: State, event: InstanceCreated) -> State:
instance = BaseInstance(instance_params=event.instance_params, instance_type=event.instance_type)
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_instance_activated(state: State, event: InstanceActivated) -> State:
if event.instance_id not in state.instances:
return state
updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.ACTIVE})
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_instance_deactivated(state: State, event: InstanceDeactivated) -> State:
if event.instance_id not in state.instances:
return state
updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.INACTIVE})
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_instance_deleted(state: State, event: InstanceDeleted) -> State:
new_instances: Mapping[InstanceId, BaseInstance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id}
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_instance_replaced_atomically(state: State, event: InstanceReplacedAtomically) -> State:
new_instances = dict(state.instances)
if event.instance_to_replace in new_instances:
del new_instances[event.instance_to_replace]
if event.new_instance_id in state.instances:
new_instances[event.new_instance_id] = state.instances[event.new_instance_id]
return state.model_copy(update={"instances": new_instances})
@event_apply.register
def apply_runner_status_updated(state: State, event: RunnerStatusUpdated) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {**state.runners, event.runner_id: event.runner_status}
return state.model_copy(update={"runners": new_runners})
@event_apply.register
def apply_node_performance_measured(state: State, event: NodePerformanceMeasured) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {**state.node_profiles, event.node_id: event.node_profile}
return state.model_copy(update={"node_profiles": new_profiles})
@event_apply.register
def apply_worker_status_updated(state: State, event: WorkerStatusUpdated) -> State:
new_node_status: Mapping[NodeId, NodeStatus] = {**state.node_status, event.node_id: event.node_state}
return state.model_copy(update={"node_status": new_node_status})
@event_apply.register
def apply_chunk_generated(state: State, event: ChunkGenerated) -> State:
return state
# TODO implemente these
@event_apply.register
def apply_worker_connected(state: State, event: WorkerConnected) -> State:
# source_node_id = event.edge.source_node_id
# sink_node_id = event.edge.sink_node_id
# new_node_status = dict(state.node_status)
# if source_node_id not in new_node_status:
# new_node_status[source_node_id] = NodeStatus.Idle
# if sink_node_id not in new_node_status:
# new_node_status[sink_node_id] = NodeStatus.Idle
# new_topology = Topology()
# new_topology.add_connection(event.edge)
# return state.model_copy(update={"node_status": new_node_status, "topology": new_topology})
return state
@event_apply.register
def apply_worker_disconnected(state: State, event: WorkerDisconnected) -> State:
# new_node_status: Mapping[NodeId, NodeStatus] = {nid: status for nid, status in state.node_status.items() if nid != event.vertex_id}
# new_topology = Topology()
# new_history = list(state.history) + [state.topology]
# return state.model_copy(update={
# "node_status": new_node_status,
# "topology": new_topology,
# "history": new_history
# })
return state
@event_apply.register
def apply_topology_edge_created(state: State, event: TopologyEdgeCreated) -> State:
# new_topology = Topology()
# new_topology.add_node(event.vertex, event.vertex.node_id)
# return state.model_copy(update={"topology": new_topology})
return state
@event_apply.register
def apply_topology_edge_replaced_atomically(state: State, event: TopologyEdgeReplacedAtomically) -> State:
# new_topology = Topology()
# new_topology.add_connection(event.edge)
# updated_connection = event.edge.model_copy(update={"connection_profile": event.edge_profile})
# new_topology.update_connection_profile(updated_connection)
# return state.model_copy(update={"topology": new_topology})
return state
@event_apply.register
def apply_topology_edge_deleted(state: State, event: TopologyEdgeDeleted) -> State:
# new_topology = Topology()
# return state.model_copy(update={"topology": new_topology})
return state
@event_apply.register
def apply_mlx_inference_saga_prepare(state: State, event: MLXInferenceSagaPrepare) -> State:
return state
@event_apply.register
def apply_mlx_inference_saga_start_prepare(state: State, event: MLXInferenceSagaStartPrepare) -> State:
return state

View File

@@ -1,6 +1,12 @@
import types
import typing
from enum import Enum
from typing import TYPE_CHECKING
from shared.constants import get_error_reporting_message
from ._events import _Event # pyright: ignore[reportPrivateUsage]
if TYPE_CHECKING:
pass
@@ -67,7 +73,7 @@ class _EventType(str, Enum):
# TimerFired = "TimerFired"
class _BaseEvent[T: _EventType](BaseModel): # pyright: ignore[reportUnusedClass]
class _BaseEvent[T: _EventType](BaseModel):
"""
This is the event base-class, to please the Pydantic gods.
PLEASE don't use this for anything unless you know why you are doing so,
@@ -84,3 +90,58 @@ class _BaseEvent[T: _EventType](BaseModel): # pyright: ignore[reportUnusedClass
Subclasses can override this method to implement specific validation logic.
"""
return True
def _check_event_type_consistency():
# Grab enum values from members
member_enum_values = [m for m in _EventType]
# grab enum values from the union => scrape the type annotation
union_enum_values: list[_EventType] = []
union_classes = list(typing.get_args(_Event))
for cls in union_classes: # pyright: ignore[reportAny]
assert issubclass(cls, object), (
f"{get_error_reporting_message()}",
f"The class {cls} is NOT a subclass of {object}."
)
# ensure the first base parameter is ALWAYS _BaseEvent
base_cls = list(types.get_original_bases(cls))
assert len(base_cls) >= 1 and issubclass(base_cls[0], object) \
and issubclass(base_cls[0], _BaseEvent), (
f"{get_error_reporting_message()}",
f"The class {cls} does NOT inherit from {_BaseEvent} {typing.get_origin(base_cls[0])}."
)
# grab type hints and extract the right values from it
cls_hints = typing.get_type_hints(cls)
assert "event_type" in cls_hints and \
typing.get_origin(cls_hints["event_type"]) is typing.Literal, ( # pyright: ignore[reportAny]
f"{get_error_reporting_message()}",
f"The class {cls} is missing a {typing.Literal}-annotated `event_type` field."
)
# make sure the value is an instance of `_EventType`
enum_value = list(typing.get_args(cls_hints["event_type"]))
assert len(enum_value) == 1 and isinstance(enum_value[0], _EventType), (
f"{get_error_reporting_message()}",
f"The `event_type` of {cls} has a non-{_EventType} literal-type."
)
union_enum_values.append(enum_value[0])
# ensure there is a 1:1 bijection between the two
for m in member_enum_values:
assert m in union_enum_values, (
f"{get_error_reporting_message()}",
f"There is no event-type registered for {m} in {_Event}."
)
union_enum_values.remove(m)
assert len(union_enum_values) == 0, (
f"{get_error_reporting_message()}",
f"The following events have multiple event types defined in {_Event}: {union_enum_values}."
)
_check_event_type_consistency()

View File

@@ -1,4 +1,6 @@
from typing import Literal
from typing import Annotated, Literal, Union
from pydantic import Field
from shared.topology import Connection, ConnectionProfile, Node, NodePerformanceProfile
from shared.types.common import NodeId
@@ -123,6 +125,34 @@ class TopologyEdgeDeleted(_BaseEvent[_EventType.TopologyEdgeDeleted]):
event_type: Literal[_EventType.TopologyEdgeDeleted] = _EventType.TopologyEdgeDeleted
edge: Connection
_Event = Union[
TaskCreated,
TaskStateUpdated,
TaskDeleted,
InstanceCreated,
InstanceActivated,
InstanceDeactivated,
InstanceDeleted,
InstanceReplacedAtomically,
RunnerStatusUpdated,
NodePerformanceMeasured,
WorkerConnected,
WorkerStatusUpdated,
WorkerDisconnected,
ChunkGenerated,
TopologyEdgeCreated,
TopologyEdgeReplacedAtomically,
TopologyEdgeDeleted,
MLXInferenceSagaPrepare,
MLXInferenceSagaStartPrepare,
]
"""
Un-annotated union of all events. Only used internally to create the registry.
For all other usecases, use the annotated union of events :class:`Event` :)
"""
Event = Annotated[_Event, Field(discriminator="event_type")]
"""Type of events, a discriminated union."""
# class TimerCreated(_BaseEvent[_EventType.TimerCreated]):
# event_type: Literal[_EventType.TimerCreated] = _EventType.TimerCreated