mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
5
.gitignore
vendored
5
.gitignore
vendored
@@ -5,4 +5,7 @@ __pycache__
|
||||
hosts_*.json
|
||||
|
||||
# hide direnv stuff
|
||||
/.direnv
|
||||
/.direnv
|
||||
# TODO figure out how to properly solve the issue with these target directories showing up
|
||||
networking/target/
|
||||
networking/topology/target/
|
||||
|
||||
@@ -246,7 +246,3 @@ class AsyncSQLiteEventStorage:
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to commit batch: {e}")
|
||||
raise
|
||||
|
||||
async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return raw event data for testing purposes."""
|
||||
return event_data
|
||||
|
||||
@@ -14,7 +14,6 @@ from shared.types.common import NodeId
|
||||
from shared.types.events import (
|
||||
ChunkGenerated,
|
||||
CommandId,
|
||||
_EventType,
|
||||
)
|
||||
from shared.types.events.chunks import ChunkType, TokenChunk
|
||||
|
||||
@@ -472,7 +471,6 @@ class TestAsyncSQLiteEventStorage:
|
||||
# Verify the event was deserialized correctly
|
||||
retrieved_event = retrieved_event_wrapper.event
|
||||
assert isinstance(retrieved_event, ChunkGenerated)
|
||||
assert retrieved_event.event_type == _EventType.ChunkGenerated
|
||||
assert retrieved_event.command_id == command_id
|
||||
|
||||
# Verify the nested chunk was deserialized correctly
|
||||
|
||||
@@ -1,99 +1,16 @@
|
||||
# ruff: noqa: F403
|
||||
# ruff: noqa: F405
|
||||
|
||||
import types
|
||||
import typing
|
||||
from typing import Annotated, Union
|
||||
|
||||
# Note: we are implementing internal details here, so importing private stuff is fine!!!
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from ...constants import get_error_reporting_message
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
|
||||
from ._apply import Event, apply
|
||||
from ._common import *
|
||||
from ._common import _BaseEvent, _EventType # pyright: ignore[reportPrivateUsage]
|
||||
from ._events import *
|
||||
|
||||
_Event = Union[
|
||||
TaskCreated,
|
||||
TaskStateUpdated,
|
||||
TaskDeleted,
|
||||
InstanceCreated,
|
||||
InstanceActivated,
|
||||
InstanceDeactivated,
|
||||
InstanceDeleted,
|
||||
InstanceReplacedAtomically,
|
||||
RunnerStatusUpdated,
|
||||
NodePerformanceMeasured,
|
||||
WorkerConnected,
|
||||
WorkerStatusUpdated,
|
||||
WorkerDisconnected,
|
||||
ChunkGenerated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeReplacedAtomically,
|
||||
TopologyEdgeDeleted,
|
||||
MLXInferenceSagaPrepare,
|
||||
MLXInferenceSagaStartPrepare,
|
||||
]
|
||||
"""
|
||||
Un-annotated union of all events. Only used internally to create the registry.
|
||||
For all other usecases, use the annotated union of events :class:`Event` :)
|
||||
"""
|
||||
|
||||
Event = Annotated[_Event, Field(discriminator="event_type")]
|
||||
"""Type of events, a discriminated union."""
|
||||
|
||||
EventParser: TypeAdapter[Event] = TypeAdapter(Event)
|
||||
"""Type adaptor to parse :class:`Event`s."""
|
||||
|
||||
|
||||
def _check_event_type_consistency():
|
||||
# Grab enum values from members
|
||||
member_enum_values = [m for m in _EventType]
|
||||
|
||||
# grab enum values from the union => scrape the type annotation
|
||||
union_enum_values: list[_EventType] = []
|
||||
union_classes = list(typing.get_args(_Event))
|
||||
for cls in union_classes: # pyright: ignore[reportAny]
|
||||
assert issubclass(cls, object), (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The class {cls} is NOT a subclass of {object}."
|
||||
)
|
||||
|
||||
# ensure the first base parameter is ALWAYS _BaseEvent
|
||||
base_cls = list(types.get_original_bases(cls))
|
||||
assert len(base_cls) >= 1 and issubclass(base_cls[0], object) \
|
||||
and issubclass(base_cls[0], _BaseEvent), (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The class {cls} does NOT inherit from {_BaseEvent} {typing.get_origin(base_cls[0])}."
|
||||
)
|
||||
|
||||
# grab type hints and extract the right values from it
|
||||
cls_hints = typing.get_type_hints(cls)
|
||||
assert "event_type" in cls_hints and \
|
||||
typing.get_origin(cls_hints["event_type"]) is typing.Literal, ( # pyright: ignore[reportAny]
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The class {cls} is missing a {typing.Literal}-annotated `event_type` field."
|
||||
)
|
||||
|
||||
# make sure the value is an instance of `_EventType`
|
||||
enum_value = list(typing.get_args(cls_hints["event_type"]))
|
||||
assert len(enum_value) == 1 and isinstance(enum_value[0], _EventType), (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The `event_type` of {cls} has a non-{_EventType} literal-type."
|
||||
)
|
||||
union_enum_values.append(enum_value[0])
|
||||
|
||||
# ensure there is a 1:1 bijection between the two
|
||||
for m in member_enum_values:
|
||||
assert m in union_enum_values, (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"There is no event-type registered for {m} in {_Event}."
|
||||
)
|
||||
union_enum_values.remove(m)
|
||||
assert len(union_enum_values) == 0, (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The following events have multiple event types defined in {_Event}: {union_enum_values}."
|
||||
)
|
||||
|
||||
|
||||
_check_event_type_consistency()
|
||||
__all__ = ["Event", "EventParser", "apply", "EventFromEventLog"]
|
||||
|
||||
185
shared/types/events/_apply.py
Normal file
185
shared/types/events/_apply.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from functools import singledispatch
|
||||
from typing import Mapping, TypeVar
|
||||
|
||||
# from shared.topology import Topology
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events._events import Event
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.profiling import NodePerformanceProfile
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import NodeStatus, RunnerId
|
||||
from shared.types.worker.instances import BaseInstance, InstanceId, TypeOfInstance
|
||||
from shared.types.worker.runners import RunnerStatus
|
||||
|
||||
from ._events import (
|
||||
ChunkGenerated,
|
||||
InstanceActivated,
|
||||
InstanceCreated,
|
||||
InstanceDeactivated,
|
||||
InstanceDeleted,
|
||||
InstanceReplacedAtomically,
|
||||
MLXInferenceSagaPrepare,
|
||||
MLXInferenceSagaStartPrepare,
|
||||
NodePerformanceMeasured,
|
||||
RunnerStatusUpdated,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStateUpdated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
TopologyEdgeReplacedAtomically,
|
||||
WorkerConnected,
|
||||
WorkerDisconnected,
|
||||
WorkerStatusUpdated,
|
||||
)
|
||||
|
||||
S = TypeVar("S", bound=State)
|
||||
|
||||
@singledispatch
|
||||
def event_apply(state: State, event: Event) -> State:
|
||||
raise RuntimeError(f"no handler for {type(event).__name__}")
|
||||
|
||||
def apply(state: State, event: EventFromEventLog[Event]) -> State:
|
||||
new_state: State = event_apply(state, event.event)
|
||||
return new_state.model_copy(update={"last_event_applied_idx": event.idx_in_log})
|
||||
|
||||
@event_apply.register
|
||||
def apply_task_created(state: State, event: TaskCreated) -> State:
|
||||
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task}
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
@event_apply.register
|
||||
def apply_task_deleted(state: State, event: TaskDeleted) -> State:
|
||||
new_tasks: Mapping[TaskId, Task] = {tid: task for tid, task in state.tasks.items() if tid != event.task_id}
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
@event_apply.register
|
||||
def apply_task_state_updated(state: State, event: TaskStateUpdated) -> State:
|
||||
if event.task_id not in state.tasks:
|
||||
return state
|
||||
|
||||
updated_task = state.tasks[event.task_id].model_copy(update={"task_status": event.task_status})
|
||||
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
@event_apply.register
|
||||
def apply_instance_created(state: State, event: InstanceCreated) -> State:
|
||||
instance = BaseInstance(instance_params=event.instance_params, instance_type=event.instance_type)
|
||||
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: instance}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register
|
||||
def apply_instance_activated(state: State, event: InstanceActivated) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
|
||||
updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.ACTIVE})
|
||||
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register
|
||||
def apply_instance_deactivated(state: State, event: InstanceDeactivated) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
|
||||
updated_instance = state.instances[event.instance_id].model_copy(update={"type": TypeOfInstance.INACTIVE})
|
||||
new_instances: Mapping[InstanceId, BaseInstance] = {**state.instances, event.instance_id: updated_instance}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register
|
||||
def apply_instance_deleted(state: State, event: InstanceDeleted) -> State:
|
||||
new_instances: Mapping[InstanceId, BaseInstance] = {iid: inst for iid, inst in state.instances.items() if iid != event.instance_id}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register
|
||||
def apply_instance_replaced_atomically(state: State, event: InstanceReplacedAtomically) -> State:
|
||||
new_instances = dict(state.instances)
|
||||
if event.instance_to_replace in new_instances:
|
||||
del new_instances[event.instance_to_replace]
|
||||
if event.new_instance_id in state.instances:
|
||||
new_instances[event.new_instance_id] = state.instances[event.new_instance_id]
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
@event_apply.register
|
||||
def apply_runner_status_updated(state: State, event: RunnerStatusUpdated) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {**state.runners, event.runner_id: event.runner_status}
|
||||
return state.model_copy(update={"runners": new_runners})
|
||||
|
||||
@event_apply.register
|
||||
def apply_node_performance_measured(state: State, event: NodePerformanceMeasured) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {**state.node_profiles, event.node_id: event.node_profile}
|
||||
return state.model_copy(update={"node_profiles": new_profiles})
|
||||
|
||||
@event_apply.register
|
||||
def apply_worker_status_updated(state: State, event: WorkerStatusUpdated) -> State:
|
||||
new_node_status: Mapping[NodeId, NodeStatus] = {**state.node_status, event.node_id: event.node_state}
|
||||
return state.model_copy(update={"node_status": new_node_status})
|
||||
|
||||
@event_apply.register
|
||||
def apply_chunk_generated(state: State, event: ChunkGenerated) -> State:
|
||||
return state
|
||||
|
||||
# TODO implemente these
|
||||
@event_apply.register
|
||||
def apply_worker_connected(state: State, event: WorkerConnected) -> State:
|
||||
# source_node_id = event.edge.source_node_id
|
||||
# sink_node_id = event.edge.sink_node_id
|
||||
|
||||
# new_node_status = dict(state.node_status)
|
||||
# if source_node_id not in new_node_status:
|
||||
# new_node_status[source_node_id] = NodeStatus.Idle
|
||||
# if sink_node_id not in new_node_status:
|
||||
# new_node_status[sink_node_id] = NodeStatus.Idle
|
||||
|
||||
# new_topology = Topology()
|
||||
# new_topology.add_connection(event.edge)
|
||||
|
||||
# return state.model_copy(update={"node_status": new_node_status, "topology": new_topology})
|
||||
return state
|
||||
|
||||
@event_apply.register
|
||||
def apply_worker_disconnected(state: State, event: WorkerDisconnected) -> State:
|
||||
# new_node_status: Mapping[NodeId, NodeStatus] = {nid: status for nid, status in state.node_status.items() if nid != event.vertex_id}
|
||||
|
||||
# new_topology = Topology()
|
||||
|
||||
# new_history = list(state.history) + [state.topology]
|
||||
|
||||
# return state.model_copy(update={
|
||||
# "node_status": new_node_status,
|
||||
# "topology": new_topology,
|
||||
# "history": new_history
|
||||
# })
|
||||
return state
|
||||
|
||||
|
||||
@event_apply.register
|
||||
def apply_topology_edge_created(state: State, event: TopologyEdgeCreated) -> State:
|
||||
# new_topology = Topology()
|
||||
# new_topology.add_node(event.vertex, event.vertex.node_id)
|
||||
# return state.model_copy(update={"topology": new_topology})
|
||||
return state
|
||||
|
||||
@event_apply.register
|
||||
def apply_topology_edge_replaced_atomically(state: State, event: TopologyEdgeReplacedAtomically) -> State:
|
||||
# new_topology = Topology()
|
||||
# new_topology.add_connection(event.edge)
|
||||
# updated_connection = event.edge.model_copy(update={"connection_profile": event.edge_profile})
|
||||
# new_topology.update_connection_profile(updated_connection)
|
||||
# return state.model_copy(update={"topology": new_topology})
|
||||
return state
|
||||
|
||||
@event_apply.register
|
||||
def apply_topology_edge_deleted(state: State, event: TopologyEdgeDeleted) -> State:
|
||||
# new_topology = Topology()
|
||||
# return state.model_copy(update={"topology": new_topology})
|
||||
return state
|
||||
|
||||
@event_apply.register
|
||||
def apply_mlx_inference_saga_prepare(state: State, event: MLXInferenceSagaPrepare) -> State:
|
||||
return state
|
||||
|
||||
@event_apply.register
|
||||
def apply_mlx_inference_saga_start_prepare(state: State, event: MLXInferenceSagaStartPrepare) -> State:
|
||||
return state
|
||||
@@ -1,6 +1,12 @@
|
||||
import types
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from shared.constants import get_error_reporting_message
|
||||
|
||||
from ._events import _Event # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@@ -67,7 +73,7 @@ class _EventType(str, Enum):
|
||||
# TimerFired = "TimerFired"
|
||||
|
||||
|
||||
class _BaseEvent[T: _EventType](BaseModel): # pyright: ignore[reportUnusedClass]
|
||||
class _BaseEvent[T: _EventType](BaseModel):
|
||||
"""
|
||||
This is the event base-class, to please the Pydantic gods.
|
||||
PLEASE don't use this for anything unless you know why you are doing so,
|
||||
@@ -84,3 +90,58 @@ class _BaseEvent[T: _EventType](BaseModel): # pyright: ignore[reportUnusedClass
|
||||
Subclasses can override this method to implement specific validation logic.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def _check_event_type_consistency():
|
||||
# Grab enum values from members
|
||||
member_enum_values = [m for m in _EventType]
|
||||
|
||||
# grab enum values from the union => scrape the type annotation
|
||||
union_enum_values: list[_EventType] = []
|
||||
union_classes = list(typing.get_args(_Event))
|
||||
for cls in union_classes: # pyright: ignore[reportAny]
|
||||
assert issubclass(cls, object), (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The class {cls} is NOT a subclass of {object}."
|
||||
)
|
||||
|
||||
# ensure the first base parameter is ALWAYS _BaseEvent
|
||||
base_cls = list(types.get_original_bases(cls))
|
||||
assert len(base_cls) >= 1 and issubclass(base_cls[0], object) \
|
||||
and issubclass(base_cls[0], _BaseEvent), (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The class {cls} does NOT inherit from {_BaseEvent} {typing.get_origin(base_cls[0])}."
|
||||
)
|
||||
|
||||
# grab type hints and extract the right values from it
|
||||
cls_hints = typing.get_type_hints(cls)
|
||||
assert "event_type" in cls_hints and \
|
||||
typing.get_origin(cls_hints["event_type"]) is typing.Literal, ( # pyright: ignore[reportAny]
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The class {cls} is missing a {typing.Literal}-annotated `event_type` field."
|
||||
)
|
||||
|
||||
# make sure the value is an instance of `_EventType`
|
||||
enum_value = list(typing.get_args(cls_hints["event_type"]))
|
||||
assert len(enum_value) == 1 and isinstance(enum_value[0], _EventType), (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The `event_type` of {cls} has a non-{_EventType} literal-type."
|
||||
)
|
||||
union_enum_values.append(enum_value[0])
|
||||
|
||||
# ensure there is a 1:1 bijection between the two
|
||||
for m in member_enum_values:
|
||||
assert m in union_enum_values, (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"There is no event-type registered for {m} in {_Event}."
|
||||
)
|
||||
union_enum_values.remove(m)
|
||||
assert len(union_enum_values) == 0, (
|
||||
f"{get_error_reporting_message()}",
|
||||
f"The following events have multiple event types defined in {_Event}: {union_enum_values}."
|
||||
)
|
||||
|
||||
|
||||
_check_event_type_consistency()
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from shared.topology import Connection, ConnectionProfile, Node, NodePerformanceProfile
|
||||
from shared.types.common import NodeId
|
||||
@@ -123,6 +125,34 @@ class TopologyEdgeDeleted(_BaseEvent[_EventType.TopologyEdgeDeleted]):
|
||||
event_type: Literal[_EventType.TopologyEdgeDeleted] = _EventType.TopologyEdgeDeleted
|
||||
edge: Connection
|
||||
|
||||
_Event = Union[
|
||||
TaskCreated,
|
||||
TaskStateUpdated,
|
||||
TaskDeleted,
|
||||
InstanceCreated,
|
||||
InstanceActivated,
|
||||
InstanceDeactivated,
|
||||
InstanceDeleted,
|
||||
InstanceReplacedAtomically,
|
||||
RunnerStatusUpdated,
|
||||
NodePerformanceMeasured,
|
||||
WorkerConnected,
|
||||
WorkerStatusUpdated,
|
||||
WorkerDisconnected,
|
||||
ChunkGenerated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeReplacedAtomically,
|
||||
TopologyEdgeDeleted,
|
||||
MLXInferenceSagaPrepare,
|
||||
MLXInferenceSagaStartPrepare,
|
||||
]
|
||||
"""
|
||||
Un-annotated union of all events. Only used internally to create the registry.
|
||||
For all other usecases, use the annotated union of events :class:`Event` :)
|
||||
"""
|
||||
|
||||
Event = Annotated[_Event, Field(discriminator="event_type")]
|
||||
"""Type of events, a discriminated union."""
|
||||
|
||||
# class TimerCreated(_BaseEvent[_EventType.TimerCreated]):
|
||||
# event_type: Literal[_EventType.TimerCreated] = _EventType.TimerCreated
|
||||
|
||||
Reference in New Issue
Block a user