mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
fix: Event definitions, state definitions
This commit is contained in:
@@ -1,38 +1,56 @@
|
||||
from asyncio import Lock, Queue, Task, create_task, gather
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import List, LiteralString, Protocol, Literal
|
||||
from logging import Logger
|
||||
from typing import Any, List, Literal, Protocol, Type, TypedDict
|
||||
|
||||
from master.logging import (
|
||||
StateUpdateEffectHandlerErrorLogEntry,
|
||||
StateUpdateErrorLogEntry,
|
||||
StateUpdateLoopAlreadyRunningLogEntry,
|
||||
StateUpdateLoopNotRunningLogEntry,
|
||||
StateUpdateLoopStartedLogEntry,
|
||||
StateUpdateLoopStoppedLogEntry,
|
||||
)
|
||||
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
|
||||
from shared.logger import log
|
||||
from shared.types.events.common import (
|
||||
Apply,
|
||||
EffectHandler,
|
||||
Event,
|
||||
EventCategories,
|
||||
EventCategory,
|
||||
Event,
|
||||
EventCategoryEnum,
|
||||
EventFromEventLog,
|
||||
EventFetcherProtocol,
|
||||
EventFromEventLog,
|
||||
StateAndEvent,
|
||||
State,
|
||||
Apply,
|
||||
)
|
||||
from asyncio import Lock, Queue, Task, gather, create_task
|
||||
from typing import Any, Type, TypedDict
|
||||
from collections.abc import Mapping
|
||||
from shared.logger import log
|
||||
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
|
||||
from master.logging import (
|
||||
StateUpdateLoopAlreadyRunningLogEntry,
|
||||
StateUpdateLoopStartedLogEntry,
|
||||
StateUpdateLoopNotRunningLogEntry,
|
||||
StateUpdateLoopStoppedLogEntry,
|
||||
StateUpdateErrorLogEntry,
|
||||
StateUpdateEffectHandlerErrorLogEntry,
|
||||
)
|
||||
|
||||
|
||||
class QueueMapping(TypedDict):
|
||||
MutatesTaskState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]]
|
||||
MutatesControlPlaneState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]]
|
||||
MutatesDataPlaneState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]]
|
||||
MutatesInstanceState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]]
|
||||
MutatesNodePerformanceState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]]
|
||||
MutatesTaskState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
|
||||
]
|
||||
MutatesControlPlaneState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
|
||||
]
|
||||
MutatesDataPlaneState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
|
||||
]
|
||||
MutatesInstanceState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
|
||||
]
|
||||
MutatesNodePerformanceState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
|
||||
]
|
||||
MutatesRunnerStatus: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
|
||||
]
|
||||
MutatesTaskSagaState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
|
||||
]
|
||||
|
||||
|
||||
def check_keys_in_map_match_enum_values[TEnum: StrEnum](
|
||||
mapping_type: Type[Mapping[Any, Any]],
|
||||
@@ -44,8 +62,10 @@ def check_keys_in_map_match_enum_values[TEnum: StrEnum](
|
||||
f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}"
|
||||
)
|
||||
|
||||
|
||||
check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum)
|
||||
|
||||
|
||||
class AsyncUpdateStateFromEvents[EventCategoryT: EventCategory](Protocol):
|
||||
"""Protocol for services that manage a specific state domain."""
|
||||
|
||||
@@ -119,7 +139,7 @@ class AsyncUpdateStateFromEvents[EventCategoryT: EventCategory](Protocol):
|
||||
raise e
|
||||
try:
|
||||
for effect_handler in self._default_effects + self.extra_effects:
|
||||
effect_handler((previous_state, event), updated_state)
|
||||
effect_handler(StateAndEvent(previous_state, event), updated_state)
|
||||
except Exception as e:
|
||||
log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e))
|
||||
raise e
|
||||
@@ -149,7 +169,9 @@ class EventRouter:
|
||||
await self.queue_map[category].put(event_to_process)
|
||||
return None
|
||||
|
||||
async def _submit_events(self, events: list[Event[EventCategory | EventCategories]]) -> None:
|
||||
async def _submit_events(
|
||||
self, events: list[Event[EventCategory | EventCategories]]
|
||||
) -> None:
|
||||
"""Route multiple events to their appropriate services."""
|
||||
for event in events:
|
||||
for category in event.event_category:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Literal
|
||||
from collections.abc import Set
|
||||
from typing import Literal
|
||||
|
||||
from shared.logging.common import LogEntry, LogEntryType
|
||||
|
||||
|
||||
@@ -1,35 +1,41 @@
|
||||
from asyncio import CancelledError, Lock, Queue, Task, create_task
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from logging import Logger, LogRecord
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from logging import Logger
|
||||
|
||||
from shared.types.events.common import Event, EventCategories, EventFetcherProtocol, EventPublisher, State
|
||||
from shared.logger import (
|
||||
configure_logger,
|
||||
LogEntryType,
|
||||
FilterLogByType,
|
||||
create_queue_listener,
|
||||
attach_to_queue,
|
||||
from master.env import MasterEnvironmentSchema
|
||||
from master.event_routing import AsyncUpdateStateFromEvents
|
||||
from master.logging import (
|
||||
MasterCommandReceivedLogEntry,
|
||||
MasterInvalidCommandReceivedLogEntry,
|
||||
MasterUninitializedLogEntry,
|
||||
)
|
||||
from shared.constants import EXO_MASTER_STATE
|
||||
from shared.logger import (
|
||||
FilterLogByType,
|
||||
LogEntryType,
|
||||
attach_to_queue,
|
||||
configure_logger,
|
||||
create_queue_listener,
|
||||
log,
|
||||
)
|
||||
from shared.types.events.common import (
|
||||
Event,
|
||||
EventCategories,
|
||||
EventFetcherProtocol,
|
||||
EventPublisher,
|
||||
State,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import Instance
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.models.model import ModelInfo
|
||||
from shared.types.states.master import MasterState
|
||||
from shared.constants import EXO_MASTER_STATE
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import LogRecord
|
||||
from typing import Annotated, Literal
|
||||
from master.env import MasterEnvironmentSchema
|
||||
from master.logging import (
|
||||
MasterUninitializedLogEntry,
|
||||
MasterCommandReceivedLogEntry,
|
||||
MasterInvalidCommandReceivedLogEntry,
|
||||
)
|
||||
from master.event_routing import AsyncUpdateStateFromEvents
|
||||
from shared.logger import log
|
||||
from asyncio import Lock, Task, CancelledError, Queue, create_task
|
||||
from enum import Enum
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import Instance
|
||||
|
||||
|
||||
# Restore State
|
||||
@@ -76,6 +82,7 @@ ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalComman
|
||||
class MasterBackgroundServices(str, Enum):
|
||||
MAIN_LOOP = "main_loop"
|
||||
|
||||
|
||||
class StateManager[T: EventCategories]:
|
||||
state: State[T]
|
||||
queue: Queue[Event[T]]
|
||||
@@ -85,8 +92,8 @@ class StateManager[T: EventCategories]:
|
||||
self,
|
||||
state: State[T],
|
||||
queue: Queue[Event[T]],
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class MasterStateManager:
|
||||
"""Thread-safe manager for MasterState with independent event loop."""
|
||||
@@ -126,7 +133,9 @@ class MasterStateManager:
|
||||
case MasterBackgroundServices.MAIN_LOOP:
|
||||
if self._services[service]:
|
||||
raise RuntimeError("State manager is already running")
|
||||
self._services[service]: Task[None] = create_task(self._event_loop())
|
||||
self._services[service]: Task[None] = create_task(
|
||||
self._event_loop()
|
||||
)
|
||||
log(self._logger, MasterStateManagerStartedLogEntry())
|
||||
case _:
|
||||
raise ValueError(f"Unknown service: {service}")
|
||||
@@ -155,7 +164,7 @@ class MasterStateManager:
|
||||
events_one = self._event_processor.get_events_to_apply(
|
||||
self._state.data_plane_network_state
|
||||
)
|
||||
case EventCategories.InstanceStateEventTypes:
|
||||
case EventCategories.InstanceEventTypes:
|
||||
events_one = self._event_processor.get_events_to_apply(
|
||||
self._state.control_plane_network_state
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
EXO_HOME = Path.home() / ".exo"
|
||||
EXO_EVENT_DB = EXO_HOME / "event_db.sqlite3"
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from typing import Set, Mapping
|
||||
from dataclasses import dataclass
|
||||
from pydantic import TypeAdapter
|
||||
from typing import Mapping, Set
|
||||
|
||||
import rustworkx as rx
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from shared.types.graphs.common import (
|
||||
Edge,
|
||||
EdgeData,
|
||||
EdgeIdT,
|
||||
EdgeTypeT,
|
||||
MutableGraphProtocol,
|
||||
Vertex,
|
||||
VertexData,
|
||||
EdgeIdT,
|
||||
VertexIdT,
|
||||
EdgeTypeT,
|
||||
VertexTypeT,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,14 +2,13 @@ import logging
|
||||
import logging.handlers
|
||||
from collections.abc import Sequence, Set
|
||||
from queue import Queue
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field, TypeAdapter
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from typing import Annotated
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
from shared.logging.common import LogEntryType
|
||||
from master.logging import MasterLogEntries
|
||||
from shared.logging.common import LogEntryType
|
||||
from worker.logging import WorkerLogEntries
|
||||
|
||||
LogEntries = Annotated[
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from collections.abc import Set
|
||||
from enum import Enum
|
||||
from typing import Generic, TypeVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
from collections.abc import Set
|
||||
from pydantic import BaseModel
|
||||
|
||||
LogEntryTypeT = TypeVar("LogEntryTypeT", bound=str)
|
||||
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
from enum import Enum, StrEnum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
FrozenSet,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Protocol,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
import annotated_types
|
||||
|
||||
from shared.types.events.sanity_checking import (
|
||||
check_event_type_union_is_consistent_with_registry,
|
||||
assert_literal_union_covers_enum,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from shared.types.common import NewUUID, NodeId
|
||||
from typing import Callable, Sequence, Protocol
|
||||
from shared.types.events.sanity_checking import (
|
||||
assert_literal_union_covers_enum,
|
||||
check_event_type_union_is_consistent_with_registry,
|
||||
)
|
||||
|
||||
|
||||
class EventId(NewUUID):
|
||||
@@ -31,7 +29,7 @@ class TimerId(NewUUID):
|
||||
|
||||
# Here are all the unique kinds of events that can be sent over the network.
|
||||
# I've defined them in different enums for clarity, but they're all part of the same set of possible events.
|
||||
class MLXEventTypes(str, Enum):
|
||||
class TaskSagaEventTypes(str, Enum):
|
||||
MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare"
|
||||
MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare"
|
||||
|
||||
@@ -54,8 +52,8 @@ class InstanceEventTypes(str, Enum):
|
||||
InstanceReplacedAtomically = "InstanceReplacedAtomically"
|
||||
|
||||
|
||||
class InstanceStateEventTypes(str, Enum):
|
||||
InstanceSagaRunnerStateUpdated = "InstanceSagaRunnerStateUpdated"
|
||||
class RunnerStatusEventTypes(str, Enum):
|
||||
RunnerStatusUpdated = "RunnerStatusUpdated"
|
||||
|
||||
|
||||
class NodePerformanceEventTypes(str, Enum):
|
||||
@@ -84,12 +82,12 @@ EVENT_TYPE_ENUMS = [
|
||||
TaskEventTypes,
|
||||
StreamingEventTypes,
|
||||
InstanceEventTypes,
|
||||
InstanceStateEventTypes,
|
||||
RunnerStatusEventTypes,
|
||||
NodePerformanceEventTypes,
|
||||
DataPlaneEventTypes,
|
||||
ControlPlaneEventTypes,
|
||||
TimerEventTypes,
|
||||
MLXEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
]
|
||||
|
||||
|
||||
@@ -98,12 +96,12 @@ EventTypes = (
|
||||
TaskEventTypes
|
||||
| StreamingEventTypes
|
||||
| InstanceEventTypes
|
||||
| InstanceStateEventTypes
|
||||
| RunnerStatusEventTypes
|
||||
| NodePerformanceEventTypes
|
||||
| ControlPlaneEventTypes
|
||||
| DataPlaneEventTypes
|
||||
| TimerEventTypes
|
||||
| MLXEventTypes
|
||||
| TaskSagaEventTypes
|
||||
)
|
||||
|
||||
|
||||
@@ -112,6 +110,8 @@ check_event_type_union_is_consistent_with_registry(EVENT_TYPE_ENUMS, EventTypes)
|
||||
|
||||
class EventCategoryEnum(StrEnum):
|
||||
MutatesTaskState = "MutatesTaskState"
|
||||
MutatesRunnerStatus = "MutatesRunnerStatus"
|
||||
MutatesTaskSagaState = "MutatesTaskSagaState"
|
||||
MutatesInstanceState = "MutatesInstanceState"
|
||||
MutatesNodePerformanceState = "MutatesNodePerformanceState"
|
||||
MutatesControlPlaneState = "MutatesControlPlaneState"
|
||||
@@ -121,6 +121,8 @@ class EventCategoryEnum(StrEnum):
|
||||
EventCategory = (
|
||||
Literal[EventCategoryEnum.MutatesControlPlaneState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskState]
|
||||
| Literal[EventCategoryEnum.MutatesTaskSagaState]
|
||||
| Literal[EventCategoryEnum.MutatesRunnerStatus]
|
||||
| Literal[EventCategoryEnum.MutatesInstanceState]
|
||||
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
|
||||
| Literal[EventCategoryEnum.MutatesDataPlaneState]
|
||||
@@ -130,6 +132,7 @@ EventCategories = FrozenSet[EventCategory]
|
||||
|
||||
assert_literal_union_covers_enum(EventCategory, EventCategoryEnum)
|
||||
|
||||
|
||||
class Event[SetMembersT: EventCategories | EventCategory](BaseModel):
|
||||
event_type: EventTypes
|
||||
event_category: SetMembersT
|
||||
|
||||
@@ -3,20 +3,20 @@ from __future__ import annotations
|
||||
from typing import Literal, Tuple
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import GenerationChunk
|
||||
from shared.types.events.common import (
|
||||
Event,
|
||||
EventTypes,
|
||||
EventCategoryEnum,
|
||||
ControlPlaneEventTypes,
|
||||
DataPlaneEventTypes,
|
||||
Event,
|
||||
EventCategoryEnum,
|
||||
EventTypes,
|
||||
InstanceEventTypes,
|
||||
InstanceStateEventTypes,
|
||||
MLXEventTypes,
|
||||
NodePerformanceEventTypes,
|
||||
RunnerStatusEventTypes,
|
||||
StreamingEventTypes,
|
||||
TaskEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
)
|
||||
from shared.types.events.chunks import GenerationChunk
|
||||
from shared.types.networking.control_plane import (
|
||||
ControlPlaneEdgeId,
|
||||
ControlPlaneEdgeType,
|
||||
@@ -37,7 +37,7 @@ from shared.types.tasks.common import (
|
||||
)
|
||||
from shared.types.worker.common import InstanceId, NodeStatus
|
||||
from shared.types.worker.instances import InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.runners import RunnerId, RunnerState, RunnerStateType
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus, RunnerStatusType
|
||||
|
||||
MLXEvent = Event[
|
||||
frozenset(
|
||||
@@ -101,22 +101,22 @@ class InstanceReplacedAtomically(Event[EventCategoryEnum.MutatesInstanceState]):
|
||||
event_type: EventTypes = InstanceEventTypes.InstanceReplacedAtomically
|
||||
instance_to_replace: InstanceId
|
||||
new_instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceSagaRunnerStateUpdated(Event[EventCategoryEnum.MutatesInstanceState]):
|
||||
event_type: EventTypes = InstanceStateEventTypes.InstanceSagaRunnerStateUpdated
|
||||
|
||||
class RunnerStatusUpdated(Event[EventCategoryEnum.MutatesRunnerStatus]):
|
||||
event_type: EventTypes = RunnerStatusEventTypes.RunnerStatusUpdated
|
||||
instance_id: InstanceId
|
||||
state_update: Tuple[RunnerId, RunnerState[RunnerStateType]]
|
||||
state_update: Tuple[RunnerId, RunnerStatus[RunnerStatusType]]
|
||||
|
||||
|
||||
class MLXInferenceSagaPrepare(Event[EventCategoryEnum.MutatesTaskState]):
|
||||
event_type: EventTypes = MLXEventTypes.MLXInferenceSagaPrepare
|
||||
class MLXInferenceSagaPrepare(Event[EventCategoryEnum.MutatesTaskSagaState]):
|
||||
event_type: EventTypes = TaskSagaEventTypes.MLXInferenceSagaPrepare
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class MLXInferenceSagaStartPrepare(Event[EventCategoryEnum.MutatesTaskState]):
|
||||
event_type: EventTypes = MLXEventTypes.MLXInferenceSagaStartPrepare
|
||||
class MLXInferenceSagaStartPrepare(Event[EventCategoryEnum.MutatesTaskSagaState]):
|
||||
event_type: EventTypes = TaskSagaEventTypes.MLXInferenceSagaStartPrepare
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
@@ -1,41 +1,41 @@
|
||||
from typing import Any, Mapping, Type, get_args
|
||||
from types import UnionType
|
||||
from typing import Annotated, Any, Mapping, Type, get_args
|
||||
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
|
||||
from shared.types.events.common import (
|
||||
ControlPlaneEventTypes,
|
||||
DataPlaneEventTypes,
|
||||
Event,
|
||||
EventCategories,
|
||||
EventTypes,
|
||||
TaskEventTypes,
|
||||
InstanceEventTypes,
|
||||
NodePerformanceEventTypes,
|
||||
ControlPlaneEventTypes,
|
||||
RunnerStatusEventTypes,
|
||||
StreamingEventTypes,
|
||||
DataPlaneEventTypes,
|
||||
MLXEventTypes,
|
||||
InstanceStateEventTypes,
|
||||
TaskEventTypes,
|
||||
TaskSagaEventTypes,
|
||||
)
|
||||
from shared.types.events.events import (
|
||||
TaskCreated,
|
||||
TaskStateUpdated,
|
||||
TaskDeleted,
|
||||
ChunkGenerated,
|
||||
DataPlaneEdgeCreated,
|
||||
DataPlaneEdgeDeleted,
|
||||
DataPlaneEdgeReplacedAtomically,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceReplacedAtomically,
|
||||
InstanceSagaRunnerStateUpdated,
|
||||
NodePerformanceMeasured,
|
||||
WorkerConnected,
|
||||
WorkerStatusUpdated,
|
||||
WorkerDisconnected,
|
||||
ChunkGenerated,
|
||||
DataPlaneEdgeCreated,
|
||||
DataPlaneEdgeReplacedAtomically,
|
||||
DataPlaneEdgeDeleted,
|
||||
MLXInferenceSagaPrepare,
|
||||
MLXInferenceSagaStartPrepare,
|
||||
NodePerformanceMeasured,
|
||||
RunnerStatusUpdated,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStateUpdated,
|
||||
WorkerConnected,
|
||||
WorkerDisconnected,
|
||||
WorkerStatusUpdated,
|
||||
)
|
||||
from pydantic import TypeAdapter
|
||||
from typing import Annotated
|
||||
from pydantic import Field
|
||||
from shared.types.events.common import EventCategories
|
||||
|
||||
"""
|
||||
class EventTypeNames(StrEnum):
|
||||
@@ -58,7 +58,7 @@ EventRegistry: Mapping[EventTypes, Type[Any]] = {
|
||||
InstanceEventTypes.InstanceCreated: InstanceCreated,
|
||||
InstanceEventTypes.InstanceDeleted: InstanceDeleted,
|
||||
InstanceEventTypes.InstanceReplacedAtomically: InstanceReplacedAtomically,
|
||||
InstanceStateEventTypes.InstanceSagaRunnerStateUpdated: InstanceSagaRunnerStateUpdated,
|
||||
RunnerStatusEventTypes.RunnerStatusUpdated: RunnerStatusUpdated,
|
||||
NodePerformanceEventTypes.NodePerformanceMeasured: NodePerformanceMeasured,
|
||||
ControlPlaneEventTypes.WorkerConnected: WorkerConnected,
|
||||
ControlPlaneEventTypes.WorkerStatusUpdated: WorkerStatusUpdated,
|
||||
@@ -67,8 +67,8 @@ EventRegistry: Mapping[EventTypes, Type[Any]] = {
|
||||
DataPlaneEventTypes.DataPlaneEdgeCreated: DataPlaneEdgeCreated,
|
||||
DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically: DataPlaneEdgeReplacedAtomically,
|
||||
DataPlaneEventTypes.DataPlaneEdgeDeleted: DataPlaneEdgeDeleted,
|
||||
MLXEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
|
||||
MLXEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
|
||||
TaskSagaEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
|
||||
TaskSagaEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
|
||||
}
|
||||
|
||||
|
||||
@@ -86,9 +86,7 @@ def check_registry_has_all_event_types() -> None:
|
||||
def check_union_of_all_events_is_consistent_with_registry(
|
||||
registry: Mapping[EventTypes, Type[Any]], union_type: UnionType
|
||||
) -> None:
|
||||
type_of_each_registry_entry = set(
|
||||
type(event_type) for event_type in registry.keys()
|
||||
)
|
||||
type_of_each_registry_entry = set(type(event_type) for event_type in registry)
|
||||
type_of_each_entry_in_union = set(get_args(union_type))
|
||||
missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union
|
||||
|
||||
@@ -112,7 +110,7 @@ AllEvents = (
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceReplacedAtomically
|
||||
| InstanceSagaRunnerStateUpdated
|
||||
| RunnerStatusUpdated
|
||||
| NodePerformanceMeasured
|
||||
| WorkerConnected
|
||||
| WorkerStatusUpdated
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import LiteralString, Sequence, Set, Any, Type, get_args
|
||||
from types import UnionType
|
||||
from enum import Enum, StrEnum
|
||||
from types import UnionType
|
||||
from typing import Any, LiteralString, Sequence, Set, Type, get_args
|
||||
|
||||
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from shared.graphs.networkx import NetworkXGraph
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.networking.control_plane import ControlPlaneEdgeId
|
||||
from shared.types.networking.data_plane import (
|
||||
@@ -5,7 +6,6 @@ from shared.types.networking.data_plane import (
|
||||
DataPlaneEdgeId,
|
||||
)
|
||||
from shared.types.worker.common import NodeStatus
|
||||
from shared.graphs.networkx import NetworkXGraph
|
||||
|
||||
|
||||
class DataPlaneTopology(
|
||||
|
||||
@@ -4,19 +4,19 @@ from queue import Queue
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from shared.types.worker.common import NodeStatus
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.common import (
|
||||
Event,
|
||||
EventCategory,
|
||||
EventCategoryEnum,
|
||||
State,
|
||||
)
|
||||
from shared.types.graphs.resource_graph import ResourceGraph
|
||||
from shared.types.networking.data_plane import (
|
||||
DataPlaneEdge,
|
||||
DataPlaneEdgeId,
|
||||
DataPlaneEdgeAdapter,
|
||||
DataPlaneEdgeId,
|
||||
)
|
||||
from shared.types.networking.topology import (
|
||||
ControlPlaneTopology,
|
||||
@@ -27,7 +27,8 @@ from shared.types.networking.topology import (
|
||||
from shared.types.profiling.common import NodePerformanceProfile
|
||||
from shared.types.states.shared import SharedState
|
||||
from shared.types.tasks.common import TaskParams, TaskType
|
||||
from shared.types.worker.instances import InstanceParams, InstanceId
|
||||
from shared.types.worker.common import NodeStatus
|
||||
from shared.types.worker.instances import InstanceId, InstanceParams
|
||||
|
||||
|
||||
class ExternalCommand(BaseModel): ...
|
||||
@@ -44,13 +45,13 @@ class CachePolicy(BaseModel, Generic[CachePolicyTypeT]):
|
||||
policy_type: CachePolicyTypeT
|
||||
|
||||
|
||||
class NodePerformanceProfileState(State[EventCategory.MutatesNodePerformanceState]):
|
||||
class NodePerformanceProfileState(State[EventCategoryEnum.MutatesNodePerformanceState]):
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile]
|
||||
|
||||
|
||||
class DataPlaneNetworkState(State[EventCategory.MutatesDataPlaneState]):
|
||||
event_category: Literal[EventCategory.MutatesDataPlaneState] = (
|
||||
EventCategory.MutatesDataPlaneState
|
||||
class DataPlaneNetworkState(State[EventCategoryEnum.MutatesDataPlaneState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = (
|
||||
EventCategoryEnum.MutatesDataPlaneState
|
||||
)
|
||||
topology: DataPlaneTopology = DataPlaneTopology(
|
||||
edge_base=DataPlaneEdgeAdapter, vertex_base=TypeAdapter(None)
|
||||
@@ -61,9 +62,9 @@ class DataPlaneNetworkState(State[EventCategory.MutatesDataPlaneState]):
|
||||
def add_edge(self, edge: DataPlaneEdge) -> None: ...
|
||||
|
||||
|
||||
class ControlPlaneNetworkState(State[EventCategory.MutatesControlPlaneState]):
|
||||
event_category: Literal[EventCategory.MutatesControlPlaneState] = (
|
||||
EventCategory.MutatesControlPlaneState
|
||||
class ControlPlaneNetworkState(State[EventCategoryEnum.MutatesControlPlaneState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = (
|
||||
EventCategoryEnum.MutatesControlPlaneState
|
||||
)
|
||||
topology: ControlPlaneTopology = ControlPlaneTopology(
|
||||
edge_base=TypeAdapter(None), vertex_base=TypeAdapter(NodeStatus)
|
||||
|
||||
@@ -4,29 +4,52 @@ from typing import Literal, Sequence
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.common import EventCategories, State
|
||||
from shared.types.tasks.common import Task, TaskId, TaskStatusType, TaskType
|
||||
from shared.types.events.common import EventCategoryEnum, State
|
||||
from shared.types.tasks.common import (
|
||||
Task,
|
||||
TaskId,
|
||||
TaskSagaEntry,
|
||||
TaskStatusType,
|
||||
TaskType,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import BaseInstance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus, RunnerStatusType
|
||||
|
||||
|
||||
class KnownInstances(State[EventCategories.InstanceStateEventTypes]):
|
||||
event_category: Literal[EventCategories.InstanceStateEventTypes] = (
|
||||
EventCategories.InstanceStateEventTypes
|
||||
class Instances(State[EventCategoryEnum.MutatesInstanceState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = (
|
||||
EventCategoryEnum.MutatesInstanceState
|
||||
)
|
||||
instances: Mapping[InstanceId, BaseInstance] = {}
|
||||
|
||||
|
||||
class Tasks(State[EventCategories.TaskEventTypes]):
|
||||
event_category: Literal[EventCategories.TaskEventTypes] = (
|
||||
EventCategories.TaskEventTypes
|
||||
class Tasks(State[EventCategoryEnum.MutatesTaskState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskState] = (
|
||||
EventCategoryEnum.MutatesTaskState
|
||||
)
|
||||
tasks: Mapping[TaskId, Task[TaskType, TaskStatusType]] = {}
|
||||
|
||||
|
||||
class TaskSagas(State[EventCategoryEnum.MutatesTaskSagaState]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesTaskSagaState] = (
|
||||
EventCategoryEnum.MutatesTaskSagaState
|
||||
)
|
||||
task_sagas: Mapping[TaskId, Sequence[TaskSagaEntry]] = {}
|
||||
|
||||
|
||||
class Runners(State[EventCategoryEnum.MutatesRunnerStatus]):
|
||||
event_category: Literal[EventCategoryEnum.MutatesRunnerStatus] = (
|
||||
EventCategoryEnum.MutatesRunnerStatus
|
||||
)
|
||||
runner_statuses: Mapping[RunnerId, RunnerStatus[RunnerStatusType]] = {}
|
||||
|
||||
|
||||
class SharedState(BaseModel):
|
||||
known_instances: KnownInstances = KnownInstances()
|
||||
compute_tasks: Tasks = Tasks()
|
||||
instances: Instances = Instances()
|
||||
runners: Runners = Runners()
|
||||
tasks: Tasks = Tasks()
|
||||
task_sagas: TaskSagas = TaskSagas()
|
||||
|
||||
def get_node_id(self) -> NodeId: ...
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ class TaskState[TaskStatusTypeT: TaskStatusType, TaskTypeT: TaskType](BaseModel)
|
||||
class BaseTask[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel):
|
||||
task_type: TaskTypeT
|
||||
task_params: TaskParams[TaskTypeT]
|
||||
task_state: TaskState[TaskStatusTypeT, TaskTypeT]
|
||||
task_stats: TaskState[TaskStatusTypeT, TaskTypeT]
|
||||
on_instance: InstanceId
|
||||
|
||||
|
||||
@@ -100,6 +100,11 @@ BaseTaskParser: TypeAdapter[BaseTask[TaskType, TaskStatusType]] = TypeAdapter(
|
||||
)
|
||||
|
||||
|
||||
class TaskSagaEntry(BaseModel):
|
||||
task_id: TaskId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
@final
|
||||
class Task[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](
|
||||
BaseTask[TaskTypeT, TaskStatusTypeT]
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.runners import (
|
||||
RunnerId,
|
||||
RunnerState,
|
||||
RunnerStateType,
|
||||
ShardAssignments,
|
||||
)
|
||||
|
||||
@@ -28,11 +24,3 @@ class BaseInstance(BaseModel):
|
||||
|
||||
class Instance(BaseInstance):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class BaseInstanceSaga(BaseModel):
|
||||
runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]]
|
||||
|
||||
|
||||
class InstanceSaga(BaseInstanceSaga):
|
||||
instance_id: InstanceId
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Generic, Literal, TypeVar, Annotated
|
||||
from typing import Annotated, Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
|
||||
@@ -11,7 +11,7 @@ from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
|
||||
from shared.types.worker.shards import PartitionStrategy, ShardMetadata
|
||||
|
||||
|
||||
class RunnerStateType(str, Enum):
|
||||
class RunnerStatusType(str, Enum):
|
||||
Rejected = "Rejected"
|
||||
Starting = "Starting"
|
||||
Downloading = "Downloading"
|
||||
@@ -19,44 +19,46 @@ class RunnerStateType(str, Enum):
|
||||
Failed = "Failed"
|
||||
|
||||
|
||||
RunnerStateTypeT = TypeVar("RunnerStateTypeT", bound=RunnerStateType)
|
||||
RunnerStatusTypeT = TypeVar("RunnerStatusTypeT", bound=RunnerStatusType)
|
||||
|
||||
|
||||
class RunnerState(BaseModel, Generic[RunnerStateTypeT]):
|
||||
runner_state: RunnerStateTypeT
|
||||
class RunnerStatus(BaseModel, Generic[RunnerStatusTypeT]):
|
||||
runner_status: RunnerStatusTypeT
|
||||
|
||||
|
||||
class RejectedRunnerState(RunnerState[RunnerStateType.Rejected]):
|
||||
runner_state: Literal[RunnerStateType.Rejected]
|
||||
class RejectedRunnerStatus(RunnerStatus[RunnerStatusType.Rejected]):
|
||||
runner_status: Literal[RunnerStatusType.Rejected]
|
||||
|
||||
|
||||
class StartingRunnerState(RunnerState[RunnerStateType.Starting]):
|
||||
runner_state: Literal[RunnerStateType.Starting]
|
||||
class StartingRunnerStatus(RunnerStatus[RunnerStatusType.Starting]):
|
||||
runner_status: Literal[RunnerStatusType.Starting]
|
||||
|
||||
|
||||
class DownloadingRunnerState(RunnerState[RunnerStateType.Downloading]):
|
||||
runner_state: Literal[RunnerStateType.Downloading]
|
||||
class DownloadingRunnerStatus(RunnerStatus[RunnerStatusType.Downloading]):
|
||||
runner_status: Literal[RunnerStatusType.Downloading]
|
||||
download_progress: BaseDownloadProgress[DownloadStatus]
|
||||
|
||||
|
||||
class RunningRunnerState(RunnerState[RunnerStateType.Running]):
|
||||
runner_state: Literal[RunnerStateType.Running]
|
||||
class RunningRunnerStatus(RunnerStatus[RunnerStatusType.Running]):
|
||||
runner_status: Literal[RunnerStatusType.Running]
|
||||
|
||||
|
||||
class FailedRunnerState(RunnerState[RunnerStateType.Failed]):
|
||||
runner_state: Literal[RunnerStateType.Failed]
|
||||
class FailedRunnerStatus(RunnerStatus[RunnerStatusType.Failed]):
|
||||
runner_status: Literal[RunnerStatusType.Failed]
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
_RunnerState = Annotated[
|
||||
RejectedRunnerState
|
||||
| StartingRunnerState
|
||||
| DownloadingRunnerState
|
||||
| RunningRunnerState
|
||||
| FailedRunnerState,
|
||||
_RunnerStatus = Annotated[
|
||||
RejectedRunnerStatus
|
||||
| StartingRunnerStatus
|
||||
| DownloadingRunnerStatus
|
||||
| RunningRunnerStatus
|
||||
| FailedRunnerStatus,
|
||||
Field,
|
||||
]
|
||||
RunnerStateParser: TypeAdapter[RunnerState[RunnerStateType]] = TypeAdapter(_RunnerState)
|
||||
RunnerStatusParser: TypeAdapter[RunnerStatus[RunnerStatusType]] = TypeAdapter(
|
||||
_RunnerStatus
|
||||
)
|
||||
|
||||
|
||||
class ShardAssignments(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Literal
|
||||
from collections.abc import Set
|
||||
from typing import Literal
|
||||
|
||||
from shared.logging.common import LogEntry, LogEntryType
|
||||
|
||||
@@ -10,4 +10,4 @@ class WorkerUninitialized(LogEntry[Literal["master_uninitialized"]]):
|
||||
message: str = "No master state found, creating new one."
|
||||
|
||||
|
||||
WorkerLogEntries = WorkerUninitialized
|
||||
WorkerLogEntries = WorkerUninitialized
|
||||
|
||||
Reference in New Issue
Block a user