fix: Event definitions, state definitions

This commit is contained in:
Arbion Halili
2025-07-14 21:41:14 +01:00
parent 70f0f09c05
commit df6626fa31
18 changed files with 234 additions and 184 deletions

View File

@@ -1,38 +1,56 @@
from asyncio import Lock, Queue, Task, create_task, gather
from collections.abc import Mapping
from enum import StrEnum
from typing import List, LiteralString, Protocol, Literal
from logging import Logger
from typing import Any, List, Literal, Protocol, Type, TypedDict
from master.logging import (
StateUpdateEffectHandlerErrorLogEntry,
StateUpdateErrorLogEntry,
StateUpdateLoopAlreadyRunningLogEntry,
StateUpdateLoopNotRunningLogEntry,
StateUpdateLoopStartedLogEntry,
StateUpdateLoopStoppedLogEntry,
)
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
from shared.logger import log
from shared.types.events.common import (
Apply,
EffectHandler,
Event,
EventCategories,
EventCategory,
Event,
EventCategoryEnum,
EventFromEventLog,
EventFetcherProtocol,
EventFromEventLog,
StateAndEvent,
State,
Apply,
)
from asyncio import Lock, Queue, Task, gather, create_task
from typing import Any, Type, TypedDict
from collections.abc import Mapping
from shared.logger import log
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
from master.logging import (
StateUpdateLoopAlreadyRunningLogEntry,
StateUpdateLoopStartedLogEntry,
StateUpdateLoopNotRunningLogEntry,
StateUpdateLoopStoppedLogEntry,
StateUpdateErrorLogEntry,
StateUpdateEffectHandlerErrorLogEntry,
)
class QueueMapping(TypedDict):
MutatesTaskState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]]
MutatesControlPlaneState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]]
MutatesDataPlaneState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]]
MutatesInstanceState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]]
MutatesNodePerformanceState: Queue[EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]]
MutatesTaskState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
]
MutatesControlPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
]
MutatesDataPlaneState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
]
MutatesInstanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
]
MutatesNodePerformanceState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
]
MutatesRunnerStatus: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
]
MutatesTaskSagaState: Queue[
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
]
def check_keys_in_map_match_enum_values[TEnum: StrEnum](
mapping_type: Type[Mapping[Any, Any]],
@@ -44,8 +62,10 @@ def check_keys_in_map_match_enum_values[TEnum: StrEnum](
f"StateDomainMapping keys {mapping_keys} do not match EventCategories values {category_values}"
)
check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum)
class AsyncUpdateStateFromEvents[EventCategoryT: EventCategory](Protocol):
"""Protocol for services that manage a specific state domain."""
@@ -119,7 +139,7 @@ class AsyncUpdateStateFromEvents[EventCategoryT: EventCategory](Protocol):
raise e
try:
for effect_handler in self._default_effects + self.extra_effects:
effect_handler((previous_state, event), updated_state)
effect_handler(StateAndEvent(previous_state, event), updated_state)
except Exception as e:
log(self._logger, StateUpdateEffectHandlerErrorLogEntry(error=e))
raise e
@@ -149,7 +169,9 @@ class EventRouter:
await self.queue_map[category].put(event_to_process)
return None
async def _submit_events(self, events: list[Event[EventCategory | EventCategories]]) -> None:
async def _submit_events(
self, events: list[Event[EventCategory | EventCategories]]
) -> None:
"""Route multiple events to their appropriate services."""
for event in events:
for category in event.event_category:

View File

@@ -1,5 +1,5 @@
from typing import Literal
from collections.abc import Set
from typing import Literal
from shared.logging.common import LogEntry, LogEntryType

View File

