Files
exo/master/api.py
Alex Cheema a241c92dd1 Glue
2025-07-25 13:10:29 +01:00

198 lines
7.1 KiB
Python

import asyncio
import time
from collections.abc import AsyncGenerator
from typing import Callable, List, Sequence, final
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.models.model_cards import MODEL_CARDS
from shared.models.model_meta import get_model_meta
from shared.types.api import (
ChatCompletionMessage,
ChatCompletionResponse,
CreateInstanceResponse,
CreateInstanceTaskParams,
DeleteInstanceResponse,
StreamingChoiceResponse,
)
from shared.types.common import CommandId
from shared.types.events import ChunkGenerated, Event
from shared.types.events.chunks import TokenChunk
from shared.types.events.commands import (
ChatCompletionCommand,
Command,
CommandType,
CreateInstanceCommand,
DeleteInstanceCommand,
)
from shared.types.events.components import EventFromEventLog
from shared.types.state import State
from shared.types.tasks import ChatCompletionTaskParams
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import Instance
def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
return ChatCompletionResponse(
id='abc',
created=int(time.time()),
model='idk',
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role='assistant',
content=chunk.text
),
finish_reason=chunk.finish_reason
)
]
)
@final
class API:
def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage, get_state: Callable[[], State]) -> None:
self._app = FastAPI()
self._setup_routes()
self.command_buffer = command_buffer
self.global_events = global_events
self.get_state = get_state
def _setup_routes(self) -> None:
# self._app.get("/topology/control_plane")(self.get_control_plane_topology)
# self._app.get("/topology/data_plane")(self.get_data_plane_topology)
# self._app.get("/instances/list")(self.list_instances)
self._app.post("/instances/create")(self.create_instance)
self._app.get("/instance/{instance_id}")(self.get_instance)
self._app.delete("/instance/{instance_id}")(self.delete_instance)
# self._app.get("/model/{model_id}/metadata")(self.get_model_data)
# self._app.post("/model/{model_id}/instances")(self.get_instances_by_model)
self._app.post("/v1/chat/completions")(self.chat_completions)
@property
def app(self) -> FastAPI:
return self._app
# def get_control_plane_topology(self):
# return {"message": "Hello, World!"}
# def get_data_plane_topology(self):
# return {"message": "Hello, World!"}
# def get_model_metadata(self, model_id: ModelId) -> ModelMetadata: ...
# def download_model(self, model_id: ModelId) -> None: ...
# def list_instances(self):
# return {"message": "Hello, World!"}
async def create_instance(self, payload: CreateInstanceTaskParams) -> CreateInstanceResponse:
if payload.model_id in MODEL_CARDS:
model_card = MODEL_CARDS[payload.model_id]
model_meta = model_card.metadata
else:
model_meta = await get_model_meta(payload.model_id)
command = CreateInstanceCommand(
command_id=CommandId(),
command_type=CommandType.CREATE_INSTANCE,
model_meta=model_meta,
instance_id=InstanceId(),
)
self.command_buffer.append(command)
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
instance_id=command.instance_id,
)
def get_instance(self, instance_id: InstanceId) -> Instance:
state = self.get_state()
if instance_id not in state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
return state.instances[instance_id]
def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
if instance_id not in self.get_state().instances:
raise HTTPException(status_code=404, detail="Instance not found")
command = DeleteInstanceCommand(
command_id=CommandId(),
command_type=CommandType.DELETE_INSTANCE,
instance_id=instance_id,
)
self.command_buffer.append(command)
return DeleteInstanceResponse(
message="Command received.",
command_id=command.command_id,
instance_id=instance_id,
)
# def get_model_data(self, model_id: ModelId) -> ModelInfo: ...
# def get_instances_by_model(self, model_id: ModelId) -> list[Instance]: ...
async def _generate_chat_stream(self, payload: ChatCompletionTaskParams) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
events = await self.global_events.get_events_since(0)
prev_idx = await self.global_events.get_last_idx()
# At the moment, we just create the task in the API.
# In the future, a `Request` will be created here and they will be bundled into `Task` objects by the master.
command_id=CommandId()
request = ChatCompletionCommand(
command_id=command_id,
command_type=CommandType.CHAT_COMPLETION,
request_params=payload,
)
self.command_buffer.append(request)
finished = False
while not finished:
await asyncio.sleep(0.01)
events: Sequence[EventFromEventLog[Event]] = await self.global_events.get_events_since(prev_idx)
# TODO: Can do this with some better functionality to tail event log into an AsyncGenerator.
prev_idx = events[-1].idx_in_log if events else prev_idx
for wrapped_event in events:
event = wrapped_event.event
if isinstance(event, ChunkGenerated) and event.command_id == command_id:
assert isinstance(event.chunk, TokenChunk)
chunk_response: ChatCompletionResponse = chunk_to_response(event.chunk)
print(chunk_response)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if event.chunk.finish_reason is not None:
yield "data: [DONE]"
finished = True
return
async def chat_completions(self, payload: ChatCompletionTaskParams) -> StreamingResponse:
"""Handle chat completions with proper streaming response."""
return StreamingResponse(
self._generate_chat_stream(payload),
media_type="text/plain"
)
def start_fastapi_server(
command_buffer: List[Command],
global_events: AsyncSQLiteEventStorage,
get_state: Callable[[], State],
host: str = "0.0.0.0",
port: int = 8000,
):
api = API(command_buffer, global_events, get_state)
uvicorn.run(api.app, host=host, port=port)