This commit is contained in:
Matt Beton
2025-07-22 21:21:12 +01:00
committed by GitHub
parent 596d9fc9d0
commit 14b3c4a6be
18 changed files with 527 additions and 284 deletions

View File

@@ -1,25 +1,157 @@
from typing import Protocol
import asyncio
import time
from asyncio.queues import Queue
from collections.abc import AsyncGenerator
from typing import List, Optional, Sequence, final
from shared.types.graphs.topology import Topology
from shared.types.models import ModelId, ModelMetadata
from shared.types.worker.common import InstanceId
from shared.types.worker.downloads import DownloadProgress
from shared.types.worker.instances import Instance
import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.events.chunks import TokenChunk
from shared.types.events.components import EventFromEventLog
from shared.types.events.events import ChunkGenerated
from shared.types.events.registry import Event
from shared.types.tasks.common import ChatCompletionTaskParams
from shared.types.tasks.request import APIRequest, RequestId
class ClusterAPI(Protocol):
def get_topology(self) -> Topology: ...
class Message(BaseModel):
role: str
content: str
def list_instances(self) -> list[Instance]: ...
class StreamingChoiceResponse(BaseModel):
index: int
delta: Message
finish_reason: Optional[str] = None
def get_instance(self, instance_id: InstanceId) -> Instance: ...
def create_instance(self, model_id: ModelId) -> InstanceId: ...
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[StreamingChoiceResponse]
def remove_instance(self, instance_id: InstanceId) -> None: ...
def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
return ChatCompletionResponse(
id='abc',
created=int(time.time()),
model='idk',
choices=[
StreamingChoiceResponse(
index=0,
delta=Message(
role='assistant',
content=chunk.text
),
finish_reason=chunk.finish_reason
)
]
)
def get_model_metadata(self, model_id: ModelId) -> ModelMetadata: ...
def download_model(self, model_id: ModelId) -> None: ...
@final
class API:
def __init__(self, command_queue: Queue[APIRequest], global_events: AsyncSQLiteEventStorage) -> None:
self._app = FastAPI()
self._setup_routes()
def get_download_progress(self, model_id: ModelId) -> DownloadProgress: ...
self.command_queue = command_queue
self.global_events = global_events
def _setup_routes(self) -> None:
# self._app.get("/topology/control_plane")(self.get_control_plane_topology)
# self._app.get("/topology/data_plane")(self.get_data_plane_topology)
# self._app.get("/instances/list")(self.list_instances)
# self._app.post("/instances/create")(self.create_instance)
# self._app.get("/instance/{instance_id}/read")(self.get_instance)
# self._app.delete("/instance/{instance_id}/delete")(self.remove_instance)
# self._app.get("/model/{model_id}/metadata")(self.get_model_data)
# self._app.post("/model/{model_id}/instances")(self.get_instances_by_model)
self._app.post("/v1/chat/completions")(self.chat_completions)
@property
def app(self) -> FastAPI:
return self._app
# def get_control_plane_topology(self):
# return {"message": "Hello, World!"}
# def get_data_plane_topology(self):
# return {"message": "Hello, World!"}
# def get_model_metadata(self, model_id: ModelId) -> ModelMetadata: ...
# def download_model(self, model_id: ModelId) -> None: ...
# def list_instances(self):
# return {"message": "Hello, World!"}
# def create_instance(self, model_id: ModelId) -> InstanceId: ...
# def get_instance(self, instance_id: InstanceId) -> Instance: ...
# def remove_instance(self, instance_id: InstanceId) -> None: ...
# def get_model_data(self, model_id: ModelId) -> ModelInfo: ...
# def get_instances_by_model(self, model_id: ModelId) -> list[Instance]: ...
async def _generate_chat_stream(self, payload: ChatCompletionTaskParams) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
events = await self.global_events.get_events_since(0)
prev_idx = await self.global_events.get_last_idx()
# At the moment, we just create the task in the API.
# In the future, a `Request` will be created here and they will be bundled into `Task` objects by the master.
request_id=RequestId()
request = APIRequest(
request_id=request_id,
request_params=payload,
)
await self.command_queue.put(request)
finished = False
while not finished:
await asyncio.sleep(0.01)
events: Sequence[EventFromEventLog[Event]] = await self.global_events.get_events_since(prev_idx)
# TODO: Can do this with some better functionality to tail event log into an AsyncGenerator.
prev_idx = events[-1].idx_in_log if events else prev_idx
for wrapped_event in events:
event = wrapped_event.event
if isinstance(event, ChunkGenerated) and event.request_id == request_id:
assert isinstance(event.chunk, TokenChunk)
chunk_response: ChatCompletionResponse = chunk_to_response(event.chunk)
print(chunk_response)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if event.chunk.finish_reason is not None:
yield "data: [DONE]"
finished = True
return
async def chat_completions(self, payload: ChatCompletionTaskParams) -> StreamingResponse:
"""Handle chat completions with proper streaming response."""
return StreamingResponse(
self._generate_chat_stream(payload),
media_type="text/plain"
)
def start_fastapi_server(
command_queue: Queue[APIRequest],
global_events: AsyncSQLiteEventStorage,
host: str = "0.0.0.0",
port: int = 8000,
):
api = API(command_queue, global_events)
uvicorn.run(api.app, host=host, port=port)

