mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
198 lines
7.1 KiB
Python
198 lines
7.1 KiB
Python
import asyncio
|
|
import time
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Callable, List, Sequence, final
|
|
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
|
from shared.models.model_cards import MODEL_CARDS
|
|
from shared.models.model_meta import get_model_meta
|
|
from shared.types.api import (
|
|
ChatCompletionMessage,
|
|
ChatCompletionResponse,
|
|
CreateInstanceResponse,
|
|
CreateInstanceTaskParams,
|
|
DeleteInstanceResponse,
|
|
StreamingChoiceResponse,
|
|
)
|
|
from shared.types.common import CommandId
|
|
from shared.types.events import ChunkGenerated, Event
|
|
from shared.types.events.chunks import TokenChunk
|
|
from shared.types.events.commands import (
|
|
ChatCompletionCommand,
|
|
Command,
|
|
CommandType,
|
|
CreateInstanceCommand,
|
|
DeleteInstanceCommand,
|
|
)
|
|
from shared.types.events.components import EventFromEventLog
|
|
from shared.types.state import State
|
|
from shared.types.tasks import ChatCompletionTaskParams
|
|
from shared.types.worker.common import InstanceId
|
|
from shared.types.worker.instances import Instance
|
|
|
|
|
|
def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
|
|
return ChatCompletionResponse(
|
|
id='abc',
|
|
created=int(time.time()),
|
|
model='idk',
|
|
choices=[
|
|
StreamingChoiceResponse(
|
|
index=0,
|
|
delta=ChatCompletionMessage(
|
|
role='assistant',
|
|
content=chunk.text
|
|
),
|
|
finish_reason=chunk.finish_reason
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
@final
|
|
class API:
|
|
def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage, get_state: Callable[[], State]) -> None:
|
|
self._app = FastAPI()
|
|
self._setup_routes()
|
|
|
|
self.command_buffer = command_buffer
|
|
self.global_events = global_events
|
|
self.get_state = get_state
|
|
|
|
def _setup_routes(self) -> None:
|
|
# self._app.get("/topology/control_plane")(self.get_control_plane_topology)
|
|
# self._app.get("/topology/data_plane")(self.get_data_plane_topology)
|
|
# self._app.get("/instances/list")(self.list_instances)
|
|
self._app.post("/instances/create")(self.create_instance)
|
|
self._app.get("/instance/{instance_id}")(self.get_instance)
|
|
self._app.delete("/instance/{instance_id}")(self.delete_instance)
|
|
# self._app.get("/model/{model_id}/metadata")(self.get_model_data)
|
|
# self._app.post("/model/{model_id}/instances")(self.get_instances_by_model)
|
|
self._app.post("/v1/chat/completions")(self.chat_completions)
|
|
|
|
@property
|
|
def app(self) -> FastAPI:
|
|
return self._app
|
|
|
|
# def get_control_plane_topology(self):
|
|
# return {"message": "Hello, World!"}
|
|
|
|
# def get_data_plane_topology(self):
|
|
# return {"message": "Hello, World!"}
|
|
|
|
# def get_model_metadata(self, model_id: ModelId) -> ModelMetadata: ...
|
|
|
|
# def download_model(self, model_id: ModelId) -> None: ...
|
|
|
|
# def list_instances(self):
|
|
# return {"message": "Hello, World!"}
|
|
|
|
async def create_instance(self, payload: CreateInstanceTaskParams) -> CreateInstanceResponse:
|
|
if payload.model_id in MODEL_CARDS:
|
|
model_card = MODEL_CARDS[payload.model_id]
|
|
model_meta = model_card.metadata
|
|
else:
|
|
model_meta = await get_model_meta(payload.model_id)
|
|
|
|
command = CreateInstanceCommand(
|
|
command_id=CommandId(),
|
|
command_type=CommandType.CREATE_INSTANCE,
|
|
model_meta=model_meta,
|
|
instance_id=InstanceId(),
|
|
)
|
|
self.command_buffer.append(command)
|
|
|
|
return CreateInstanceResponse(
|
|
message="Command received.",
|
|
command_id=command.command_id,
|
|
model_meta=model_meta,
|
|
instance_id=command.instance_id,
|
|
)
|
|
|
|
def get_instance(self, instance_id: InstanceId) -> Instance:
|
|
state = self.get_state()
|
|
if instance_id not in state.instances:
|
|
raise HTTPException(status_code=404, detail="Instance not found")
|
|
return state.instances[instance_id]
|
|
|
|
def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
|
|
if instance_id not in self.get_state().instances:
|
|
raise HTTPException(status_code=404, detail="Instance not found")
|
|
|
|
command = DeleteInstanceCommand(
|
|
command_id=CommandId(),
|
|
command_type=CommandType.DELETE_INSTANCE,
|
|
instance_id=instance_id,
|
|
)
|
|
self.command_buffer.append(command)
|
|
return DeleteInstanceResponse(
|
|
message="Command received.",
|
|
command_id=command.command_id,
|
|
instance_id=instance_id,
|
|
)
|
|
|
|
# def get_model_data(self, model_id: ModelId) -> ModelInfo: ...
|
|
|
|
# def get_instances_by_model(self, model_id: ModelId) -> list[Instance]: ...
|
|
|
|
async def _generate_chat_stream(self, payload: ChatCompletionTaskParams) -> AsyncGenerator[str, None]:
|
|
"""Generate chat completion stream as JSON strings."""
|
|
events = await self.global_events.get_events_since(0)
|
|
prev_idx = await self.global_events.get_last_idx()
|
|
|
|
# At the moment, we just create the task in the API.
|
|
# In the future, a `Request` will be created here and they will be bundled into `Task` objects by the master.
|
|
command_id=CommandId()
|
|
|
|
request = ChatCompletionCommand(
|
|
command_id=command_id,
|
|
command_type=CommandType.CHAT_COMPLETION,
|
|
request_params=payload,
|
|
)
|
|
self.command_buffer.append(request)
|
|
|
|
finished = False
|
|
while not finished:
|
|
await asyncio.sleep(0.01)
|
|
|
|
events: Sequence[EventFromEventLog[Event]] = await self.global_events.get_events_since(prev_idx)
|
|
# TODO: Can do this with some better functionality to tail event log into an AsyncGenerator.
|
|
prev_idx = events[-1].idx_in_log if events else prev_idx
|
|
|
|
for wrapped_event in events:
|
|
event = wrapped_event.event
|
|
if isinstance(event, ChunkGenerated) and event.command_id == command_id:
|
|
assert isinstance(event.chunk, TokenChunk)
|
|
chunk_response: ChatCompletionResponse = chunk_to_response(event.chunk)
|
|
print(chunk_response)
|
|
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
|
|
|
if event.chunk.finish_reason is not None:
|
|
yield "data: [DONE]"
|
|
finished = True
|
|
|
|
return
|
|
|
|
async def chat_completions(self, payload: ChatCompletionTaskParams) -> StreamingResponse:
|
|
"""Handle chat completions with proper streaming response."""
|
|
return StreamingResponse(
|
|
self._generate_chat_stream(payload),
|
|
media_type="text/plain"
|
|
)
|
|
|
|
|
|
|
|
def start_fastapi_server(
|
|
command_buffer: List[Command],
|
|
global_events: AsyncSQLiteEventStorage,
|
|
get_state: Callable[[], State],
|
|
host: str = "0.0.0.0",
|
|
port: int = 8000,
|
|
):
|
|
api = API(command_buffer, global_events, get_state)
|
|
|
|
uvicorn.run(api.app, host=host, port=port) |