mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
272 lines
8.6 KiB
Python
272 lines
8.6 KiB
Python
from fastapi import FastAPI, Response
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field, TypeAdapter
|
|
from logging import Logger
|
|
|
|
from shared.types.events.common import Event, EventCategories, EventFetcherProtocol, EventPublisher, State
|
|
from shared.logger import (
|
|
configure_logger,
|
|
LogEntryType,
|
|
FilterLogByType,
|
|
create_queue_listener,
|
|
attach_to_queue,
|
|
)
|
|
from shared.types.worker.common import InstanceId
|
|
from shared.types.worker.instances import Instance
|
|
from shared.types.models.common import ModelId
|
|
from shared.types.models.model import ModelInfo
|
|
from shared.types.states.master import MasterState
|
|
from shared.constants import EXO_MASTER_STATE
|
|
from contextlib import asynccontextmanager
|
|
from logging import LogRecord
|
|
from typing import Annotated, Literal
|
|
from master.env import MasterEnvironmentSchema
|
|
from master.logging import (
|
|
MasterUninitializedLogEntry,
|
|
MasterCommandReceivedLogEntry,
|
|
MasterInvalidCommandReceivedLogEntry,
|
|
)
|
|
from master.event_routing import AsyncUpdateStateFromEvents
|
|
from shared.logger import log
|
|
from asyncio import Lock, Task, CancelledError, Queue, create_task
|
|
from enum import Enum
|
|
|
|
|
|
# Restore State
|
|
def get_master_state(logger: Logger) -> MasterState:
|
|
if EXO_MASTER_STATE.exists():
|
|
with open(EXO_MASTER_STATE, "r") as f:
|
|
return MasterState.model_validate_json(f.read())
|
|
else:
|
|
log(logger, MasterUninitializedLogEntry())
|
|
return MasterState()
|
|
|
|
|
|
# FastAPI Dependencies
|
|
def check_env_vars_defined(data: object, logger: Logger) -> MasterEnvironmentSchema:
|
|
if not isinstance(data, MasterEnvironmentSchema):
|
|
raise RuntimeError("Environment Variables Not Found")
|
|
return data
|
|
|
|
|
|
def get_master_state_dependency(data: object, logger: Logger) -> MasterState:
|
|
if not isinstance(data, MasterState):
|
|
raise RuntimeError("Master State Not Found")
|
|
return data
|
|
|
|
|
|
class BaseExternalCommand[T: str](BaseModel):
|
|
command_type: T
|
|
|
|
|
|
class ChatCompletionNonStreamingCommand(
|
|
BaseExternalCommand[Literal["chat_completion_non_streaming"]]
|
|
):
|
|
command_type: Literal["chat_completion_non_streaming"] = (
|
|
"chat_completion_non_streaming"
|
|
)
|
|
|
|
|
|
ExternalCommand = Annotated[
|
|
ChatCompletionNonStreamingCommand, Field(discriminator="command_type")
|
|
]
|
|
ExternalCommandParser: TypeAdapter[ExternalCommand] = TypeAdapter(ExternalCommand)
|
|
|
|
|
|
class MasterBackgroundServices(str, Enum):
|
|
MAIN_LOOP = "main_loop"
|
|
|
|
class StateManager[T: EventCategories]:
|
|
state: State[T]
|
|
queue: Queue[Event[T]]
|
|
manager: AsyncUpdateStateFromEvents[T]
|
|
|
|
def __init__(
|
|
self,
|
|
state: State[T],
|
|
queue: Queue[Event[T]],
|
|
) -> None:
|
|
...
|
|
|
|
class MasterStateManager:
|
|
"""Thread-safe manager for MasterState with independent event loop."""
|
|
|
|
def __init__(
|
|
self,
|
|
initial_state: MasterState,
|
|
event_processor: EventFetcherProtocol[EventCategories],
|
|
event_publisher: EventPublisher[EventCategories],
|
|
logger: Logger,
|
|
):
|
|
self._state = initial_state
|
|
self._state_lock = Lock()
|
|
self._command_queue: Queue[ExternalCommand] = Queue()
|
|
self._services: dict[MasterBackgroundServices, Task[None]] = {}
|
|
self._logger = logger
|
|
|
|
async def read_state(self) -> MasterState:
|
|
"""Get a thread-safe snapshot of the current state."""
|
|
async with self._state_lock:
|
|
return self._state.model_copy(deep=True)
|
|
|
|
async def send_command(
|
|
self, command: ExternalCommand
|
|
) -> Response | StreamingResponse:
|
|
"""Send a command to the background event loop."""
|
|
if self._services[MasterBackgroundServices.MAIN_LOOP]:
|
|
self._command_queue.put(command)
|
|
return Response(status_code=200)
|
|
else:
|
|
raise RuntimeError("State manager is not running")
|
|
|
|
async def start(self) -> None:
|
|
"""Start the background event loop."""
|
|
for service in MasterBackgroundServices:
|
|
match service:
|
|
case MasterBackgroundServices.MAIN_LOOP:
|
|
if self._services[service]:
|
|
raise RuntimeError("State manager is already running")
|
|
self._services[service]: Task[None] = create_task(self._event_loop())
|
|
log(self._logger, MasterStateManagerStartedLogEntry())
|
|
case _:
|
|
raise ValueError(f"Unknown service: {service}")
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the background event loop and persist state."""
|
|
if not self._services[MasterBackgroundServices.MAIN_LOOP]:
|
|
raise RuntimeError("State manager is not running")
|
|
|
|
for service in self._services.values():
|
|
service.cancel()
|
|
try:
|
|
await service
|
|
except CancelledError:
|
|
pass
|
|
|
|
log(self._logger, MasterStateManagerStoppedLogEntry())
|
|
|
|
async def _event_loop(self) -> None:
|
|
"""Independent event loop for processing commands and mutating state."""
|
|
while True:
|
|
try:
|
|
async with self._state_lock:
|
|
match EventCategories:
|
|
case EventCategories.InstanceEventTypes:
|
|
events_one = self._event_processor.get_events_to_apply(
|
|
self._state.data_plane_network_state
|
|
)
|
|
case EventCategories.InstanceStateEventTypes:
|
|
events_one = self._event_processor.get_events_to_apply(
|
|
self._state.control_plane_network_state
|
|
)
|
|
case _:
|
|
raise ValueError(
|
|
f"Unknown event category: {event_category}"
|
|
)
|
|
command = self._command_queue.get(timeout=5.0)
|
|
match command:
|
|
case ChatCompletionNonStreamingCommand():
|
|
log(
|
|
self._logger,
|
|
MasterCommandReceivedLogEntry(
|
|
command_name=command.command_type
|
|
),
|
|
)
|
|
case _:
|
|
log(
|
|
self._logger,
|
|
MasterInvalidCommandReceivedLogEntry(
|
|
command_name=command.command_type
|
|
),
|
|
)
|
|
except CancelledError:
|
|
break
|
|
except Exception as e:
|
|
log(self._logger, MasterStateManagerErrorLogEntry(error=str(e)))
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger = configure_logger("master")
|
|
|
|
telemetry_queue: Queue[LogRecord] = Queue()
|
|
metrics_queue: Queue[LogRecord] = Queue()
|
|
cluster_queue: Queue[LogRecord] = Queue()
|
|
|
|
attach_to_queue(
|
|
logger,
|
|
[
|
|
FilterLogByType(log_types={LogEntryType.telemetry}),
|
|
],
|
|
telemetry_queue,
|
|
)
|
|
attach_to_queue(
|
|
logger,
|
|
[
|
|
FilterLogByType(log_types={LogEntryType.metrics}),
|
|
],
|
|
metrics_queue,
|
|
)
|
|
attach_to_queue(
|
|
logger,
|
|
[
|
|
FilterLogByType(log_types={LogEntryType.cluster}),
|
|
],
|
|
cluster_queue,
|
|
)
|
|
|
|
# TODO: Add handlers
|
|
telemetry_listener = create_queue_listener(telemetry_queue, [])
|
|
metrics_listener = create_queue_listener(metrics_queue, [])
|
|
cluster_listener = create_queue_listener(cluster_queue, [])
|
|
|
|
telemetry_listener.start()
|
|
metrics_listener.start()
|
|
cluster_listener.start()
|
|
|
|
initial_state = get_master_state(logger)
|
|
app.state.master_state_manager = MasterStateManager(initial_state, logger)
|
|
await app.state.master_state_manager.start()
|
|
|
|
yield
|
|
|
|
await app.state.master_state_manager.stop()
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
@app.get("/topology/control_plane")
|
|
def get_control_plane_topology():
|
|
return {"message": "Hello, World!"}
|
|
|
|
|
|
@app.get("/topology/data_plane")
|
|
def get_data_plane_topology():
|
|
return {"message": "Hello, World!"}
|
|
|
|
|
|
@app.get("/instances/list")
|
|
def list_instances():
|
|
return {"message": "Hello, World!"}
|
|
|
|
|
|
@app.post("/instances/create")
|
|
def create_instance(model_id: ModelId) -> InstanceId: ...
|
|
|
|
|
|
@app.get("/instance/{instance_id}/read")
|
|
def get_instance(instance_id: InstanceId) -> Instance: ...
|
|
|
|
|
|
@app.delete("/instance/{instance_id}/delete")
|
|
def remove_instance(instance_id: InstanceId) -> None: ...
|
|
|
|
|
|
@app.get("/model/{model_id}/metadata")
|
|
def get_model_data(model_id: ModelId) -> ModelInfo: ...
|
|
|
|
|
|
@app.post("/model/{model_id}/instances")
|
|
def get_instances_by_model(model_id: ModelId) -> list[Instance]: ...
|