mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
more incomplete trash
This commit is contained in:
21
master/commands.py
Normal file
21
master/commands.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
|
||||
class BaseExternalCommand[T: str](BaseModel):
|
||||
command_type: T
|
||||
|
||||
|
||||
class ChatCompletionNonStreamingCommand(
|
||||
BaseExternalCommand[Literal["chat_completion_non_streaming"]]
|
||||
):
|
||||
command_type: Literal["chat_completion_non_streaming"] = (
|
||||
"chat_completion_non_streaming"
|
||||
)
|
||||
|
||||
|
||||
ExternalCommand = Annotated[
|
||||
ChatCompletionNonStreamingCommand, Field(discriminator="command_type")
|
||||
]
|
||||
ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalCommand)
|
||||
@@ -91,6 +91,8 @@ MasterLogEntries = (
|
||||
MasterUninitializedLogEntry
|
||||
| MasterCommandReceivedLogEntry
|
||||
| MasterInvalidCommandReceivedLogEntry
|
||||
| MasterCommandRunnerNotRunningLogEntry
|
||||
| MasterStateManagerStoppedLogEntry
|
||||
| EventCategoryUnknownLogEntry
|
||||
| StateUpdateLoopAlreadyRunningLogEntry
|
||||
| StateUpdateLoopStartedLogEntry
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
from asyncio import CancelledError, Lock, Task
|
||||
from asyncio import CancelledError, Lock, Task, create_task
|
||||
from asyncio import Queue as AsyncQueue
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, LogRecord
|
||||
from queue import Queue as PQueue
|
||||
from typing import Annotated, Literal
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from master.commands import ExternalCommand
|
||||
from master.env import MasterEnvironmentSchema
|
||||
from master.logging import (
|
||||
MasterCommandRunnerNotRunningLogEntry,
|
||||
MasterStateManagerStoppedLogEntry,
|
||||
MasterUninitializedLogEntry,
|
||||
)
|
||||
from master.router import QueueMapping
|
||||
from master.state_manager.sync import SyncStateManagerMapping
|
||||
from shared.constants import EXO_MASTER_STATE
|
||||
from shared.logger import (
|
||||
@@ -26,9 +27,11 @@ from shared.logger import (
|
||||
log,
|
||||
)
|
||||
from shared.types.events.common import (
|
||||
Apply,
|
||||
EventCategory,
|
||||
EventFetcherProtocol,
|
||||
EventFromEventLog,
|
||||
EventPublisher,
|
||||
State,
|
||||
)
|
||||
from shared.types.models.common import ModelId
|
||||
from shared.types.models.model import ModelInfo
|
||||
@@ -60,22 +63,18 @@ def get_master_state_dependency(data: object, logger: Logger) -> MasterState:
|
||||
return data
|
||||
|
||||
|
||||
class BaseExternalCommand[T: str](BaseModel):
|
||||
command_type: T
|
||||
|
||||
|
||||
class ChatCompletionNonStreamingCommand(
|
||||
BaseExternalCommand[Literal["chat_completion_non_streaming"]]
|
||||
):
|
||||
command_type: Literal["chat_completion_non_streaming"] = (
|
||||
"chat_completion_non_streaming"
|
||||
)
|
||||
|
||||
|
||||
ExternalCommand = Annotated[
|
||||
ChatCompletionNonStreamingCommand, Field(discriminator="command_type")
|
||||
]
|
||||
ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalCommand)
|
||||
# Safety on Apply.
|
||||
def safely_apply[T: EventCategory](
|
||||
state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]]
|
||||
) -> State[T]:
|
||||
sorted_events = sorted(events, key=lambda event: event.idx_in_log)
|
||||
state = state.model_copy()
|
||||
for event in sorted_events:
|
||||
if event.idx_in_log <= state.last_event_applied_idx:
|
||||
continue
|
||||
state.last_event_applied_idx = event.idx_in_log
|
||||
state = apply_fn(state, event)
|
||||
return state
|
||||
|
||||
|
||||
class MasterEventLoop:
|
||||
@@ -84,24 +83,27 @@ class MasterEventLoop:
|
||||
def __init__(
|
||||
self,
|
||||
initial_state: MasterState,
|
||||
event_processor: EventFetcherProtocol[EventCategory],
|
||||
push_events_to_queue: Callable[[QueueMapping], None],
|
||||
event_publisher: EventPublisher[EventCategory],
|
||||
state_managers: SyncStateManagerMapping,
|
||||
logger: Logger,
|
||||
):
|
||||
self._state = initial_state
|
||||
self._state_lock = Lock()
|
||||
self._command_runner: Task[None] | None = None
|
||||
self._event_queues: QueueMapping
|
||||
self._command_runner: ...
|
||||
self._command_run_task: Task[None] | None = None
|
||||
self._command_queue: AsyncQueue[ExternalCommand] = AsyncQueue()
|
||||
self._response_queue: AsyncQueue[Response | StreamingResponse] = AsyncQueue()
|
||||
self._state_managers: SyncStateManagerMapping
|
||||
self._event_fetcher: EventFetcherProtocol[EventCategory]
|
||||
self._state_global_lock: Lock = Lock()
|
||||
self._push_events_to_queue: Callable[[QueueMapping], None]
|
||||
self._event_fetch_task: Task[None] | None = None
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
def _is_command_runner_running(self) -> bool:
|
||||
return self._command_runner is not None and not self._command_runner.done()
|
||||
return self._command_run_task is not None and not self._command_run_task.done()
|
||||
|
||||
@property
|
||||
def _is_event_fetcher_running(self) -> bool:
|
||||
@@ -121,14 +123,26 @@ class MasterEventLoop:
|
||||
async def start(self) -> None:
|
||||
"""Start the background event loop."""
|
||||
|
||||
async def fetch_and_apply_events() -> None:
|
||||
while True:
|
||||
async with self._state_global_lock:
|
||||
for state in self._state_managers.values():
|
||||
self._push_events_to_queue(self._event_queues)
|
||||
safely_apply(
|
||||
state, apply_fn, self._event_queues[state.event_category]
|
||||
)
|
||||
|
||||
self._event_fetch_task = create_task(fetch_and_apply_events())
|
||||
self._command_run_task = create_task(self._command_runner())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the background event loop and persist state."""
|
||||
if not self._is_command_runner_running or not self._is_event_fetcher_running:
|
||||
raise RuntimeError("Command Runner Is Not Running")
|
||||
|
||||
assert self._command_runner is not None and self._event_fetch_task is not None
|
||||
assert self._command_run_task is not None and self._event_fetch_task is not None
|
||||
|
||||
for service in [self._event_fetch_task, self._command_runner]:
|
||||
for service in [self._event_fetch_task, self._command_run_task]:
|
||||
service.cancel()
|
||||
try:
|
||||
await service
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from asyncio import Queue, gather
|
||||
from logging import Logger
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, Protocol, TypedDict
|
||||
|
||||
from master.sanity_checking import check_keys_in_map_match_enum_values
|
||||
from shared.types.events.common import (
|
||||
EventCategories,
|
||||
EventCategory,
|
||||
EventCategoryEnum,
|
||||
EventFetcherProtocol,
|
||||
EventFromEventLog,
|
||||
narrow_event_from_event_log_type,
|
||||
)
|
||||
@@ -40,12 +39,19 @@ class QueueMapping(TypedDict):
|
||||
check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum)
|
||||
|
||||
|
||||
class EventRouter:
|
||||
class EventRouterProtocol(Protocol):
|
||||
queue_map: QueueMapping
|
||||
start_idx: int
|
||||
|
||||
def sync_queues(self) -> None: ...
|
||||
|
||||
|
||||
class EventRouter(EventRouterProtocol):
|
||||
"""Routes events to appropriate services based on event categories."""
|
||||
|
||||
queue_map: QueueMapping
|
||||
event_fetcher: EventFetcherProtocol[EventCategory]
|
||||
_logger: Logger
|
||||
start_idx: int
|
||||
logger: Logger
|
||||
|
||||
async def _get_queue_by_category[T: EventCategory](
|
||||
self, category: T
|
||||
@@ -82,8 +88,3 @@ class EventRouter:
|
||||
await q2.put(narrow_event)
|
||||
|
||||
await gather(*[self._process_events(domain) for domain in EventCategoryEnum])
|
||||
|
||||
async def _get_events_to_process(
|
||||
self,
|
||||
) -> list[EventFromEventLog[EventCategories | EventCategory]]:
|
||||
"""Get events to process from the event fetcher."""
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
def main():
|
||||
print("Hello from worker!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user