From 14b3c4a6beb29eb7da262bfbbe258f1202e6e074 Mon Sep 17 00:00:00 2001 From: Matt Beton Date: Tue, 22 Jul 2025 21:21:12 +0100 Subject: [PATCH] New API! --- master/api.py | 162 +++++++++++++++-- master/main.py | 244 +++++++++----------------- master/pyproject.toml | 1 + master/tests/api_utils_test.py | 78 ++++++++ master/tests/test_api.py | 47 +++++ shared/db/sqlite/connector.py | 27 ++- shared/tests/test_sqlite_connector.py | 62 +++++-- shared/types/api.py | 35 +++- shared/types/events/chunks.py | 35 +--- shared/types/events/events.py | 3 +- shared/types/tasks/common.py | 38 +--- shared/types/tasks/request.py | 12 ++ uv.lock | 24 +++ worker/main.py | 5 +- worker/runner/runner_supervisor.py | 13 +- worker/tests/conftest.py | 3 +- worker/tests/test_supervisor.py | 16 +- worker/tests/test_worker_handlers.py | 6 +- 18 files changed, 527 insertions(+), 284 deletions(-) create mode 100644 master/tests/api_utils_test.py create mode 100644 master/tests/test_api.py create mode 100644 shared/types/tasks/request.py diff --git a/master/api.py b/master/api.py index 2751f2df..219b5f57 100644 --- a/master/api.py +++ b/master/api.py @@ -1,25 +1,157 @@ -from typing import Protocol +import asyncio +import time +from asyncio.queues import Queue +from collections.abc import AsyncGenerator +from typing import List, Optional, Sequence, final -from shared.types.graphs.topology import Topology -from shared.types.models import ModelId, ModelMetadata -from shared.types.worker.common import InstanceId -from shared.types.worker.downloads import DownloadProgress -from shared.types.worker.instances import Instance +import uvicorn +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.types.events.chunks import TokenChunk +from shared.types.events.components import EventFromEventLog +from shared.types.events.events import ChunkGenerated +from shared.types.events.registry import Event +from shared.types.tasks.common import ChatCompletionTaskParams +from shared.types.tasks.request import APIRequest, RequestId -class ClusterAPI(Protocol): - def get_topology(self) -> Topology: ... +class Message(BaseModel): + role: str + content: str - def list_instances(self) -> list[Instance]: ... +class StreamingChoiceResponse(BaseModel): + index: int + delta: Message + finish_reason: Optional[str] = None - def get_instance(self, instance_id: InstanceId) -> Instance: ... - def create_instance(self, model_id: ModelId) -> InstanceId: ... +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[StreamingChoiceResponse] - def remove_instance(self, instance_id: InstanceId) -> None: ... +def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse: + return ChatCompletionResponse( + id='abc', + created=int(time.time()), + model='idk', + choices=[ + StreamingChoiceResponse( + index=0, + delta=Message( + role='assistant', + content=chunk.text + ), + finish_reason=chunk.finish_reason + ) + ] + ) - def get_model_metadata(self, model_id: ModelId) -> ModelMetadata: ... - def download_model(self, model_id: ModelId) -> None: ... +@final +class API: + def __init__(self, command_queue: Queue[APIRequest], global_events: AsyncSQLiteEventStorage) -> None: + self._app = FastAPI() + self._setup_routes() - def get_download_progress(self, model_id: ModelId) -> DownloadProgress: ... + self.command_queue = command_queue + self.global_events = global_events + + def _setup_routes(self) -> None: + # self._app.get("/topology/control_plane")(self.get_control_plane_topology) + # self._app.get("/topology/data_plane")(self.get_data_plane_topology) + # self._app.get("/instances/list")(self.list_instances) + # self._app.post("/instances/create")(self.create_instance) + # self._app.get("/instance/{instance_id}/read")(self.get_instance) + # self._app.delete("/instance/{instance_id}/delete")(self.remove_instance) + # self._app.get("/model/{model_id}/metadata")(self.get_model_data) + # self._app.post("/model/{model_id}/instances")(self.get_instances_by_model) + self._app.post("/v1/chat/completions")(self.chat_completions) + + @property + def app(self) -> FastAPI: + return self._app + + # def get_control_plane_topology(self): + # return {"message": "Hello, World!"} + + # def get_data_plane_topology(self): + # return {"message": "Hello, World!"} + + # def get_model_metadata(self, model_id: ModelId) -> ModelMetadata: ... + + # def download_model(self, model_id: ModelId) -> None: ... + + # def list_instances(self): + # return {"message": "Hello, World!"} + + # def create_instance(self, model_id: ModelId) -> InstanceId: ... + + # def get_instance(self, instance_id: InstanceId) -> Instance: ... + + # def remove_instance(self, instance_id: InstanceId) -> None: ... + + # def get_model_data(self, model_id: ModelId) -> ModelInfo: ... + + # def get_instances_by_model(self, model_id: ModelId) -> list[Instance]: ... + + async def _generate_chat_stream(self, payload: ChatCompletionTaskParams) -> AsyncGenerator[str, None]: + """Generate chat completion stream as JSON strings.""" + events = await self.global_events.get_events_since(0) + prev_idx = await self.global_events.get_last_idx() + + # At the moment, we just create the task in the API. + # In the future, a `Request` will be created here and they will be bundled into `Task` objects by the master. + request_id=RequestId() + + request = APIRequest( + request_id=request_id, + request_params=payload, + ) + await self.command_queue.put(request) + + finished = False + while not finished: + await asyncio.sleep(0.01) + + events: Sequence[EventFromEventLog[Event]] = await self.global_events.get_events_since(prev_idx) + # TODO: Can do this with some better functionality to tail event log into an AsyncGenerator. + prev_idx = events[-1].idx_in_log if events else prev_idx + + for wrapped_event in events: + event = wrapped_event.event + if isinstance(event, ChunkGenerated) and event.request_id == request_id: + assert isinstance(event.chunk, TokenChunk) + chunk_response: ChatCompletionResponse = chunk_to_response(event.chunk) + print(chunk_response) + yield f"data: {chunk_response.model_dump_json()}\n\n" + + if event.chunk.finish_reason is not None: + yield "data: [DONE]" + finished = True + + return + + async def chat_completions(self, payload: ChatCompletionTaskParams) -> StreamingResponse: + """Handle chat completions with proper streaming response.""" + return StreamingResponse( + self._generate_chat_stream(payload), + media_type="text/plain" + ) + + + +def start_fastapi_server( + command_queue: Queue[APIRequest], + global_events: AsyncSQLiteEventStorage, + host: str = "0.0.0.0", + port: int = 8000, +): + api = API(command_queue, global_events) + + uvicorn.run(api.app, host=host, port=port) \ No newline at end of file diff --git a/master/main.py b/master/main.py index 8e4dadeb..37949c27 100644 --- a/master/main.py +++ b/master/main.py @@ -1,171 +1,97 @@ -from contextlib import asynccontextmanager -from logging import Logger, LogRecord -from queue import Queue as PQueue +import asyncio +import threading +from asyncio.queues import Queue +from logging import Logger -from fastapi import FastAPI - -from master.env import MasterEnvironmentSchema -from master.logging import ( - MasterUninitializedLogEntry, -) -from shared.constants import EXO_MASTER_STATE -from shared.event_loops.main import NodeEventLoopProtocol -from shared.logger import ( - FilterLogByType, - LogEntryType, - attach_to_queue, - configure_logger, - create_queue_listener, - log, -) -from shared.types.models import ModelId, ModelMetadata -from shared.types.state import State -from shared.types.worker.common import InstanceId -from shared.types.worker.instances import Instance +from master.api import start_fastapi_server +from shared.db.sqlite.config import EventLogConfig +from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.db.sqlite.event_log_manager import EventLogManager +from shared.types.common import NodeId +from shared.types.events.chunks import TokenChunk +from shared.types.events.events import ChunkGenerated +from shared.types.tasks.request import APIRequest, RequestId -# Restore State -def get_state(logger: Logger) -> State: - if EXO_MASTER_STATE.exists(): - with open(EXO_MASTER_STATE, "r") as f: - return State.model_validate_json(f.read()) - else: - log(logger, MasterUninitializedLogEntry()) - return State() +## TODO: Hook this up properly +async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, request_id: RequestId): + model_id = "testmodelabc" + + for i in range(10): + await asyncio.sleep(0.1) + + # Create the event with proper types and consistent IDs + chunk_event = ChunkGenerated( + request_id=request_id, + chunk=TokenChunk( + request_id=request_id, # Use the same task_id + idx=i, + model=model_id, # Use the same model_id + text=f'text{i}', + token_id=i + ) + ) + + # ChunkGenerated needs to be cast to the expected BaseEvent type + await events_log.append_events( + [chunk_event], + origin=NodeId() + ) + await asyncio.sleep(0.1) -# FastAPI Dependencies -def check_env_vars_defined(data: object, logger: Logger) -> MasterEnvironmentSchema: - if not isinstance(data, MasterEnvironmentSchema): - raise RuntimeError("Environment Variables Not Found") - return data + # Create the event with proper types and consistent IDs + chunk_event = ChunkGenerated( + request_id=request_id, + chunk=TokenChunk( + request_id=request_id, # Use the same task_id + idx=11, + model=model_id, # Use the same model_id + text=f'text{11}', + token_id=11, + finish_reason='stop' + ) + ) - -def get_state_dependency(data: object, logger: Logger) -> State: - if not isinstance(data, State): - raise RuntimeError("Master State Not Found") - return data - - -# Takes Care Of All States And Events Related To The Master -class MasterEventLoopProtocol(NodeEventLoopProtocol): ... - - -@asynccontextmanager -async def lifespan(app: FastAPI): - logger = configure_logger("master") - - telemetry_queue: PQueue[LogRecord] = PQueue() - metrics_queue: PQueue[LogRecord] = PQueue() - cluster_queue: PQueue[LogRecord] = PQueue() - - attach_to_queue( - logger, - [ - FilterLogByType(log_types={LogEntryType.telemetry}), - ], - telemetry_queue, - ) - attach_to_queue( - logger, - [ - FilterLogByType(log_types={LogEntryType.metrics}), - ], - metrics_queue, - ) - attach_to_queue( - logger, - [ - FilterLogByType(log_types={LogEntryType.cluster}), - ], - cluster_queue, + # ChunkGenerated needs to be cast to the expected BaseEvent type + await events_log.append_events( + [chunk_event], + origin=NodeId() ) - # TODO: Add Handlers For Pushing Logs To Remote Services - telemetry_listener = create_queue_listener(telemetry_queue, []) - metrics_listener = create_queue_listener(metrics_queue, []) - cluster_listener = create_queue_listener(cluster_queue, []) - - telemetry_listener.start() - metrics_listener.start() - cluster_listener.start() - - # # Get validated environment - # env = get_validated_env(MasterEnvironmentSchema, logger) - - # # Initialize event log manager (creates both worker and global event DBs) - # event_log_config = EventLogConfig() # Uses default config - # event_log_manager = EventLogManager( - # config=event_log_config, - # logger=logger - # ) - # await event_log_manager.initialize() - - # # Store for use in API handlers - # app.state.event_log_manager = event_log_manager - - # # Initialize forwarder if configured - # if env.FORWARDER_BINARY_PATH: - # forwarder_supervisor = ForwarderSupervisor( - # forwarder_binary_path=env.FORWARDER_BINARY_PATH, - # logger=logger - # ) - # # Start as replica by default (until elected) - # await forwarder_supervisor.start_as_replica() - - # # Create election callbacks for Rust election system - # election_callbacks = ElectionCallbacks( - # forwarder_supervisor=forwarder_supervisor, - # logger=logger - # ) - - # # Make callbacks available for Rust code to invoke - # app.state.election_callbacks = election_callbacks - - # # Log status - # logger.info( - # f"Forwarder supervisor initialized. Running: {forwarder_supervisor.is_running}" - # ) - # else: - # logger.warning("No forwarder binary path configured") - # forwarder_supervisor = None - # initial_state = get_master_state(logger) - # app.state.master_event_loop = MasterEventLoop() - # await app.state.master_event_loop.start() - - yield - - # await app.state.master_event_loop.stop() -app = FastAPI(lifespan=lifespan) +async def main(): + logger = Logger(name='master_logger') + + event_log_manager = EventLogManager(EventLogConfig(), logger=logger) + await event_log_manager.initialize() + global_events: AsyncSQLiteEventStorage = event_log_manager.global_events + + command_queue: Queue[APIRequest] = asyncio.Queue() + + api_thread = threading.Thread( + target=start_fastapi_server, + args=( + command_queue, + global_events, + ), + daemon=True + ) + api_thread.start() + print('Running FastAPI server in a separate thread. Listening on port 8000.') + + while True: + # master loop + if not command_queue.empty(): + command = await command_queue.get() + + print(command) + + await fake_tokens_task(global_events, request_id=command.request_id) + + await asyncio.sleep(0.01) -@app.get("/topology") -def get_topology(): - return {"message": "Hello, World!"} - - -@app.get("/instances/list") -def list_instances(): - return {"message": "Hello, World!"} - - -@app.post("/instances/create") -def create_instance(model_id: ModelId) -> InstanceId: ... - - -@app.get("/instance/{instance_id}/read") -def get_instance(instance_id: InstanceId) -> Instance: ... - - -@app.delete("/instance/{instance_id}/delete") -def remove_instance(instance_id: InstanceId) -> None: ... - - -@app.get("/model/{model_id}/metadata") -def get_model_metadata(model_id: ModelId) -> ModelMetadata: ... - - -@app.post("/model/{model_id}/instances") -def get_instances_by_model(model_id: ModelId) -> list[Instance]: ... +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/master/pyproject.toml b/master/pyproject.toml index b8912679..d1343631 100644 --- a/master/pyproject.toml +++ b/master/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.13" dependencies = [ "exo-shared", "fastapi>=0.116.0", + "uvicorn>=0.35.0", ] [build-system] diff --git a/master/tests/api_utils_test.py b/master/tests/api_utils_test.py new file mode 100644 index 00000000..a51622d1 --- /dev/null +++ b/master/tests/api_utils_test.py @@ -0,0 +1,78 @@ +import asyncio +import functools +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Coroutine, + ParamSpec, + TypeVar, + final, +) + +import openai +import pytest +from openai._streaming import AsyncStream +from openai.types.chat import ( + ChatCompletionMessageParam, +) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice + +from master.main import main as master_main + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +OPENAI_API_KEY: str = "" +OPENAI_API_URL: str = "http://0.0.0.0:8000/v1" + +def with_master_main( + func: Callable[_P, Awaitable[_R]] +) -> Callable[_P, Coroutine[Any, Any, _R]]: + @pytest.mark.asyncio + @functools.wraps(func) + async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + master_task = asyncio.create_task(master_main()) + try: + return await func(*args, **kwargs) + finally: + master_task.cancel() + with pytest.raises(asyncio.CancelledError): + await master_task + return wrapper + +@final +class ChatMessage: + """Strictly-typed chat message for OpenAI API.""" + def __init__(self, role: str, content: str) -> None: + self.role = role + self.content = content + + def to_openai(self) -> ChatCompletionMessageParam: + if self.role == "user": + return {"role": "user", "content": self.content} # type: ChatCompletionUserMessageParam + elif self.role == "assistant": + return {"role": "assistant", "content": self.content} # type: ChatCompletionAssistantMessageParam + elif self.role == "system": + return {"role": "system", "content": self.content} # type: ChatCompletionSystemMessageParam + else: + raise ValueError(f"Unsupported role: {self.role}") + +async def stream_chatgpt_response( + messages: list[ChatMessage], + model: str = "gpt-3.5-turbo", +) -> AsyncGenerator[Choice, None]: + client = openai.AsyncOpenAI( + api_key=OPENAI_API_KEY, + base_url=OPENAI_API_URL, + ) + openai_messages: list[ChatCompletionMessageParam] = [m.to_openai() for m in messages] + stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + model=model, + messages=openai_messages, + stream=True, + ) + async for chunk in stream: + for choice in chunk.choices: + yield choice diff --git a/master/tests/test_api.py b/master/tests/test_api.py new file mode 100644 index 00000000..7fd01916 --- /dev/null +++ b/master/tests/test_api.py @@ -0,0 +1,47 @@ +import asyncio + +import pytest + +from master.tests.api_utils_test import ( + ChatMessage, + stream_chatgpt_response, + with_master_main, +) + + +@with_master_main +@pytest.mark.asyncio +async def test_master_api_multiple_response_sequential() -> None: + messages = [ + ChatMessage(role="user", content="Hello, who are you?") + ] + token_count = 0 + text: str = "" + async for choice in stream_chatgpt_response(messages): + print(choice, flush=True) + if choice.delta and choice.delta.content: + text += choice.delta.content + token_count += 1 + if choice.finish_reason: + break + + assert token_count >= 3, f"Expected at least 3 tokens, got {token_count}" + assert len(text) > 0, "Expected non-empty response text" + + await asyncio.sleep(0.1) + + messages = [ + ChatMessage(role="user", content="What time is it in France?") + ] + token_count = 0 + text = "" # re-initialize, do not redeclare type + async for choice in stream_chatgpt_response(messages): + print(choice, flush=True) + if choice.delta and choice.delta.content: + text += choice.delta.content + token_count += 1 + if choice.finish_reason: + break + + assert token_count >= 3, f"Expected at least 3 tokens, got {token_count}" + assert len(text) > 0, "Expected non-empty response text" diff --git a/shared/db/sqlite/connector.py b/shared/db/sqlite/connector.py index 44de9efd..4b40cf9b 100644 --- a/shared/db/sqlite/connector.py +++ b/shared/db/sqlite/connector.py @@ -116,6 +116,27 @@ class AsyncSQLiteEventStorage: )) return events + + async def get_last_idx(self) -> int: + if self._closed: + raise RuntimeError("Storaged is closed") + + assert self._engine is not None + + async with AsyncSession(self._engine) as session: + result = await session.execute( + text("SELECT rowid, origin, event_data FROM events ORDER BY rowid DESC LIMIT 1"), + {} + ) + rows = result.fetchall() + + if len(rows) == 0: + return 0 + if len(rows) == 1: + row = rows[0] + return cast(int, row[0]) + else: + raise AssertionError("There should have been at most 1 row returned from this SQL query.") async def close(self) -> None: """Close the storage connection and cleanup resources.""" @@ -211,12 +232,12 @@ class AsyncSQLiteEventStorage: try: async with AsyncSession(self._engine) as session: - for event, origin in batch: + for event, origin in batch: stored_event = StoredEvent( origin=str(origin.uuid), - event_type=str(event.event_type), + event_type=event.event_type, event_id=str(event.event_id), - event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion + event_data=event.model_dump(mode='json') # Serialize UUIDs and other objects to JSON-compatible strings ) session.add(stored_event) diff --git a/shared/tests/test_sqlite_connector.py b/shared/tests/test_sqlite_connector.py index c78e51dc..7bd98b40 100644 --- a/shared/tests/test_sqlite_connector.py +++ b/shared/tests/test_sqlite_connector.py @@ -11,12 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig from shared.types.common import NodeId -from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData +from shared.types.events.chunks import ChunkType, TokenChunk from shared.types.events.events import ( ChunkGenerated, EventType, ) -from shared.types.tasks.common import TaskId +from shared.types.tasks.request import RequestId # Type ignore comment for all protected member access in this test file # pyright: reportPrivateUsage=false @@ -162,6 +162,41 @@ class TestAsyncSQLiteEventStorage: await storage.close() + + + @pytest.mark.asyncio + async def test_get_last_idx(self, temp_db_path: Path, sample_node_id: NodeId) -> None: + """Test that rowid returns correctly from db.""" + default_config = EventLogConfig() + storage = AsyncSQLiteEventStorage(db_path=temp_db_path, batch_size=default_config.batch_size, batch_timeout_ms=default_config.batch_timeout_ms, debounce_ms=default_config.debounce_ms, max_age_ms=default_config.max_age_ms) + await storage.start() + + # Insert multiple records + test_records = [ + {"event_type": "test_event_1", "data": "first"}, + {"event_type": "test_event_2", "data": "second"}, + {"event_type": "test_event_3", "data": "third"} + ] + + assert storage._engine is not None + async with AsyncSession(storage._engine) as session: + for record in test_records: + await session.execute( + text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"), + { + "origin": str(sample_node_id.uuid), + "event_type": record["event_type"], + "event_id": str(uuid4()), + "event_data": json.dumps(record) + } + ) + await session.commit() + + last_idx = await storage.get_last_idx() + assert last_idx == 3 + + await storage.close() + @pytest.mark.asyncio async def test_rowid_with_multiple_origins(self, temp_db_path: Path) -> None: """Test rowid sequence across multiple origins.""" @@ -404,22 +439,19 @@ class TestAsyncSQLiteEventStorage: await storage.start() # Create a ChunkGenerated event with nested TokenChunk - task_id = TaskId(uuid=uuid4()) - chunk_data = TokenChunkData( + request_id = RequestId(uuid=uuid4()) + token_chunk = TokenChunk( text="Hello, world!", token_id=42, - finish_reason="stop" - ) - token_chunk = TokenChunk( - chunk_data=chunk_data, + finish_reason="stop", chunk_type=ChunkType.token, - task_id=task_id, + request_id=request_id, idx=0, model="test-model" ) chunk_generated_event = ChunkGenerated( - task_id=task_id, + request_id=request_id, chunk=token_chunk ) @@ -441,19 +473,19 @@ class TestAsyncSQLiteEventStorage: retrieved_event = retrieved_event_wrapper.event assert isinstance(retrieved_event, ChunkGenerated) assert retrieved_event.event_type == EventType.ChunkGenerated - assert retrieved_event.task_id == task_id + assert retrieved_event.request_id == request_id # Verify the nested chunk was deserialized correctly retrieved_chunk = retrieved_event.chunk assert isinstance(retrieved_chunk, TokenChunk) assert retrieved_chunk.chunk_type == ChunkType.token - assert retrieved_chunk.task_id == task_id + assert retrieved_chunk.request_id == request_id assert retrieved_chunk.idx == 0 assert retrieved_chunk.model == "test-model" # Verify the chunk data - assert retrieved_chunk.chunk_data.text == "Hello, world!" - assert retrieved_chunk.chunk_data.token_id == 42 - assert retrieved_chunk.chunk_data.finish_reason == "stop" + assert retrieved_chunk.text == "Hello, world!" + assert retrieved_chunk.token_id == 42 + assert retrieved_chunk.finish_reason == "stop" await storage.close() \ No newline at end of file diff --git a/shared/types/api.py b/shared/types/api.py index 8c581c41..37f1a74e 100644 --- a/shared/types/api.py +++ b/shared/types/api.py @@ -1,11 +1,34 @@ -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel -from shared.types.tasks.common import ChatCompletionTaskParams, TaskId + +class ChatCompletionMessage(BaseModel): + role: Literal["system", "user", "assistant", "developer", "tool", "function"] + content: str | None = None + name: str | None = None + tool_calls: list[dict[str, Any]] | None = None + tool_call_id: str | None = None + function_call: dict[str, Any] | None = None -class ChatTask(BaseModel): - task_id: TaskId - kind: Literal["chat"] = "chat" - task_data: ChatCompletionTaskParams +class ChatCompletionTaskParams(BaseModel): + model: str + frequency_penalty: float | None = None + messages: list[ChatCompletionMessage] + logit_bias: dict[str, int] | None = None + logprobs: bool | None = None + top_logprobs: int | None = None + max_tokens: int | None = None + n: int | None = None + presence_penalty: float | None = None + response_format: dict[str, Any] | None = None + seed: int | None = None + stop: str | list[str] | None = None + stream: bool = False + temperature: float | None = None + top_p: float | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + parallel_tool_calls: bool | None = None + user: str | None = None \ No newline at end of file diff --git a/shared/types/events/chunks.py b/shared/types/events/chunks.py index 8db92f51..860633e1 100644 --- a/shared/types/events/chunks.py +++ b/shared/types/events/chunks.py @@ -1,13 +1,11 @@ from enum import Enum from typing import Annotated, Literal -# from openai.types.chat.chat_completion import ChatCompletion -# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel, Field, TypeAdapter from shared.openai_compat import FinishReason from shared.types.models import ModelId -from shared.types.tasks.common import TaskId +from shared.types.tasks.request import RequestId class ChunkType(str, Enum): @@ -17,38 +15,21 @@ class ChunkType(str, Enum): class BaseChunk[ChunkTypeT: ChunkType](BaseModel): chunk_type: ChunkTypeT - task_id: TaskId + request_id: RequestId idx: int model: ModelId -### - - -class TokenChunkData(BaseModel): +class TokenChunk(BaseChunk[ChunkType.token]): + chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True) text: str token_id: int finish_reason: FinishReason | None = None -class ImageChunkData(BaseModel): - data: bytes - - -### - - -class TokenChunk(BaseChunk[ChunkType.token]): - chunk_data: TokenChunkData - chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True) - - class ImageChunk(BaseChunk[ChunkType.image]): - chunk_data: ImageChunkData chunk_type: Literal[ChunkType.image] = Field(default=ChunkType.image, frozen=True) - - -### + data: bytes GenerationChunk = Annotated[TokenChunk | ImageChunk, Field(discriminator="chunk_type")] GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(GenerationChunk) @@ -60,10 +41,8 @@ GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(Generatio # my_chunk: dict[str, Any] = TokenChunk( # task_id=TaskId('nicerid'), # idx=0, -# chunk_data=TokenChunkData( -# text='hello', -# token_id=12, -# ), + # text='hello', + # token_id=12, # chunk_type=ChunkType.token, # model='llama-3.1', # ).model_dump() diff --git a/shared/types/events/events.py b/shared/types/events/events.py index 478e82de..dd9a1d5c 100644 --- a/shared/types/events/events.py +++ b/shared/types/events/events.py @@ -17,6 +17,7 @@ from shared.types.graphs.topology import ( ) from shared.types.profiling.common import NodePerformanceProfile from shared.types.tasks.common import Task, TaskId, TaskStatus +from shared.types.tasks.request import RequestId from shared.types.worker.common import InstanceId, NodeStatus from shared.types.worker.instances import InstanceParams, TypeOfInstance from shared.types.worker.runners import RunnerId, RunnerStatus @@ -111,7 +112,7 @@ class WorkerDisconnected(BaseEvent[EventType.WorkerDisconnected]): class ChunkGenerated(BaseEvent[EventType.ChunkGenerated]): event_type: Literal[EventType.ChunkGenerated] = EventType.ChunkGenerated - task_id: TaskId + request_id: RequestId chunk: GenerationChunk diff --git a/shared/types/tasks/common.py b/shared/types/tasks/common.py index 8710c5f7..c324c42d 100644 --- a/shared/types/tasks/common.py +++ b/shared/types/tasks/common.py @@ -1,8 +1,8 @@ from enum import Enum -from typing import Any, Literal from pydantic import BaseModel +from shared.types.api import ChatCompletionTaskParams from shared.types.common import NewUUID from shared.types.worker.common import InstanceId @@ -10,11 +10,9 @@ from shared.types.worker.common import InstanceId class TaskId(NewUUID): pass - class TaskType(str, Enum): ChatCompletion = "ChatCompletion" - class TaskStatus(str, Enum): Pending = "Pending" Running = "Running" @@ -22,42 +20,10 @@ class TaskStatus(str, Enum): Failed = "Failed" -class ChatCompletionMessage(BaseModel): - role: Literal["system", "user", "assistant", "developer", "tool", "function"] - content: str | None = None - name: str | None = None - tool_calls: list[dict[str, Any]] | None = None - tool_call_id: str | None = None - function_call: dict[str, Any] | None = None - - -class ChatCompletionTaskParams(BaseModel): - task_type: Literal[TaskType.ChatCompletion] = TaskType.ChatCompletion - model: str - frequency_penalty: float | None = None - messages: list[ChatCompletionMessage] - logit_bias: dict[str, int] | None = None - logprobs: bool | None = None - top_logprobs: int | None = None - max_tokens: int | None = None - n: int | None = None - presence_penalty: float | None = None - response_format: dict[str, Any] | None = None - seed: int | None = None - stop: str | list[str] | None = None - stream: bool = False - temperature: float | None = None - top_p: float | None = None - tools: list[dict[str, Any]] | None = None - tool_choice: str | dict[str, Any] | None = None - parallel_tool_calls: bool | None = None - user: str | None = None - - class Task(BaseModel): task_id: TaskId + task_type: TaskType # redundant atm as we only have 1 task type. instance_id: InstanceId - task_type: TaskType task_status: TaskStatus task_params: ChatCompletionTaskParams diff --git a/shared/types/tasks/request.py b/shared/types/tasks/request.py new file mode 100644 index 00000000..a9a267a8 --- /dev/null +++ b/shared/types/tasks/request.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from shared.types.api import ChatCompletionTaskParams +from shared.types.common import NewUUID + + +class RequestId(NewUUID): + pass + +class APIRequest(BaseModel): + request_id: RequestId + request_params: ChatCompletionTaskParams \ No newline at end of file diff --git a/uv.lock b/uv.lock index e91fab50..d1fc02fc 100644 --- a/uv.lock +++ b/uv.lock @@ -154,6 +154,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" }, ] +[[package]] +name = "click" +version = "8.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, +] + [[package]] name = "distro" version = "1.9.0" @@ -219,12 +228,14 @@ source = { editable = "master" } dependencies = [ { name = "exo-shared", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "uvicorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] [package.metadata] requires-dist = [ { name = "exo-shared", editable = "shared" }, { name = "fastapi", specifier = ">=0.116.0" }, + { name = "uvicorn", specifier = ">=0.35.0" }, ] [[package]] @@ -1129,6 +1140,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "uvicorn" +version = "0.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/42/e0e305207bb88c6b8d3061399c6a961ffe5fbb7e2aa63c9234df7259e9cd/uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01", size = 78473, upload-time = "2025-06-28T16:15:46.058Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/e2/dc81b1bd1dcfe91735810265e9d26bc8ec5da45b4c0f6237e286819194c3/uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a", size = 66406, upload-time = "2025-06-28T16:15:44.816Z" }, +] + [[package]] name = "yarl" version = "1.20.1" diff --git a/worker/main.py b/worker/main.py index e0295c1b..9bb6121e 100644 --- a/worker/main.py +++ b/worker/main.py @@ -236,12 +236,15 @@ class Worker: assigned_runner.status = RunningRunnerStatus() await queue.put(assigned_runner.status_update_event()) + try: async for chunk in assigned_runner.runner.stream_response( task=op.task, request_started_callback=partial(running_callback, queue)): await queue.put(ChunkGenerated( - task_id=op.task.task_id, + # todo: at some point we will no longer have a bijection between task_id and row_id. + # So we probably want to store a mapping between these two in our Worker object. + request_id=chunk.request_id, chunk=chunk )) diff --git a/worker/runner/runner_supervisor.py b/worker/runner/runner_supervisor.py index 1df40e47..de527932 100644 --- a/worker/runner/runner_supervisor.py +++ b/worker/runner/runner_supervisor.py @@ -5,11 +5,12 @@ from collections.abc import AsyncGenerator from types import CoroutineType from typing import Any, Callable -from shared.types.events.chunks import GenerationChunk, TokenChunk, TokenChunkData +from shared.types.events.chunks import GenerationChunk, TokenChunk from shared.types.tasks.common import ( ChatCompletionTaskParams, Task, ) +from shared.types.tasks.request import RequestId from shared.types.worker.commands_runner import ( ChatTaskMessage, ErrorResponse, @@ -183,14 +184,12 @@ class RunnerSupervisor: text=text, token=token, finish_reason=finish_reason ): yield TokenChunk( - task_id=task.task_id, + request_id=RequestId(uuid=task.task_id.uuid), idx=token, model=self.model_shard_meta.model_meta.model_id, - chunk_data=TokenChunkData( - text=text, - token_id=token, - finish_reason=finish_reason, - ), + text=text, + token_id=token, + finish_reason=finish_reason, ) case FinishedResponse(): break diff --git a/worker/tests/conftest.py b/worker/tests/conftest.py index 955fb81e..25e226c7 100644 --- a/worker/tests/conftest.py +++ b/worker/tests/conftest.py @@ -6,12 +6,11 @@ from typing import Callable import pytest +from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams from shared.types.common import NodeId from shared.types.models import ModelId, ModelMetadata from shared.types.state import State from shared.types.tasks.common import ( - ChatCompletionMessage, - ChatCompletionTaskParams, Task, TaskId, TaskStatus, diff --git a/worker/tests/test_supervisor.py b/worker/tests/test_supervisor.py index 686630e5..4fd1dfeb 100644 --- a/worker/tests/test_supervisor.py +++ b/worker/tests/test_supervisor.py @@ -45,9 +45,9 @@ async def test_supervisor_single_node_response( async for chunk in supervisor.stream_response(task=chat_task): if isinstance(chunk, TokenChunk): - full_response += chunk.chunk_data.text - if chunk.chunk_data.finish_reason: - stop_reason = chunk.chunk_data.finish_reason + full_response += chunk.text + if chunk.finish_reason: + stop_reason = chunk.finish_reason # Case-insensitive check for Paris in the response assert "paris" in full_response.lower(), ( @@ -87,13 +87,13 @@ async def test_supervisor_two_node_response( nonlocal full_response_0 async for chunk in supervisor_0.stream_response(task=chat_task): if isinstance(chunk, TokenChunk): - full_response_0 += chunk.chunk_data.text + full_response_0 += chunk.text async def collect_response_1(): nonlocal full_response_1 async for chunk in supervisor_1.stream_response(task=chat_task): if isinstance(chunk, TokenChunk): - full_response_1 += chunk.chunk_data.text + full_response_1 += chunk.text # Run both stream responses simultaneously _ = await asyncio.gather(collect_response_0(), collect_response_1()) @@ -148,10 +148,10 @@ async def test_supervisor_early_stopping( async for chunk in supervisor.stream_response(task=chat_task): if isinstance(chunk, TokenChunk): - full_response += chunk.chunk_data.text + full_response += chunk.text count += 1 - if chunk.chunk_data.finish_reason: - stop_reason = chunk.chunk_data.finish_reason + if chunk.finish_reason: + stop_reason = chunk.finish_reason print(f"full_response: {full_response}") diff --git a/worker/tests/test_worker_handlers.py b/worker/tests/test_worker_handlers.py index 04390658..d70c1ed5 100644 --- a/worker/tests/test_worker_handlers.py +++ b/worker/tests/test_worker_handlers.py @@ -7,7 +7,7 @@ from typing import Callable import pytest from shared.types.common import NodeId -from shared.types.events.chunks import TokenChunk, TokenChunkData +from shared.types.events.chunks import TokenChunk from shared.types.events.events import ChunkGenerated, RunnerStatusUpdated from shared.types.events.registry import Event from shared.types.tasks.common import Task @@ -107,7 +107,7 @@ async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, async for chunk in supervisor.stream_response(task=chat_task): if isinstance(chunk, TokenChunk): - full_response += chunk.chunk_data.text + full_response += chunk.text assert "42" in full_response.lower(), ( f"Expected '42' in response, but got: {full_response}" @@ -175,7 +175,7 @@ async def test_execute_task_op( assert isinstance(events[-1].runner_status, LoadedRunnerStatus) # It should not have failed. gen_events: list[ChunkGenerated] = [x for x in events if isinstance(x, ChunkGenerated)] - text_chunks: list[TokenChunkData] = [x.chunk.chunk_data for x in gen_events if isinstance(x.chunk.chunk_data, TokenChunkData)] + text_chunks: list[TokenChunk] = [x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)] assert len(text_chunks) == len(events) - 2 output_text = ''.join([x.text for x in text_chunks])