mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
tweak
This commit is contained in:
@@ -17,7 +17,6 @@ from shared.logger import log
|
||||
from shared.types.events.common import (
|
||||
Apply,
|
||||
EffectHandler,
|
||||
Event,
|
||||
EventCategories,
|
||||
EventCategory,
|
||||
EventCategoryEnum,
|
||||
@@ -25,6 +24,7 @@ from shared.types.events.common import (
|
||||
EventFromEventLog,
|
||||
StateAndEvent,
|
||||
State,
|
||||
narrow_event_from_event_log_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,24 +32,24 @@ class QueueMapping(TypedDict):
|
||||
MutatesTaskState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskState]]
|
||||
]
|
||||
MutatesTaskSagaState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
|
||||
]
|
||||
MutatesControlPlaneState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesControlPlaneState]]
|
||||
]
|
||||
MutatesDataPlaneState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesDataPlaneState]]
|
||||
]
|
||||
MutatesRunnerStatus: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
|
||||
]
|
||||
MutatesInstanceState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesInstanceState]]
|
||||
]
|
||||
MutatesNodePerformanceState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesNodePerformanceState]]
|
||||
]
|
||||
MutatesRunnerStatus: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesRunnerStatus]]
|
||||
]
|
||||
MutatesTaskSagaState: Queue[
|
||||
EventFromEventLog[Literal[EventCategoryEnum.MutatesTaskSagaState]]
|
||||
]
|
||||
|
||||
|
||||
def check_keys_in_map_match_enum_values[TEnum: StrEnum](
|
||||
@@ -154,32 +154,40 @@ class EventRouter:
|
||||
|
||||
async def _get_queue_by_category[T: EventCategory](
|
||||
self, category: T
|
||||
) -> Queue[Event[T]]:
|
||||
) -> Queue[EventFromEventLog[T]]:
|
||||
"""Get the queue for a given category."""
|
||||
category_str: str = category.value
|
||||
queue: Queue[Event[T]] = self.queue_map[category_str]
|
||||
queue: Queue[EventFromEventLog[T]] = self.queue_map[category_str]
|
||||
return queue
|
||||
|
||||
async def _process_events[T: EventCategory](self, category: T) -> None:
|
||||
"""Process events for a given domain."""
|
||||
queue: Queue[Event[T]] = await self._get_queue_by_category(category)
|
||||
events_to_process: list[Event[T]] = []
|
||||
queue: Queue[EventFromEventLog[T]] = await self._get_queue_by_category(category)
|
||||
events_to_process: list[EventFromEventLog[T]] = []
|
||||
while not queue.empty():
|
||||
events_to_process.append(await queue.get())
|
||||
for event_to_process in events_to_process:
|
||||
await self.queue_map[category].put(event_to_process)
|
||||
await self.queue_map[category.value].put(event_to_process)
|
||||
return None
|
||||
|
||||
async def _submit_events(
|
||||
self, events: list[Event[EventCategory | EventCategories]]
|
||||
async def _submit_events[T: EventCategory | EventCategories](
|
||||
self, events: list[EventFromEventLog[T]]
|
||||
) -> None:
|
||||
"""Route multiple events to their appropriate services."""
|
||||
for event in events:
|
||||
for category in event.event_category:
|
||||
await self._event_queues[category].put(event)
|
||||
|
||||
if isinstance(event.event.event_category, EventCategory):
|
||||
q1: Queue[EventFromEventLog[T]] = self.queue_map[event.event.event_category.value]
|
||||
await q1.put(event)
|
||||
elif isinstance(event.event.event_category, EventCategories):
|
||||
for category in event.event.event_category:
|
||||
narrow_event = narrow_event_from_event_log_type(event, category)
|
||||
q2: Queue[EventFromEventLog[T]] = self.queue_map[category.value]
|
||||
await q2.put(narrow_event)
|
||||
|
||||
await gather(
|
||||
*[self._process_events(domain) for domain in self._event_queues.keys()]
|
||||
*[self._process_events(domain) for domain in EventCategoryEnum]
|
||||
)
|
||||
|
||||
async def _get_events_to_process(self) -> list[Event[EventCategories]]:
|
||||
async def _get_events_to_process(self) -> list[EventFromEventLog[EventCategories | EventCategory]]:
|
||||
"""Get events to process from the event fetcher."""
|
||||
...
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from hashlib import sha3_224 as hasher
|
||||
from typing import Sequence, TypeVar
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from shared.types.events.common import EventCategories, EventId, IdemKeyGenerator, State
|
||||
|
||||
EventCategoryT = TypeVar("EventCategoryT", bound=EventCategories)
|
||||
from shared.types.events.common import EventCategory, EventId, IdemKeyGenerator, State
|
||||
|
||||
|
||||
def get_idem_tag_generator(base: str) -> IdemKeyGenerator[EventCategoryT]:
|
||||
def get_idem_tag_generator[EventCategoryT: EventCategory](base: str) -> IdemKeyGenerator[EventCategoryT]:
|
||||
"""Generates idempotency keys for events.
|
||||
|
||||
The keys are generated by hashing the state sequence number against a base string.
|
||||
@@ -24,7 +22,7 @@ def get_idem_tag_generator(base: str) -> IdemKeyGenerator[EventCategoryT]:
|
||||
*recurse(n - 1, next_hash),
|
||||
)
|
||||
|
||||
initial_bytes = state.sequence_number.to_bytes(8, byteorder="big", signed=False)
|
||||
initial_bytes = state.last_event_applied_idx.to_bytes(8, byteorder="big", signed=False)
|
||||
return recurse(num_keys, initial_bytes)
|
||||
|
||||
return get_idem_keys
|
||||
|
||||
112
master/main.py
112
master/main.py
@@ -1,15 +1,17 @@
|
||||
from asyncio import CancelledError, Lock, Queue, Task, create_task
|
||||
from asyncio import CancelledError, Lock, Task, create_task
|
||||
from asyncio import Queue as AsyncQueue
|
||||
from queue import Queue as PQueue
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from logging import Logger, LogRecord
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, Literal, Type
|
||||
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from master.env import MasterEnvironmentSchema
|
||||
from master.event_routing import AsyncUpdateStateFromEvents
|
||||
from master.event_routing import AsyncUpdateStateFromEvents, QueueMapping
|
||||
from master.logging import (
|
||||
MasterCommandReceivedLogEntry,
|
||||
MasterInvalidCommandReceivedLogEntry,
|
||||
@@ -26,7 +28,7 @@ from shared.logger import (
|
||||
)
|
||||
from shared.types.events.common import (
|
||||
Event,
|
||||
EventCategories,
|
||||
EventCategory,
|
||||
EventFetcherProtocol,
|
||||
EventPublisher,
|
||||
State,
|
||||
@@ -83,15 +85,15 @@ class MasterBackgroundServices(str, Enum):
|
||||
MAIN_LOOP = "main_loop"
|
||||
|
||||
|
||||
class StateManager[T: EventCategories]:
|
||||
class StateManager[T: EventCategory]:
|
||||
state: State[T]
|
||||
queue: Queue[Event[T]]
|
||||
queue: AsyncQueue[Event[T]]
|
||||
manager: AsyncUpdateStateFromEvents[T]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: State[T],
|
||||
queue: Queue[Event[T]],
|
||||
queue: AsyncQueue[Event[T]],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@@ -101,51 +103,50 @@ class MasterStateManager:
|
||||
def __init__(
|
||||
self,
|
||||
initial_state: MasterState,
|
||||
event_processor: EventFetcherProtocol[EventCategories],
|
||||
event_publisher: EventPublisher[EventCategories],
|
||||
event_processor: EventFetcherProtocol[EventCategory],
|
||||
event_publisher: EventPublisher[EventCategory],
|
||||
state_updater: dict[EventCategory, AsyncUpdateStateFromEvents[EventCategory]],
|
||||
logger: Logger,
|
||||
):
|
||||
self._state = initial_state
|
||||
self._state_lock = Lock()
|
||||
self._command_queue: Queue[ExternalCommand] = Queue()
|
||||
self._services: dict[MasterBackgroundServices, Task[None]] = {}
|
||||
self._command_runner: Task[None] | None = None
|
||||
self._command_queue: AsyncQueue[ExternalCommand] = AsyncQueue()
|
||||
self._response_queue: AsyncQueue[Response | StreamingResponse] = AsyncQueue()
|
||||
self._state_managers: dict[EventCategory, AsyncUpdateStateFromEvents[EventCategory]] = {}
|
||||
self._asyncio_tasks: dict[EventCategory, Task[None]] = {}
|
||||
self._logger = logger
|
||||
|
||||
async def read_state(self) -> MasterState:
|
||||
"""Get a thread-safe snapshot of the current state."""
|
||||
async with self._state_lock:
|
||||
return self._state.model_copy(deep=True)
|
||||
@property
|
||||
def _is_command_runner_running(self) -> bool:
|
||||
return self._command_runner is not None and not self._command_runner.done()
|
||||
|
||||
async def send_command(
|
||||
self, command: ExternalCommand
|
||||
) -> Response | StreamingResponse:
|
||||
"""Send a command to the background event loop."""
|
||||
if self._services[MasterBackgroundServices.MAIN_LOOP]:
|
||||
self._command_queue.put(command)
|
||||
return Response(status_code=200)
|
||||
if self._is_command_runner_running:
|
||||
await self._command_queue.put(command)
|
||||
return await self._response_queue.get()
|
||||
else:
|
||||
raise RuntimeError("State manager is not running")
|
||||
log(self._logger, MasterCommandRunnerNotRunningLogEntry())
|
||||
raise RuntimeError("Command Runner Is Not Running")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the background event loop."""
|
||||
for service in MasterBackgroundServices:
|
||||
match service:
|
||||
case MasterBackgroundServices.MAIN_LOOP:
|
||||
if self._services[service]:
|
||||
raise RuntimeError("State manager is already running")
|
||||
self._services[service]: Task[None] = create_task(
|
||||
self._event_loop()
|
||||
)
|
||||
log(self._logger, MasterStateManagerStartedLogEntry())
|
||||
case _:
|
||||
raise ValueError(f"Unknown service: {service}")
|
||||
for category in self._state_managers:
|
||||
self._asyncio_tasks[category] = create_task(
|
||||
self._state_managers[category].start()
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the background event loop and persist state."""
|
||||
if not self._services[MasterBackgroundServices.MAIN_LOOP]:
|
||||
raise RuntimeError("State manager is not running")
|
||||
if not self._is_command_runner_running:
|
||||
raise RuntimeError("Command Runner Is Not Running")
|
||||
|
||||
for service in self._services.values():
|
||||
assert self._command_runner is not None
|
||||
|
||||
for service in [*self._asyncio_tasks.values(), self._command_runner]:
|
||||
service.cancel()
|
||||
try:
|
||||
await service
|
||||
@@ -154,53 +155,14 @@ class MasterStateManager:
|
||||
|
||||
log(self._logger, MasterStateManagerStoppedLogEntry())
|
||||
|
||||
async def _event_loop(self) -> None:
|
||||
"""Independent event loop for processing commands and mutating state."""
|
||||
while True:
|
||||
try:
|
||||
async with self._state_lock:
|
||||
match EventCategories:
|
||||
case EventCategories.InstanceEventTypes:
|
||||
events_one = self._event_processor.get_events_to_apply(
|
||||
self._state.data_plane_network_state
|
||||
)
|
||||
case EventCategories.InstanceEventTypes:
|
||||
events_one = self._event_processor.get_events_to_apply(
|
||||
self._state.control_plane_network_state
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unknown event category: {event_category}"
|
||||
)
|
||||
command = self._command_queue.get(timeout=5.0)
|
||||
match command:
|
||||
case ChatCompletionNonStreamingCommand():
|
||||
log(
|
||||
self._logger,
|
||||
MasterCommandReceivedLogEntry(
|
||||
command_name=command.command_type
|
||||
),
|
||||
)
|
||||
case _:
|
||||
log(
|
||||
self._logger,
|
||||
MasterInvalidCommandReceivedLogEntry(
|
||||
command_name=command.command_type
|
||||
),
|
||||
)
|
||||
except CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
log(self._logger, MasterStateManagerErrorLogEntry(error=str(e)))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger = configure_logger("master")
|
||||
|
||||
telemetry_queue: Queue[LogRecord] = Queue()
|
||||
metrics_queue: Queue[LogRecord] = Queue()
|
||||
cluster_queue: Queue[LogRecord] = Queue()
|
||||
telemetry_queue: PQueue[LogRecord] = PQueue()
|
||||
metrics_queue: PQueue[LogRecord] = PQueue()
|
||||
cluster_queue: PQueue[LogRecord] = PQueue()
|
||||
|
||||
attach_to_queue(
|
||||
logger,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from enum import Enum, StrEnum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
FrozenSet,
|
||||
Literal,
|
||||
@@ -110,8 +109,8 @@ check_event_type_union_is_consistent_with_registry(EVENT_TYPE_ENUMS, EventTypes)
|
||||
|
||||
class EventCategoryEnum(StrEnum):
|
||||
MutatesTaskState = "MutatesTaskState"
|
||||
MutatesRunnerStatus = "MutatesRunnerStatus"
|
||||
MutatesTaskSagaState = "MutatesTaskSagaState"
|
||||
MutatesRunnerStatus = "MutatesRunnerStatus"
|
||||
MutatesInstanceState = "MutatesInstanceState"
|
||||
MutatesNodePerformanceState = "MutatesNodePerformanceState"
|
||||
MutatesControlPlaneState = "MutatesControlPlaneState"
|
||||
@@ -155,8 +154,8 @@ class EventFromEventLog[SetMembersT: EventCategories | EventCategory](BaseModel)
|
||||
raise ValueError("Invalid Event: Origin ID Does Not Match")
|
||||
|
||||
|
||||
def narrow_event_type[T: EventCategory](
|
||||
event: Event[EventCategories],
|
||||
def narrow_event_type[T: EventCategory, Q: EventCategories | EventCategory](
|
||||
event: Event[Q],
|
||||
target_category: T,
|
||||
) -> Event[T]:
|
||||
if target_category not in event.event_category:
|
||||
@@ -165,6 +164,16 @@ def narrow_event_type[T: EventCategory](
|
||||
narrowed_event = event.model_copy(update={"event_category": {target_category}})
|
||||
return cast(Event[T], narrowed_event)
|
||||
|
||||
def narrow_event_from_event_log_type[T: EventCategory, Q: EventCategories | EventCategory](
|
||||
event: EventFromEventLog[Q],
|
||||
target_category: T,
|
||||
) -> EventFromEventLog[T]:
|
||||
if target_category not in event.event.event_category:
|
||||
raise ValueError(f"Event Does Not Contain Target Category {target_category}")
|
||||
narrowed_event = event.model_copy(update={"event": narrow_event_type(event.event, target_category)})
|
||||
|
||||
return cast(EventFromEventLog[T], narrowed_event)
|
||||
|
||||
|
||||
class State[EventCategoryT: EventCategory](BaseModel):
|
||||
event_category: EventCategoryT
|
||||
@@ -190,7 +199,7 @@ class StateAndEvent[EventCategoryT: EventCategory](NamedTuple):
|
||||
type EffectHandler[EventCategoryT: EventCategory] = Callable[
|
||||
[StateAndEvent[EventCategoryT], State[EventCategoryT]], None
|
||||
]
|
||||
type EventPublisher = Callable[[Event[Any]], None]
|
||||
type EventPublisher[EventCategoryT: EventCategory] = Callable[[Event[EventCategoryT]], None]
|
||||
|
||||
|
||||
# A component that can publish events
|
||||
@@ -207,7 +216,7 @@ class EventFetcherProtocol[EventCategoryT: EventCategory](Protocol):
|
||||
|
||||
# A component that can get the effect handler for a saga
|
||||
def get_saga_effect_handler[EventCategoryT: EventCategory](
|
||||
saga: Saga[EventCategoryT], event_publisher: EventPublisher
|
||||
saga: Saga[EventCategoryT], event_publisher: EventPublisher[EventCategoryT]
|
||||
) -> EffectHandler[EventCategoryT]:
|
||||
def effect_handler(state_and_event: StateAndEvent[EventCategoryT]) -> None:
|
||||
trigger_state, trigger_event = state_and_event
|
||||
@@ -219,7 +228,7 @@ def get_saga_effect_handler[EventCategoryT: EventCategory](
|
||||
|
||||
def get_effects_from_sagas[EventCategoryT: EventCategory](
|
||||
sagas: Sequence[Saga[EventCategoryT]],
|
||||
event_publisher: EventPublisher,
|
||||
event_publisher: EventPublisher[EventCategoryT],
|
||||
) -> Sequence[EffectHandler[EventCategoryT]]:
|
||||
return [get_saga_effect_handler(saga, event_publisher) for saga in sagas]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user