@@ -1,35 +1,41 @@
from asyncio import CancelledError, Lock, Queue, Task, create_task
from contextlib import asynccontextmanager
from enum import Enum
from logging import Logger, LogRecord
from typing import Annotated, Literal
from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, TypeAdapter
from logging import Logger
from shared.types.events.common import Event, EventCategories, EventFetcherProtocol, EventPublisher, State
from shared.logger import (
configure_logger,
LogEntryType,
FilterLogByType,
create_queue_listener,
attach_to_queue,
from master.env import MasterEnvironmentSchema
from master.event_routing import AsyncUpdateStateFromEvents
from master.logging import (
MasterCommandReceivedLogEntry,
MasterInvalidCommandReceivedLogEntry,
MasterUninitializedLogEntry,
)
from shared.constants import EXO_MASTER_STATE
from shared.logger import (
FilterLogByType,
LogEntryType,
attach_to_queue,
configure_logger,
create_queue_listener,
log,
)
from shared.types.events.common import (
Event,
EventCategories,
EventFetcherProtocol,
EventPublisher,
State,
)
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import Instance
from shared.types.models.common import ModelId
from shared.types.models.model import ModelInfo
from shared.types.states.master import MasterState
from shared.constants import EXO_MASTER_STATE
from contextlib import asynccontextmanager
from logging import LogRecord
from typing import Annotated, Literal
from master.env import MasterEnvironmentSchema
from master.logging import (
MasterUninitializedLogEntry,
MasterCommandReceivedLogEntry,
MasterInvalidCommandReceivedLogEntry,
)
from master.event_routing import AsyncUpdateStateFromEvents
from shared.logger import log
from asyncio import Lock, Task, CancelledError, Queue, create_task
from enum import Enum
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import Instance
# Restore State
@@ -76,6 +82,7 @@ ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalComman
class MasterBackgroundServices(str, Enum):
MAIN_LOOP = "main_loop"
class StateManager[T: EventCategories]:
state: State[T]
queue: Queue[Event[T]]
@@ -85,8 +92,8 @@ class StateManager[T: EventCategories]:
self,
state: State[T],
queue: Queue[Event[T]],
) -> None:
...
) -> None: ...
class MasterStateManager:
"""Thread-safe manager for MasterState with independent event loop."""
@@ -126,7 +133,9 @@ class MasterStateManager:
case MasterBackgroundServices.MAIN_LOOP:
if self._services[service]:
raise RuntimeError("State manager is already running")
self._services[service]: Task[None] = create_task(self._event_loop())
self._services[service]: Task[None] = create_task(
self._event_loop()
)
log(self._logger, MasterStateManagerStartedLogEntry())
case _:
raise ValueError(f"Unknown service: {service}")
@@ -155,7 +164,7 @@ class MasterStateManager:
events_one = self._event_processor.get_events_to_apply(
self._state.data_plane_network_state
)
case EventCategories.InstanceStateEventTypes:
case EventCategories.InstanceEventTypes:
events_one = self._event_processor.get_events_to_apply(
self._state.control_plane_network_state
)

View File

@@ -1,5 +1,5 @@
from pathlib import Path
import inspect
from pathlib import Path
EXO_HOME = Path.home() / ".exo"
EXO_EVENT_DB = EXO_HOME / "event_db.sqlite3"

View File

@@ -1,18 +1,18 @@
from typing import Set, Mapping
from dataclasses import dataclass
from pydantic import TypeAdapter
from typing import Mapping, Set
import rustworkx as rx
from pydantic import TypeAdapter
from shared.types.graphs.common import (
Edge,
EdgeData,
EdgeIdT,
EdgeTypeT,
MutableGraphProtocol,
Vertex,
VertexData,
EdgeIdT,
VertexIdT,
EdgeTypeT,
VertexTypeT,
)

View File