View File

@@ -1,171 +1,97 @@
from contextlib import asynccontextmanager
from logging import Logger, LogRecord
from queue import Queue as PQueue
import asyncio
import threading
from asyncio.queues import Queue
from logging import Logger
from fastapi import FastAPI
from master.env import MasterEnvironmentSchema
from master.logging import (
MasterUninitializedLogEntry,
)
from shared.constants import EXO_MASTER_STATE
from shared.event_loops.main import NodeEventLoopProtocol
from shared.logger import (
FilterLogByType,
LogEntryType,
attach_to_queue,
configure_logger,
create_queue_listener,
log,
)
from shared.types.models import ModelId, ModelMetadata
from shared.types.state import State
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import Instance
from master.api import start_fastapi_server
from shared.db.sqlite.config import EventLogConfig
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.types.common import NodeId
from shared.types.events.chunks import TokenChunk
from shared.types.events.events import ChunkGenerated
from shared.types.tasks.request import APIRequest, RequestId
# Restore State
def get_state(logger: Logger) -> State:
if EXO_MASTER_STATE.exists():
with open(EXO_MASTER_STATE, "r") as f:
return State.model_validate_json(f.read())
else:
log(logger, MasterUninitializedLogEntry())
return State()
## TODO: Hook this up properly
async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, request_id: RequestId):
model_id = "testmodelabc"
for i in range(10):
await asyncio.sleep(0.1)
# Create the event with proper types and consistent IDs
chunk_event = ChunkGenerated(
request_id=request_id,
chunk=TokenChunk(
request_id=request_id, # Use the same task_id
idx=i,
model=model_id, # Use the same model_id
text=f'text{i}',
token_id=i
)
)
# ChunkGenerated needs to be cast to the expected BaseEvent type
await events_log.append_events(
[chunk_event],
origin=NodeId()
)
await asyncio.sleep(0.1)
# FastAPI Dependencies
def check_env_vars_defined(data: object, logger: Logger) -> MasterEnvironmentSchema:
if not isinstance(data, MasterEnvironmentSchema):
raise RuntimeError("Environment Variables Not Found")
return data
# Create the event with proper types and consistent IDs
chunk_event = ChunkGenerated(
request_id=request_id,
chunk=TokenChunk(
request_id=request_id, # Use the same task_id
idx=11,
model=model_id, # Use the same model_id
text=f'text{11}',
token_id=11,
finish_reason='stop'
)
)
def get_state_dependency(data: object, logger: Logger) -> State:
if not isinstance(data, State):
raise RuntimeError("Master State Not Found")
return data
# Takes Care Of All States And Events Related To The Master
class MasterEventLoopProtocol(NodeEventLoopProtocol): ...
@asynccontextmanager
async def lifespan(app: FastAPI):
logger = configure_logger("master")
telemetry_queue: PQueue[LogRecord] = PQueue()
metrics_queue: PQueue[LogRecord] = PQueue()
cluster_queue: PQueue[LogRecord] = PQueue()
attach_to_queue(
logger,
[
FilterLogByType(log_types={LogEntryType.telemetry}),
],
telemetry_queue,
)
attach_to_queue(
logger,
[
FilterLogByType(log_types={LogEntryType.metrics}),
],
metrics_queue,
)
attach_to_queue(
logger,
[
FilterLogByType(log_types={LogEntryType.cluster}),
],
cluster_queue,
# ChunkGenerated needs to be cast to the expected BaseEvent type
await events_log.append_events(
[chunk_event],
origin=NodeId()
)
# TODO: Add Handlers For Pushing Logs To Remote Services
telemetry_listener = create_queue_listener(telemetry_queue, [])
metrics_listener = create_queue_listener(metrics_queue, [])
cluster_listener = create_queue_listener(cluster_queue, [])
telemetry_listener.start()
metrics_listener.start()
cluster_listener.start()
# # Get validated environment
# env = get_validated_env(MasterEnvironmentSchema, logger)
# # Initialize event log manager (creates both worker and global event DBs)
# event_log_config = EventLogConfig() # Uses default config
# event_log_manager = EventLogManager(
# config=event_log_config,
# logger=logger
# )
# await event_log_manager.initialize()
# # Store for use in API handlers
# app.state.event_log_manager = event_log_manager
# # Initialize forwarder if configured
# if env.FORWARDER_BINARY_PATH:
# forwarder_supervisor = ForwarderSupervisor(
# forwarder_binary_path=env.FORWARDER_BINARY_PATH,
# logger=logger
# )
# # Start as replica by default (until elected)
# await forwarder_supervisor.start_as_replica()
# # Create election callbacks for Rust election system
# election_callbacks = ElectionCallbacks(
# forwarder_supervisor=forwarder_supervisor,
# logger=logger
# )
# # Make callbacks available for Rust code to invoke
# app.state.election_callbacks = election_callbacks
# # Log status
# logger.info(
# f"Forwarder supervisor initialized. Running: {forwarder_supervisor.is_running}"
# )
# else:
# logger.warning("No forwarder binary path configured")
# forwarder_supervisor = None
# initial_state = get_master_state(logger)
# app.state.master_event_loop = MasterEventLoop()
# await app.state.master_event_loop.start()
yield
# await app.state.master_event_loop.stop()
app = FastAPI(lifespan=lifespan)
async def main():
logger = Logger(name='master_logger')
event_log_manager = EventLogManager(EventLogConfig(), logger=logger)
await event_log_manager.initialize()
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
command_queue: Queue[APIRequest] = asyncio.Queue()
api_thread = threading.Thread(
target=start_fastapi_server,
args=(
command_queue,
global_events,
),
daemon=True
)
api_thread.start()
print('Running FastAPI server in a separate thread. Listening on port 8000.')
while True:
# master loop
if not command_queue.empty():
command = await command_queue.get()
print(command)
await fake_tokens_task(global_events, request_id=command.request_id)
await asyncio.sleep(0.01)
@app.get("/topology")
def get_topology():
return {"message": "Hello, World!"}
@app.get("/instances/list")
def list_instances():
return {"message": "Hello, World!"}
@app.post("/instances/create")
def create_instance(model_id: ModelId) -> InstanceId: ...
@app.get("/instance/{instance_id}/read")
def get_instance(instance_id: InstanceId) -> Instance: ...
@app.delete("/instance/{instance_id}/delete")
def remove_instance(instance_id: InstanceId) -> None: ...
@app.get("/model/{model_id}/metadata")
def get_model_metadata(model_id: ModelId) -> ModelMetadata: ...
@app.post("/model/{model_id}/instances")
def get_instances_by_model(model_id: ModelId) -> list[Instance]: ...
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -7,6 +7,7 @@ requires-python = ">=3.13"
dependencies = [
"exo-shared",
"fastapi>=0.116.0",
"uvicorn>=0.35.0",
]
[build-system]

View File

@@ -0,0 +1,78 @@
import asyncio
import functools
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
ParamSpec,
TypeVar,
final,
)
import openai
import pytest
from openai._streaming import AsyncStream
from openai.types.chat import (
ChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice
from master.main import main as master_main
_P = ParamSpec("_P")
_R = TypeVar("_R")
OPENAI_API_KEY: str = "<YOUR_OPENAI_API_KEY>"
OPENAI_API_URL: str = "http://0.0.0.0:8000/v1"
def with_master_main(
func: Callable[_P, Awaitable[_R]]
) -> Callable[_P, Coroutine[Any, Any, _R]]:
@pytest.mark.asyncio
@functools.wraps(func)
async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
master_task = asyncio.create_task(master_main())
try:
return await func(*args, **kwargs)
finally:
master_task.cancel()
with pytest.raises(asyncio.CancelledError):
await master_task
return wrapper
@final
class ChatMessage:
"""Strictly-typed chat message for OpenAI API."""
def __init__(self, role: str, content: str) -> None:
self.role = role
self.content = content
def to_openai(self) -> ChatCompletionMessageParam:
if self.role == "user":
return {"role": "user", "content": self.content} # type: ChatCompletionUserMessageParam
elif self.role == "assistant":
return {"role": "assistant", "content": self.content} # type: ChatCompletionAssistantMessageParam
elif self.role == "system":
return {"role": "system", "content": self.content} # type: ChatCompletionSystemMessageParam
else:
raise ValueError(f"Unsupported role: {self.role}")
async def stream_chatgpt_response(
messages: list[ChatMessage],
model: str = "gpt-3.5-turbo",
) -> AsyncGenerator[Choice, None]:
client = openai.AsyncOpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_API_URL,
)
openai_messages: list[ChatCompletionMessageParam] = [m.to_openai() for m in messages]
stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
model=model,
messages=openai_messages,
stream=True,
)
async for chunk in stream:
for choice in chunk.choices:
yield choice

