more incomplete trash

This commit is contained in:
Arbion Halili
2025-07-15 13:40:21 +01:00
parent 9f96b6791f
commit 7fa7de8e83
5 changed files with 74 additions and 42 deletions

21
master/commands.py Normal file
View File

@@ -0,0 +1,21 @@
from typing import Annotated, Literal
from pydantic import BaseModel, Field, TypeAdapter
class BaseExternalCommand[T: str](BaseModel):
command_type: T
class ChatCompletionNonStreamingCommand(
BaseExternalCommand[Literal["chat_completion_non_streaming"]]
):
command_type: Literal["chat_completion_non_streaming"] = (
"chat_completion_non_streaming"
)
ExternalCommand = Annotated[
ChatCompletionNonStreamingCommand, Field(discriminator="command_type")
]
ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalCommand)

View File

@@ -91,6 +91,8 @@ MasterLogEntries = (
MasterUninitializedLogEntry
| MasterCommandReceivedLogEntry
| MasterInvalidCommandReceivedLogEntry
| MasterCommandRunnerNotRunningLogEntry
| MasterStateManagerStoppedLogEntry
| EventCategoryUnknownLogEntry
| StateUpdateLoopAlreadyRunningLogEntry
| StateUpdateLoopStartedLogEntry

View File

@@ -1,20 +1,21 @@
from asyncio import CancelledError, Lock, Task
from asyncio import CancelledError, Lock, Task, create_task
from asyncio import Queue as AsyncQueue
from contextlib import asynccontextmanager
from logging import Logger, LogRecord
from queue import Queue as PQueue
from typing import Annotated, Literal
from typing import Callable, Sequence
from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, TypeAdapter
from master.commands import ExternalCommand
from master.env import MasterEnvironmentSchema
from master.logging import (
MasterCommandRunnerNotRunningLogEntry,
MasterStateManagerStoppedLogEntry,
MasterUninitializedLogEntry,
)
from master.router import QueueMapping
from master.state_manager.sync import SyncStateManagerMapping
from shared.constants import EXO_MASTER_STATE
from shared.logger import (
@@ -26,9 +27,11 @@ from shared.logger import (
log,
)
from shared.types.events.common import (
Apply,
EventCategory,
EventFetcherProtocol,
EventFromEventLog,
EventPublisher,
State,
)
from shared.types.models.common import ModelId
from shared.types.models.model import ModelInfo
@@ -60,22 +63,18 @@ def get_master_state_dependency(data: object, logger: Logger) -> MasterState:
return data
class BaseExternalCommand[T: str](BaseModel):
command_type: T
class ChatCompletionNonStreamingCommand(
BaseExternalCommand[Literal["chat_completion_non_streaming"]]
):
command_type: Literal["chat_completion_non_streaming"] = (
"chat_completion_non_streaming"
)
ExternalCommand = Annotated[
ChatCompletionNonStreamingCommand, Field(discriminator="command_type")
]
ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalCommand)
# Safety on Apply.
def safely_apply[T: EventCategory](
state: State[T], apply_fn: Apply[T], events: Sequence[EventFromEventLog[T]]
) -> State[T]:
sorted_events = sorted(events, key=lambda event: event.idx_in_log)
state = state.model_copy()
for event in sorted_events:
if event.idx_in_log <= state.last_event_applied_idx:
continue
state.last_event_applied_idx = event.idx_in_log
state = apply_fn(state, event)
return state
class MasterEventLoop:
@@ -84,24 +83,27 @@ class MasterEventLoop:
def __init__(
self,
initial_state: MasterState,
event_processor: EventFetcherProtocol[EventCategory],
push_events_to_queue: Callable[[QueueMapping], None],
event_publisher: EventPublisher[EventCategory],
state_managers: SyncStateManagerMapping,
logger: Logger,
):
self._state = initial_state
self._state_lock = Lock()
self._command_runner: Task[None] | None = None
self._event_queues: QueueMapping
self._command_runner: ...
self._command_run_task: Task[None] | None = None
self._command_queue: AsyncQueue[ExternalCommand] = AsyncQueue()
self._response_queue: AsyncQueue[Response | StreamingResponse] = AsyncQueue()
self._state_managers: SyncStateManagerMapping
self._event_fetcher: EventFetcherProtocol[EventCategory]
self._state_global_lock: Lock = Lock()
self._push_events_to_queue: Callable[[QueueMapping], None]
self._event_fetch_task: Task[None] | None = None
self._logger = logger
@property
def _is_command_runner_running(self) -> bool:
return self._command_runner is not None and not self._command_runner.done()
return self._command_run_task is not None and not self._command_run_task.done()
@property
def _is_event_fetcher_running(self) -> bool:
@@ -121,14 +123,26 @@ class MasterEventLoop:
async def start(self) -> None:
"""Start the background event loop."""
async def fetch_and_apply_events() -> None:
while True:
async with self._state_global_lock:
for state in self._state_managers.values():
self._push_events_to_queue(self._event_queues)
safely_apply(
state, apply_fn, self._event_queues[state.event_category]
)
self._event_fetch_task = create_task(fetch_and_apply_events())
self._command_run_task = create_task(self._command_runner())
async def stop(self) -> None:
"""Stop the background event loop and persist state."""
if not self._is_command_runner_running or not self._is_event_fetcher_running:
raise RuntimeError("Command Runner Is Not Running")
assert self._command_runner is not None and self._event_fetch_task is not None
assert self._command_run_task is not None and self._event_fetch_task is not None
for service in [self._event_fetch_task, self._command_runner]:
for service in [self._event_fetch_task, self._command_run_task]:
service.cancel()
try:
await service

View File

@@ -1,13 +1,12 @@
from asyncio import Queue, gather
from logging import Logger
from typing import Literal, TypedDict
from typing import Literal, Protocol, TypedDict
from master.sanity_checking import check_keys_in_map_match_enum_values
from shared.types.events.common import (
EventCategories,
EventCategory,
EventCategoryEnum,
EventFetcherProtocol,
EventFromEventLog,
narrow_event_from_event_log_type,
)
@@ -40,12 +39,19 @@ class QueueMapping(TypedDict):
check_keys_in_map_match_enum_values(QueueMapping, EventCategoryEnum)
class EventRouter:
class EventRouterProtocol(Protocol):
queue_map: QueueMapping
start_idx: int
def sync_queues(self) -> None: ...
class EventRouter(EventRouterProtocol):
"""Routes events to appropriate services based on event categories."""
queue_map: QueueMapping
event_fetcher: EventFetcherProtocol[EventCategory]
_logger: Logger
start_idx: int
logger: Logger
async def _get_queue_by_category[T: EventCategory](
self, category: T
@@ -82,8 +88,3 @@ class EventRouter:
await q2.put(narrow_event)
await gather(*[self._process_events(domain) for domain in EventCategoryEnum])
async def _get_events_to_process(
self,
) -> list[EventFromEventLog[EventCategories | EventCategory]]:
"""Get events to process from the event fetcher."""

View File

@@ -1,6 +0,0 @@
def main():
print("Hello from worker!")
if __name__ == "__main__":
main()