@@ -2,14 +2,13 @@ import logging
import logging.handlers
from collections.abc import Sequence, Set
from queue import Queue
from typing import Annotated
from pydantic import Field, TypeAdapter
from rich.logging import RichHandler
from typing import Annotated
from pydantic import Field, TypeAdapter
from shared.logging.common import LogEntryType
from master.logging import MasterLogEntries
from shared.logging.common import LogEntryType
from worker.logging import WorkerLogEntries
LogEntries = Annotated[

View File

@@ -1,8 +1,8 @@
from collections.abc import Set
from enum import Enum
from typing import Generic, TypeVar
from pydantic import BaseModel
from collections.abc import Set
from pydantic import BaseModel
LogEntryTypeT = TypeVar("LogEntryTypeT", bound=str)

View File

@@ -1,24 +1,22 @@
from enum import Enum, StrEnum
from typing import (
Annotated,
Any,
Callable,
FrozenSet,
Literal,
NamedTuple,
Protocol,
Sequence,
cast,
)
import annotated_types
from shared.types.events.sanity_checking import (
check_event_type_union_is_consistent_with_registry,
assert_literal_union_covers_enum,
)
from pydantic import BaseModel, Field, model_validator
from shared.types.common import NewUUID, NodeId
from typing import Callable, Sequence, Protocol
from shared.types.events.sanity_checking import (
assert_literal_union_covers_enum,
check_event_type_union_is_consistent_with_registry,
)
class EventId(NewUUID):
@@ -31,7 +29,7 @@ class TimerId(NewUUID):
# Here are all the unique kinds of events that can be sent over the network.
# I've defined them in different enums for clarity, but they're all part of the same set of possible events.
class MLXEventTypes(str, Enum):
class TaskSagaEventTypes(str, Enum):
MLXInferenceSagaPrepare = "MLXInferenceSagaPrepare"
MLXInferenceSagaStartPrepare = "MLXInferenceSagaStartPrepare"
@@ -54,8 +52,8 @@ class InstanceEventTypes(str, Enum):
InstanceReplacedAtomically = "InstanceReplacedAtomically"
class InstanceStateEventTypes(str, Enum):
InstanceSagaRunnerStateUpdated = "InstanceSagaRunnerStateUpdated"
class RunnerStatusEventTypes(str, Enum):
RunnerStatusUpdated = "RunnerStatusUpdated"
class NodePerformanceEventTypes(str, Enum):
@@ -84,12 +82,12 @@ EVENT_TYPE_ENUMS = [
TaskEventTypes,
StreamingEventTypes,
InstanceEventTypes,
InstanceStateEventTypes,
RunnerStatusEventTypes,
NodePerformanceEventTypes,
DataPlaneEventTypes,
ControlPlaneEventTypes,
TimerEventTypes,
MLXEventTypes,
TaskSagaEventTypes,
]
@@ -98,12 +96,12 @@ EventTypes = (
TaskEventTypes
| StreamingEventTypes
| InstanceEventTypes
| InstanceStateEventTypes
| RunnerStatusEventTypes
| NodePerformanceEventTypes
| ControlPlaneEventTypes
| DataPlaneEventTypes
| TimerEventTypes
| MLXEventTypes
| TaskSagaEventTypes
)
@@ -112,6 +110,8 @@ check_event_type_union_is_consistent_with_registry(EVENT_TYPE_ENUMS, EventTypes)
class EventCategoryEnum(StrEnum):
MutatesTaskState = "MutatesTaskState"
MutatesRunnerStatus = "MutatesRunnerStatus"
MutatesTaskSagaState = "MutatesTaskSagaState"
MutatesInstanceState = "MutatesInstanceState"
MutatesNodePerformanceState = "MutatesNodePerformanceState"
MutatesControlPlaneState = "MutatesControlPlaneState"
@@ -121,6 +121,8 @@ class EventCategoryEnum(StrEnum):
EventCategory = (
Literal[EventCategoryEnum.MutatesControlPlaneState]
| Literal[EventCategoryEnum.MutatesTaskState]
| Literal[EventCategoryEnum.MutatesTaskSagaState]
| Literal[EventCategoryEnum.MutatesRunnerStatus]
| Literal[EventCategoryEnum.MutatesInstanceState]
| Literal[EventCategoryEnum.MutatesNodePerformanceState]
| Literal[EventCategoryEnum.MutatesDataPlaneState]
@@ -130,6 +132,7 @@ EventCategories = FrozenSet[EventCategory]
assert_literal_union_covers_enum(EventCategory, EventCategoryEnum)
class Event[SetMembersT: EventCategories | EventCategory](BaseModel):
event_type: EventTypes
event_category: SetMembersT

View File

@@ -3,20 +3,20 @@ from __future__ import annotations
from typing import Literal, Tuple
from shared.types.common import NodeId
from shared.types.events.chunks import GenerationChunk
from shared.types.events.common import (
Event,
EventTypes,
EventCategoryEnum,
ControlPlaneEventTypes,
DataPlaneEventTypes,
Event,
EventCategoryEnum,
EventTypes,
InstanceEventTypes,
InstanceStateEventTypes,
MLXEventTypes,
NodePerformanceEventTypes,
RunnerStatusEventTypes,
StreamingEventTypes,
TaskEventTypes,
TaskSagaEventTypes,
)
from shared.types.events.chunks import GenerationChunk
from shared.types.networking.control_plane import (
ControlPlaneEdgeId,
ControlPlaneEdgeType,
@@ -37,7 +37,7 @@ from shared.types.tasks.common import (
)
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import InstanceParams, TypeOfInstance
from shared.types.worker.runners import RunnerId, RunnerState, RunnerStateType
from shared.types.worker.runners import RunnerId, RunnerStatus, RunnerStatusType
MLXEvent = Event[
frozenset(
@@ -101,22 +101,22 @@ class InstanceReplacedAtomically(Event[EventCategoryEnum.MutatesInstanceState]):
event_type: EventTypes = InstanceEventTypes.InstanceReplacedAtomically
instance_to_replace: InstanceId
new_instance_id: InstanceId
class InstanceSagaRunnerStateUpdated(Event[EventCategoryEnum.MutatesInstanceState]):
event_type: EventTypes = InstanceStateEventTypes.InstanceSagaRunnerStateUpdated
class RunnerStatusUpdated(Event[EventCategoryEnum.MutatesRunnerStatus]):
event_type: EventTypes = RunnerStatusEventTypes.RunnerStatusUpdated
instance_id: InstanceId
state_update: Tuple[RunnerId, RunnerState[RunnerStateType]]
state_update: Tuple[RunnerId, RunnerStatus[RunnerStatusType]]
class MLXInferenceSagaPrepare(Event[EventCategoryEnum.MutatesTaskState]):
event_type: EventTypes = MLXEventTypes.MLXInferenceSagaPrepare
class MLXInferenceSagaPrepare(Event[EventCategoryEnum.MutatesTaskSagaState]):
event_type: EventTypes = TaskSagaEventTypes.MLXInferenceSagaPrepare
task_id: TaskId
instance_id: InstanceId
class MLXInferenceSagaStartPrepare(Event[EventCategoryEnum.MutatesTaskState]):
event_type: EventTypes = MLXEventTypes.MLXInferenceSagaStartPrepare
class MLXInferenceSagaStartPrepare(Event[EventCategoryEnum.MutatesTaskSagaState]):
event_type: EventTypes = TaskSagaEventTypes.MLXInferenceSagaStartPrepare
task_id: TaskId
instance_id: InstanceId

View File

@@ -1,41 +1,41 @@
from typing import Any, Mapping, Type, get_args
from types import UnionType
from typing import Annotated, Any, Mapping, Type, get_args
from pydantic import Field, TypeAdapter
from shared.constants import EXO_ERROR_REPORTING_MESSAGE
from shared.types.events.common import (
ControlPlaneEventTypes,
DataPlaneEventTypes,
Event,
EventCategories,
EventTypes,
TaskEventTypes,
InstanceEventTypes,
NodePerformanceEventTypes,
ControlPlaneEventTypes,
RunnerStatusEventTypes,
StreamingEventTypes,
DataPlaneEventTypes,
MLXEventTypes,
InstanceStateEventTypes,
TaskEventTypes,
TaskSagaEventTypes,
)
from shared.types.events.events import (
TaskCreated,
TaskStateUpdated,
TaskDeleted,
ChunkGenerated,
DataPlaneEdgeCreated,
DataPlaneEdgeDeleted,
DataPlaneEdgeReplacedAtomically,
InstanceCreated,
InstanceDeleted,
InstanceReplacedAtomically,
InstanceSagaRunnerStateUpdated,
NodePerformanceMeasured,
WorkerConnected,
WorkerStatusUpdated,
WorkerDisconnected,
ChunkGenerated,
DataPlaneEdgeCreated,
DataPlaneEdgeReplacedAtomically,
DataPlaneEdgeDeleted,
MLXInferenceSagaPrepare,
MLXInferenceSagaStartPrepare,
NodePerformanceMeasured,
RunnerStatusUpdated,
TaskCreated,
TaskDeleted,
TaskStateUpdated,
WorkerConnected,
WorkerDisconnected,
WorkerStatusUpdated,
)
from pydantic import TypeAdapter
from typing import Annotated
from pydantic import Field
from shared.types.events.common import EventCategories
"""
class EventTypeNames(StrEnum):
@@ -58,7 +58,7 @@ EventRegistry: Mapping[EventTypes, Type[Any]] = {
InstanceEventTypes.InstanceCreated: InstanceCreated,
InstanceEventTypes.InstanceDeleted: InstanceDeleted,
InstanceEventTypes.InstanceReplacedAtomically: InstanceReplacedAtomically,
InstanceStateEventTypes.InstanceSagaRunnerStateUpdated: InstanceSagaRunnerStateUpdated,
RunnerStatusEventTypes.RunnerStatusUpdated: RunnerStatusUpdated,
NodePerformanceEventTypes.NodePerformanceMeasured: NodePerformanceMeasured,
ControlPlaneEventTypes.WorkerConnected: WorkerConnected,
ControlPlaneEventTypes.WorkerStatusUpdated: WorkerStatusUpdated,
@@ -67,8 +67,8 @@ EventRegistry: Mapping[EventTypes, Type[Any]] = {
DataPlaneEventTypes.DataPlaneEdgeCreated: DataPlaneEdgeCreated,
DataPlaneEventTypes.DataPlaneEdgeReplacedAtomically: DataPlaneEdgeReplacedAtomically,
DataPlaneEventTypes.DataPlaneEdgeDeleted: DataPlaneEdgeDeleted,
MLXEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
MLXEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
TaskSagaEventTypes.MLXInferenceSagaPrepare: MLXInferenceSagaPrepare,
TaskSagaEventTypes.MLXInferenceSagaStartPrepare: MLXInferenceSagaStartPrepare,
}
@@ -86,9 +86,7 @@ def check_registry_has_all_event_types() -> None:
def check_union_of_all_events_is_consistent_with_registry(
registry: Mapping[EventTypes, Type[Any]], union_type: UnionType
) -> None:
type_of_each_registry_entry = set(
type(event_type) for event_type in registry.keys()
)
type_of_each_registry_entry = set(type(event_type) for event_type in registry)
type_of_each_entry_in_union = set(get_args(union_type))
missing_from_union = type_of_each_registry_entry - type_of_each_entry_in_union
@@ -112,7 +110,7 @@ AllEvents = (
| InstanceCreated
| InstanceDeleted
| InstanceReplacedAtomically
| InstanceSagaRunnerStateUpdated
| RunnerStatusUpdated
| NodePerformanceMeasured
| WorkerConnected
| WorkerStatusUpdated

View File

@@ -1,6 +1,6 @@
from typing import LiteralString, Sequence, Set, Any, Type, get_args
from types import UnionType
from enum import Enum, StrEnum
from types import UnionType
from typing import Any, LiteralString, Sequence, Set, Type, get_args
from shared.constants import EXO_ERROR_REPORTING_MESSAGE

View File

@@ -1,3 +1,4 @@
from shared.graphs.networkx import NetworkXGraph
from shared.types.common import NodeId
from shared.types.networking.control_plane import ControlPlaneEdgeId
from shared.types.networking.data_plane import (
@@ -5,7 +6,6 @@ from shared.types.networking.data_plane import (
DataPlaneEdgeId,
)
from shared.types.worker.common import NodeStatus
from shared.graphs.networkx import NetworkXGraph
class DataPlaneTopology(

View File

@@ -4,19 +4,19 @@ from queue import Queue
from typing import Generic, Literal, TypeVar
from pydantic import BaseModel, TypeAdapter
from shared.types.worker.common import NodeStatus
from shared.types.common import NodeId
from shared.types.events.common import (
Event,
EventCategory,
EventCategoryEnum,
State,
)
from shared.types.graphs.resource_graph import ResourceGraph
from shared.types.networking.data_plane import (
DataPlaneEdge,
DataPlaneEdgeId,
DataPlaneEdgeAdapter,
DataPlaneEdgeId,
)
from shared.types.networking.topology import (
ControlPlaneTopology,
@@ -27,7 +27,8 @@ from shared.types.networking.topology import (
from shared.types.profiling.common import NodePerformanceProfile
from shared.types.states.shared import SharedState
from shared.types.tasks.common import TaskParams, TaskType
from shared.types.worker.instances import InstanceParams, InstanceId
from shared.types.worker.common import NodeStatus
from shared.types.worker.instances import InstanceId, InstanceParams
class ExternalCommand(BaseModel): ...
@@ -44,13 +45,13 @@ class CachePolicy(BaseModel, Generic[CachePolicyTypeT]):
policy_type: CachePolicyTypeT
class NodePerformanceProfileState(State[EventCategory.MutatesNodePerformanceState]):
class NodePerformanceProfileState(State[EventCategoryEnum.MutatesNodePerformanceState]):
node_profiles: Mapping[NodeId, NodePerformanceProfile]
class DataPlaneNetworkState(State[EventCategory.MutatesDataPlaneState]):
event_category: Literal[EventCategory.MutatesDataPlaneState] = (
EventCategory.MutatesDataPlaneState
class DataPlaneNetworkState(State[EventCategoryEnum.MutatesDataPlaneState]):
event_category: Literal[EventCategoryEnum.MutatesDataPlaneState] = (
EventCategoryEnum.MutatesDataPlaneState
)
topology: DataPlaneTopology = DataPlaneTopology(
edge_base=DataPlaneEdgeAdapter, vertex_base=TypeAdapter(None)
@@ -61,9 +62,9 @@ class DataPlaneNetworkState(State[EventCategory.MutatesDataPlaneState]):
def add_edge(self, edge: DataPlaneEdge) -> None: ...
class ControlPlaneNetworkState(State[EventCategory.MutatesControlPlaneState]):
event_category: Literal[EventCategory.MutatesControlPlaneState] = (
EventCategory.MutatesControlPlaneState
class ControlPlaneNetworkState(State[EventCategoryEnum.MutatesControlPlaneState]):
event_category: Literal[EventCategoryEnum.MutatesControlPlaneState] = (
EventCategoryEnum.MutatesControlPlaneState
)
topology: ControlPlaneTopology = ControlPlaneTopology(
edge_base=TypeAdapter(None), vertex_base=TypeAdapter(NodeStatus)

View File

@@ -4,29 +4,52 @@ from typing import Literal, Sequence
from pydantic import BaseModel
from shared.types.common import NodeId
from shared.types.events.common import EventCategories, State
from shared.types.tasks.common import Task, TaskId, TaskStatusType, TaskType
from shared.types.events.common import EventCategoryEnum, State
from shared.types.tasks.common import (
Task,
TaskId,
TaskSagaEntry,
TaskStatusType,
TaskType,
)
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import BaseInstance
from shared.types.worker.runners import RunnerId, RunnerStatus, RunnerStatusType
class KnownInstances(State[EventCategories.InstanceStateEventTypes]):
event_category: Literal[EventCategories.InstanceStateEventTypes] = (
EventCategories.InstanceStateEventTypes
class Instances(State[EventCategoryEnum.MutatesInstanceState]):
event_category: Literal[EventCategoryEnum.MutatesInstanceState] = (
EventCategoryEnum.MutatesInstanceState
)
instances: Mapping[InstanceId, BaseInstance] = {}
class Tasks(State[EventCategories.TaskEventTypes]):
event_category: Literal[EventCategories.TaskEventTypes] = (
EventCategories.TaskEventTypes
class Tasks(State[EventCategoryEnum.MutatesTaskState]):
event_category: Literal[EventCategoryEnum.MutatesTaskState] = (
EventCategoryEnum.MutatesTaskState
)
tasks: Mapping[TaskId, Task[TaskType, TaskStatusType]] = {}
class TaskSagas(State[EventCategoryEnum.MutatesTaskSagaState]):
event_category: Literal[EventCategoryEnum.MutatesTaskSagaState] = (
EventCategoryEnum.MutatesTaskSagaState
)
task_sagas: Mapping[TaskId, Sequence[TaskSagaEntry]] = {}
class Runners(State[EventCategoryEnum.MutatesRunnerStatus]):
event_category: Literal[EventCategoryEnum.MutatesRunnerStatus] = (
EventCategoryEnum.MutatesRunnerStatus
)
runner_statuses: Mapping[RunnerId, RunnerStatus[RunnerStatusType]] = {}
class SharedState(BaseModel):
known_instances: KnownInstances = KnownInstances()
compute_tasks: Tasks = Tasks()
instances: Instances = Instances()
runners: Runners = Runners()
tasks: Tasks = Tasks()
task_sagas: TaskSagas = TaskSagas()
def get_node_id(self) -> NodeId: ...

View File

@@ -83,7 +83,7 @@ class TaskState[TaskStatusTypeT: TaskStatusType, TaskTypeT: TaskType](BaseModel)
class BaseTask[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](BaseModel):
task_type: TaskTypeT
task_params: TaskParams[TaskTypeT]
task_state: TaskState[TaskStatusTypeT, TaskTypeT]
task_stats: TaskState[TaskStatusTypeT, TaskTypeT]
on_instance: InstanceId
@@ -100,6 +100,11 @@ BaseTaskParser: TypeAdapter[BaseTask[TaskType, TaskStatusType]] = TypeAdapter(
)
class TaskSagaEntry(BaseModel):
task_id: TaskId
instance_id: InstanceId
@final
class Task[TaskTypeT: TaskType, TaskStatusTypeT: TaskStatusType](
BaseTask[TaskTypeT, TaskStatusTypeT]

View File

@@ -1,13 +1,9 @@
from collections.abc import Mapping
from enum import Enum
from pydantic import BaseModel
from shared.types.worker.common import InstanceId
from shared.types.worker.runners import (
RunnerId,
RunnerState,
RunnerStateType,
ShardAssignments,
)
@@ -28,11 +24,3 @@ class BaseInstance(BaseModel):
class Instance(BaseInstance):
instance_id: InstanceId
class BaseInstanceSaga(BaseModel):
runner_states: Mapping[RunnerId, RunnerState[RunnerStateType]]
class InstanceSaga(BaseInstanceSaga):
instance_id: InstanceId

View File

@@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Generic, Literal, TypeVar, Annotated
from typing import Annotated, Generic, Literal, TypeVar
from pydantic import BaseModel, Field, TypeAdapter, model_validator
@@ -11,7 +11,7 @@ from shared.types.worker.downloads import BaseDownloadProgress, DownloadStatus
from shared.types.worker.shards import PartitionStrategy, ShardMetadata
class RunnerStateType(str, Enum):
class RunnerStatusType(str, Enum):
Rejected = "Rejected"
Starting = "Starting"
Downloading = "Downloading"
@@ -19,44 +19,46 @@ class RunnerStateType(str, Enum):
Failed = "Failed"
RunnerStateTypeT = TypeVar("RunnerStateTypeT", bound=RunnerStateType)
RunnerStatusTypeT = TypeVar("RunnerStatusTypeT", bound=RunnerStatusType)
class RunnerState(BaseModel, Generic[RunnerStateTypeT]):
runner_state: RunnerStateTypeT
class RunnerStatus(BaseModel, Generic[RunnerStatusTypeT]):
runner_status: RunnerStatusTypeT
class RejectedRunnerState(RunnerState[RunnerStateType.Rejected]):
runner_state: Literal[RunnerStateType.Rejected]
class RejectedRunnerStatus(RunnerStatus[RunnerStatusType.Rejected]):
runner_status: Literal[RunnerStatusType.Rejected]
class StartingRunnerState(RunnerState[RunnerStateType.Starting]):
runner_state: Literal[RunnerStateType.Starting]
class StartingRunnerStatus(RunnerStatus[RunnerStatusType.Starting]):
runner_status: Literal[RunnerStatusType.Starting]
class DownloadingRunnerState(RunnerState[RunnerStateType.Downloading]):
runner_state: Literal[RunnerStateType.Downloading]
class DownloadingRunnerStatus(RunnerStatus[RunnerStatusType.Downloading]):
runner_status: Literal[RunnerStatusType.Downloading]
download_progress: BaseDownloadProgress[DownloadStatus]
class RunningRunnerState(RunnerState[RunnerStateType.Running]):
runner_state: Literal[RunnerStateType.Running]
class RunningRunnerStatus(RunnerStatus[RunnerStatusType.Running]):
runner_status: Literal[RunnerStatusType.Running]
class FailedRunnerState(RunnerState[RunnerStateType.Failed]):
runner_state: Literal[RunnerStateType.Failed]
class FailedRunnerStatus(RunnerStatus[RunnerStatusType.Failed]):
runner_status: Literal[RunnerStatusType.Failed]
error_message: str | None = None
_RunnerState = Annotated[
RejectedRunnerState
| StartingRunnerState
| DownloadingRunnerState
| RunningRunnerState
| FailedRunnerState,
_RunnerStatus = Annotated[
RejectedRunnerStatus
| StartingRunnerStatus
| DownloadingRunnerStatus
| RunningRunnerStatus
| FailedRunnerStatus,
Field,
]
RunnerStateParser: TypeAdapter[RunnerState[RunnerStateType]] = TypeAdapter(_RunnerState)
RunnerStatusParser: TypeAdapter[RunnerStatus[RunnerStatusType]] = TypeAdapter(
_RunnerStatus
)
class ShardAssignments(BaseModel):

View File

@@ -1,5 +1,5 @@
from typing import Literal
from collections.abc import Set
from typing import Literal
from shared.logging.common import LogEntry, LogEntryType
@@ -10,4 +10,4 @@ class WorkerUninitialized(LogEntry[Literal["master_uninitialized"]]):
message: str = "No master state found, creating new one."
WorkerLogEntries = WorkerUninitialized
WorkerLogEntries = WorkerUninitialized