47
master/tests/test_api.py Normal file
View File

@@ -0,0 +1,47 @@
import asyncio
import pytest
from master.tests.api_utils_test import (
ChatMessage,
stream_chatgpt_response,
with_master_main,
)
@with_master_main
@pytest.mark.asyncio
async def test_master_api_multiple_response_sequential() -> None:
messages = [
ChatMessage(role="user", content="Hello, who are you?")
]
token_count = 0
text: str = ""
async for choice in stream_chatgpt_response(messages):
print(choice, flush=True)
if choice.delta and choice.delta.content:
text += choice.delta.content
token_count += 1
if choice.finish_reason:
break
assert token_count >= 3, f"Expected at least 3 tokens, got {token_count}"
assert len(text) > 0, "Expected non-empty response text"
await asyncio.sleep(0.1)
messages = [
ChatMessage(role="user", content="What time is it in France?")
]
token_count = 0
text = "" # re-initialize, do not redeclare type
async for choice in stream_chatgpt_response(messages):
print(choice, flush=True)
if choice.delta and choice.delta.content:
text += choice.delta.content
token_count += 1
if choice.finish_reason:
break
assert token_count >= 3, f"Expected at least 3 tokens, got {token_count}"
assert len(text) > 0, "Expected non-empty response text"

View File

@@ -116,6 +116,27 @@ class AsyncSQLiteEventStorage:
))
return events
async def get_last_idx(self) -> int:
if self._closed:
raise RuntimeError("Storaged is closed")
assert self._engine is not None
async with AsyncSession(self._engine) as session:
result = await session.execute(
text("SELECT rowid, origin, event_data FROM events ORDER BY rowid DESC LIMIT 1"),
{}
)
rows = result.fetchall()
if len(rows) == 0:
return 0
if len(rows) == 1:
row = rows[0]
return cast(int, row[0])
else:
raise AssertionError("There should have been at most 1 row returned from this SQL query.")
async def close(self) -> None:
"""Close the storage connection and cleanup resources."""
@@ -211,12 +232,12 @@ class AsyncSQLiteEventStorage:
try:
async with AsyncSession(self._engine) as session:
for event, origin in batch:
for event, origin in batch:
stored_event = StoredEvent(
origin=str(origin.uuid),
event_type=str(event.event_type),
event_type=event.event_type,
event_id=str(event.event_id),
event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion
event_data=event.model_dump(mode='json') # Serialize UUIDs and other objects to JSON-compatible strings
)
session.add(stored_event)

View File

@@ -11,12 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig
from shared.types.common import NodeId
from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData
from shared.types.events.chunks import ChunkType, TokenChunk
from shared.types.events.events import (
ChunkGenerated,
EventType,
)
from shared.types.tasks.common import TaskId
from shared.types.tasks.request import RequestId
# Type ignore comment for all protected member access in this test file
# pyright: reportPrivateUsage=false
@@ -162,6 +162,41 @@ class TestAsyncSQLiteEventStorage:
await storage.close()
@pytest.mark.asyncio
async def test_get_last_idx(self, temp_db_path: Path, sample_node_id: NodeId) -> None:
"""Test that rowid returns correctly from db."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(db_path=temp_db_path, batch_size=default_config.batch_size, batch_timeout_ms=default_config.batch_timeout_ms, debounce_ms=default_config.debounce_ms, max_age_ms=default_config.max_age_ms)
await storage.start()
# Insert multiple records
test_records = [
{"event_type": "test_event_1", "data": "first"},
{"event_type": "test_event_2", "data": "second"},
{"event_type": "test_event_3", "data": "third"}
]
assert storage._engine is not None
async with AsyncSession(storage._engine) as session:
for record in test_records:
await session.execute(
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
{
"origin": str(sample_node_id.uuid),
"event_type": record["event_type"],
"event_id": str(uuid4()),
"event_data": json.dumps(record)
}
)
await session.commit()
last_idx = await storage.get_last_idx()
assert last_idx == 3
await storage.close()
@pytest.mark.asyncio
async def test_rowid_with_multiple_origins(self, temp_db_path: Path) -> None:
"""Test rowid sequence across multiple origins."""
@@ -404,22 +439,19 @@ class TestAsyncSQLiteEventStorage:
await storage.start()
# Create a ChunkGenerated event with nested TokenChunk
task_id = TaskId(uuid=uuid4())
chunk_data = TokenChunkData(
request_id = RequestId(uuid=uuid4())
token_chunk = TokenChunk(
text="Hello, world!",
token_id=42,
finish_reason="stop"
)
token_chunk = TokenChunk(
chunk_data=chunk_data,
finish_reason="stop",
chunk_type=ChunkType.token,
task_id=task_id,
request_id=request_id,
idx=0,
model="test-model"
)
chunk_generated_event = ChunkGenerated(
task_id=task_id,
request_id=request_id,
chunk=token_chunk
)
@@ -441,19 +473,19 @@ class TestAsyncSQLiteEventStorage:
retrieved_event = retrieved_event_wrapper.event
assert isinstance(retrieved_event, ChunkGenerated)
assert retrieved_event.event_type == EventType.ChunkGenerated
assert retrieved_event.task_id == task_id
assert retrieved_event.request_id == request_id
# Verify the nested chunk was deserialized correctly
retrieved_chunk = retrieved_event.chunk
assert isinstance(retrieved_chunk, TokenChunk)
assert retrieved_chunk.chunk_type == ChunkType.token
assert retrieved_chunk.task_id == task_id
assert retrieved_chunk.request_id == request_id
assert retrieved_chunk.idx == 0
assert retrieved_chunk.model == "test-model"
# Verify the chunk data
assert retrieved_chunk.chunk_data.text == "Hello, world!"
assert retrieved_chunk.chunk_data.token_id == 42
assert retrieved_chunk.chunk_data.finish_reason == "stop"
assert retrieved_chunk.text == "Hello, world!"
assert retrieved_chunk.token_id == 42
assert retrieved_chunk.finish_reason == "stop"
await storage.close()

View File

@@ -1,11 +1,34 @@
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel
from shared.types.tasks.common import ChatCompletionTaskParams, TaskId
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: str | None = None
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_call_id: str | None = None
function_call: dict[str, Any] | None = None
class ChatTask(BaseModel):
task_id: TaskId
kind: Literal["chat"] = "chat"
task_data: ChatCompletionTaskParams
class ChatCompletionTaskParams(BaseModel):
model: str
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
logit_bias: dict[str, int] | None = None
logprobs: bool | None = None
top_logprobs: int | None = None
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
response_format: dict[str, Any] | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None

View File

@@ -1,13 +1,11 @@
from enum import Enum
from typing import Annotated, Literal
# from openai.types.chat.chat_completion import ChatCompletion
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from pydantic import BaseModel, Field, TypeAdapter
from shared.openai_compat import FinishReason
from shared.types.models import ModelId
from shared.types.tasks.common import TaskId
from shared.types.tasks.request import RequestId
class ChunkType(str, Enum):
@@ -17,38 +15,21 @@ class ChunkType(str, Enum):
class BaseChunk[ChunkTypeT: ChunkType](BaseModel):
chunk_type: ChunkTypeT
task_id: TaskId
request_id: RequestId
idx: int
model: ModelId
###
class TokenChunkData(BaseModel):
class TokenChunk(BaseChunk[ChunkType.token]):
chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True)
text: str
token_id: int
finish_reason: FinishReason | None = None
class ImageChunkData(BaseModel):
data: bytes
###
class TokenChunk(BaseChunk[ChunkType.token]):
chunk_data: TokenChunkData
chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True)
class ImageChunk(BaseChunk[ChunkType.image]):
chunk_data: ImageChunkData
chunk_type: Literal[ChunkType.image] = Field(default=ChunkType.image, frozen=True)
###
data: bytes
GenerationChunk = Annotated[TokenChunk | ImageChunk, Field(discriminator="chunk_type")]
GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(GenerationChunk)
@@ -60,10 +41,8 @@ GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(Generatio
# my_chunk: dict[str, Any] = TokenChunk(
# task_id=TaskId('nicerid'),
# idx=0,
# chunk_data=TokenChunkData(
# text='hello',
# token_id=12,
# ),
# text='hello',
# token_id=12,
# chunk_type=ChunkType.token,
# model='llama-3.1',
# ).model_dump()

View File

@@ -17,6 +17,7 @@ from shared.types.graphs.topology import (
)
from shared.types.profiling.common import NodePerformanceProfile
from shared.types.tasks.common import Task, TaskId, TaskStatus
from shared.types.tasks.request import RequestId
from shared.types.worker.common import InstanceId, NodeStatus
from shared.types.worker.instances import InstanceParams, TypeOfInstance
from shared.types.worker.runners import RunnerId, RunnerStatus
@@ -111,7 +112,7 @@ class WorkerDisconnected(BaseEvent[EventType.WorkerDisconnected]):
class ChunkGenerated(BaseEvent[EventType.ChunkGenerated]):
event_type: Literal[EventType.ChunkGenerated] = EventType.ChunkGenerated
task_id: TaskId
request_id: RequestId
chunk: GenerationChunk

View File

@@ -1,8 +1,8 @@
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel
from shared.types.api import ChatCompletionTaskParams
from shared.types.common import NewUUID
from shared.types.worker.common import InstanceId
@@ -10,11 +10,9 @@ from shared.types.worker.common import InstanceId
class TaskId(NewUUID):
pass
class TaskType(str, Enum):
ChatCompletion = "ChatCompletion"
class TaskStatus(str, Enum):
Pending = "Pending"
Running = "Running"
@@ -22,42 +20,10 @@ class TaskStatus(str, Enum):
Failed = "Failed"
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: str | None = None
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_call_id: str | None = None
function_call: dict[str, Any] | None = None
class ChatCompletionTaskParams(BaseModel):
task_type: Literal[TaskType.ChatCompletion] = TaskType.ChatCompletion
model: str
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
logit_bias: dict[str, int] | None = None
logprobs: bool | None = None
top_logprobs: int | None = None
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
response_format: dict[str, Any] | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
class Task(BaseModel):
task_id: TaskId
task_type: TaskType # redundant atm as we only have 1 task type.
instance_id: InstanceId
task_type: TaskType
task_status: TaskStatus
task_params: ChatCompletionTaskParams

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
from shared.types.api import ChatCompletionTaskParams
from shared.types.common import NewUUID
class RequestId(NewUUID):
pass
class APIRequest(BaseModel):
request_id: RequestId
request_params: ChatCompletionTaskParams

24
uv.lock generated
View File

@@ -154,6 +154,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" },
]
[[package]]
name = "click"
version = "8.2.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" },
]
[[package]]
name = "distro"
version = "1.9.0"
@@ -219,12 +228,14 @@ source = { editable = "master" }
dependencies = [
{ name = "exo-shared", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "uvicorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[package.metadata]
requires-dist = [
{ name = "exo-shared", editable = "shared" },
{ name = "fastapi", specifier = ">=0.116.0" },
{ name = "uvicorn", specifier = ">=0.35.0" },
]
[[package]]
@@ -1129,6 +1140,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" },
]
[[package]]
name = "uvicorn"
version = "0.35.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5e/42/e0e305207bb88c6b8d3061399c6a961ffe5fbb7e2aa63c9234df7259e9cd/uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01", size = 78473, upload-time = "2025-06-28T16:15:46.058Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d2/e2/dc81b1bd1dcfe91735810265e9d26bc8ec5da45b4c0f6237e286819194c3/uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a", size = 66406, upload-time = "2025-06-28T16:15:44.816Z" },
]
[[package]]
name = "yarl"
version = "1.20.1"

View File

@@ -236,12 +236,15 @@ class Worker:
assigned_runner.status = RunningRunnerStatus()
await queue.put(assigned_runner.status_update_event())
try:
async for chunk in assigned_runner.runner.stream_response(
task=op.task,
request_started_callback=partial(running_callback, queue)):
await queue.put(ChunkGenerated(
task_id=op.task.task_id,
# todo: at some point we will no longer have a bijection between task_id and row_id.
# So we probably want to store a mapping between these two in our Worker object.
request_id=chunk.request_id,
chunk=chunk
))

View File

@@ -5,11 +5,12 @@ from collections.abc import AsyncGenerator
from types import CoroutineType
from typing import Any, Callable
from shared.types.events.chunks import GenerationChunk, TokenChunk, TokenChunkData
from shared.types.events.chunks import GenerationChunk, TokenChunk
from shared.types.tasks.common import (
ChatCompletionTaskParams,
Task,
)
from shared.types.tasks.request import RequestId
from shared.types.worker.commands_runner import (
ChatTaskMessage,
ErrorResponse,
@@ -183,14 +184,12 @@ class RunnerSupervisor:
text=text, token=token, finish_reason=finish_reason
):
yield TokenChunk(
task_id=task.task_id,
request_id=RequestId(uuid=task.task_id.uuid),
idx=token,
model=self.model_shard_meta.model_meta.model_id,
chunk_data=TokenChunkData(
text=text,
token_id=token,
finish_reason=finish_reason,
),
text=text,
token_id=token,
finish_reason=finish_reason,
)
case FinishedResponse():
break

View File

@@ -6,12 +6,11 @@ from typing import Callable
import pytest
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import NodeId
from shared.types.models import ModelId, ModelMetadata
from shared.types.state import State
from shared.types.tasks.common import (
ChatCompletionMessage,
ChatCompletionTaskParams,
Task,
TaskId,
TaskStatus,

View File

@@ -45,9 +45,9 @@ async def test_supervisor_single_node_response(
async for chunk in supervisor.stream_response(task=chat_task):
if isinstance(chunk, TokenChunk):
full_response += chunk.chunk_data.text
if chunk.chunk_data.finish_reason:
stop_reason = chunk.chunk_data.finish_reason
full_response += chunk.text
if chunk.finish_reason:
stop_reason = chunk.finish_reason
# Case-insensitive check for Paris in the response
assert "paris" in full_response.lower(), (
@@ -87,13 +87,13 @@ async def test_supervisor_two_node_response(
nonlocal full_response_0
async for chunk in supervisor_0.stream_response(task=chat_task):
if isinstance(chunk, TokenChunk):
full_response_0 += chunk.chunk_data.text
full_response_0 += chunk.text
async def collect_response_1():
nonlocal full_response_1
async for chunk in supervisor_1.stream_response(task=chat_task):
if isinstance(chunk, TokenChunk):
full_response_1 += chunk.chunk_data.text
full_response_1 += chunk.text
# Run both stream responses simultaneously
_ = await asyncio.gather(collect_response_0(), collect_response_1())
@@ -148,10 +148,10 @@ async def test_supervisor_early_stopping(
async for chunk in supervisor.stream_response(task=chat_task):
if isinstance(chunk, TokenChunk):
full_response += chunk.chunk_data.text
full_response += chunk.text
count += 1
if chunk.chunk_data.finish_reason:
stop_reason = chunk.chunk_data.finish_reason
if chunk.finish_reason:
stop_reason = chunk.finish_reason
print(f"full_response: {full_response}")

View File

@@ -7,7 +7,7 @@ from typing import Callable
import pytest
from shared.types.common import NodeId
from shared.types.events.chunks import TokenChunk, TokenChunkData
from shared.types.events.chunks import TokenChunk
from shared.types.events.events import ChunkGenerated, RunnerStatusUpdated
from shared.types.events.registry import Event
from shared.types.tasks.common import Task
@@ -107,7 +107,7 @@ async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
async for chunk in supervisor.stream_response(task=chat_task):
if isinstance(chunk, TokenChunk):
full_response += chunk.chunk_data.text
full_response += chunk.text
assert "42" in full_response.lower(), (
f"Expected '42' in response, but got: {full_response}"
@@ -175,7 +175,7 @@ async def test_execute_task_op(
assert isinstance(events[-1].runner_status, LoadedRunnerStatus) # It should not have failed.
gen_events: list[ChunkGenerated] = [x for x in events if isinstance(x, ChunkGenerated)]
text_chunks: list[TokenChunkData] = [x.chunk.chunk_data for x in gen_events if isinstance(x.chunk.chunk_data, TokenChunkData)]
text_chunks: list[TokenChunk] = [x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)]
assert len(text_chunks) == len(events) - 2
output_text = ''.join([x.text for x in text_chunks])