mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
New API!
This commit is contained in:
162
master/api.py
162
master/api.py
@@ -1,25 +1,157 @@
|
||||
from typing import Protocol
|
||||
import asyncio
|
||||
import time
|
||||
from asyncio.queues import Queue
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import List, Optional, Sequence, final
|
||||
|
||||
from shared.types.graphs.topology import Topology
|
||||
from shared.types.models import ModelId, ModelMetadata
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.downloads import DownloadProgress
|
||||
from shared.types.worker.instances import Instance
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.events.events import ChunkGenerated
|
||||
from shared.types.events.registry import Event
|
||||
from shared.types.tasks.common import ChatCompletionTaskParams
|
||||
from shared.types.tasks.request import APIRequest, RequestId
|
||||
|
||||
|
||||
class ClusterAPI(Protocol):
|
||||
def get_topology(self) -> Topology: ...
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def list_instances(self) -> list[Instance]: ...
|
||||
class StreamingChoiceResponse(BaseModel):
|
||||
index: int
|
||||
delta: Message
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
def get_instance(self, instance_id: InstanceId) -> Instance: ...
|
||||
|
||||
def create_instance(self, model_id: ModelId) -> InstanceId: ...
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[StreamingChoiceResponse]
|
||||
|
||||
def remove_instance(self, instance_id: InstanceId) -> None: ...
|
||||
def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id='abc',
|
||||
created=int(time.time()),
|
||||
model='idk',
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=Message(
|
||||
role='assistant',
|
||||
content=chunk.text
|
||||
),
|
||||
finish_reason=chunk.finish_reason
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def get_model_metadata(self, model_id: ModelId) -> ModelMetadata: ...
|
||||
|
||||
def download_model(self, model_id: ModelId) -> None: ...
|
||||
@final
|
||||
class API:
|
||||
def __init__(self, command_queue: Queue[APIRequest], global_events: AsyncSQLiteEventStorage) -> None:
|
||||
self._app = FastAPI()
|
||||
self._setup_routes()
|
||||
|
||||
def get_download_progress(self, model_id: ModelId) -> DownloadProgress: ...
|
||||
self.command_queue = command_queue
|
||||
self.global_events = global_events
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
# self._app.get("/topology/control_plane")(self.get_control_plane_topology)
|
||||
# self._app.get("/topology/data_plane")(self.get_data_plane_topology)
|
||||
# self._app.get("/instances/list")(self.list_instances)
|
||||
# self._app.post("/instances/create")(self.create_instance)
|
||||
# self._app.get("/instance/{instance_id}/read")(self.get_instance)
|
||||
# self._app.delete("/instance/{instance_id}/delete")(self.remove_instance)
|
||||
# self._app.get("/model/{model_id}/metadata")(self.get_model_data)
|
||||
# self._app.post("/model/{model_id}/instances")(self.get_instances_by_model)
|
||||
self._app.post("/v1/chat/completions")(self.chat_completions)
|
||||
|
||||
@property
|
||||
def app(self) -> FastAPI:
|
||||
return self._app
|
||||
|
||||
# def get_control_plane_topology(self):
|
||||
# return {"message": "Hello, World!"}
|
||||
|
||||
# def get_data_plane_topology(self):
|
||||
# return {"message": "Hello, World!"}
|
||||
|
||||
# def get_model_metadata(self, model_id: ModelId) -> ModelMetadata: ...
|
||||
|
||||
# def download_model(self, model_id: ModelId) -> None: ...
|
||||
|
||||
# def list_instances(self):
|
||||
# return {"message": "Hello, World!"}
|
||||
|
||||
# def create_instance(self, model_id: ModelId) -> InstanceId: ...
|
||||
|
||||
# def get_instance(self, instance_id: InstanceId) -> Instance: ...
|
||||
|
||||
# def remove_instance(self, instance_id: InstanceId) -> None: ...
|
||||
|
||||
# def get_model_data(self, model_id: ModelId) -> ModelInfo: ...
|
||||
|
||||
# def get_instances_by_model(self, model_id: ModelId) -> list[Instance]: ...
|
||||
|
||||
async def _generate_chat_stream(self, payload: ChatCompletionTaskParams) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
events = await self.global_events.get_events_since(0)
|
||||
prev_idx = await self.global_events.get_last_idx()
|
||||
|
||||
# At the moment, we just create the task in the API.
|
||||
# In the future, a `Request` will be created here and they will be bundled into `Task` objects by the master.
|
||||
request_id=RequestId()
|
||||
|
||||
request = APIRequest(
|
||||
request_id=request_id,
|
||||
request_params=payload,
|
||||
)
|
||||
await self.command_queue.put(request)
|
||||
|
||||
finished = False
|
||||
while not finished:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
events: Sequence[EventFromEventLog[Event]] = await self.global_events.get_events_since(prev_idx)
|
||||
# TODO: Can do this with some better functionality to tail event log into an AsyncGenerator.
|
||||
prev_idx = events[-1].idx_in_log if events else prev_idx
|
||||
|
||||
for wrapped_event in events:
|
||||
event = wrapped_event.event
|
||||
if isinstance(event, ChunkGenerated) and event.request_id == request_id:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(event.chunk)
|
||||
print(chunk_response)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if event.chunk.finish_reason is not None:
|
||||
yield "data: [DONE]"
|
||||
finished = True
|
||||
|
||||
return
|
||||
|
||||
async def chat_completions(self, payload: ChatCompletionTaskParams) -> StreamingResponse:
|
||||
"""Handle chat completions with proper streaming response."""
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(payload),
|
||||
media_type="text/plain"
|
||||
)
|
||||
|
||||
|
||||
|
||||
def start_fastapi_server(
|
||||
command_queue: Queue[APIRequest],
|
||||
global_events: AsyncSQLiteEventStorage,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
):
|
||||
api = API(command_queue, global_events)
|
||||
|
||||
uvicorn.run(api.app, host=host, port=port)
|
||||
244
master/main.py
244
master/main.py
@@ -1,171 +1,97 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger, LogRecord
|
||||
from queue import Queue as PQueue
|
||||
import asyncio
|
||||
import threading
|
||||
from asyncio.queues import Queue
|
||||
from logging import Logger
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from master.env import MasterEnvironmentSchema
|
||||
from master.logging import (
|
||||
MasterUninitializedLogEntry,
|
||||
)
|
||||
from shared.constants import EXO_MASTER_STATE
|
||||
from shared.event_loops.main import NodeEventLoopProtocol
|
||||
from shared.logger import (
|
||||
FilterLogByType,
|
||||
LogEntryType,
|
||||
attach_to_queue,
|
||||
configure_logger,
|
||||
create_queue_listener,
|
||||
log,
|
||||
)
|
||||
from shared.types.models import ModelId, ModelMetadata
|
||||
from shared.types.state import State
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import Instance
|
||||
from master.api import start_fastapi_server
|
||||
from shared.db.sqlite.config import EventLogConfig
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.events.events import ChunkGenerated
|
||||
from shared.types.tasks.request import APIRequest, RequestId
|
||||
|
||||
|
||||
# Restore State
|
||||
def get_state(logger: Logger) -> State:
|
||||
if EXO_MASTER_STATE.exists():
|
||||
with open(EXO_MASTER_STATE, "r") as f:
|
||||
return State.model_validate_json(f.read())
|
||||
else:
|
||||
log(logger, MasterUninitializedLogEntry())
|
||||
return State()
|
||||
## TODO: Hook this up properly
|
||||
async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, request_id: RequestId):
|
||||
model_id = "testmodelabc"
|
||||
|
||||
for i in range(10):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Create the event with proper types and consistent IDs
|
||||
chunk_event = ChunkGenerated(
|
||||
request_id=request_id,
|
||||
chunk=TokenChunk(
|
||||
request_id=request_id, # Use the same task_id
|
||||
idx=i,
|
||||
model=model_id, # Use the same model_id
|
||||
text=f'text{i}',
|
||||
token_id=i
|
||||
)
|
||||
)
|
||||
|
||||
# ChunkGenerated needs to be cast to the expected BaseEvent type
|
||||
await events_log.append_events(
|
||||
[chunk_event],
|
||||
origin=NodeId()
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# FastAPI Dependencies
|
||||
def check_env_vars_defined(data: object, logger: Logger) -> MasterEnvironmentSchema:
|
||||
if not isinstance(data, MasterEnvironmentSchema):
|
||||
raise RuntimeError("Environment Variables Not Found")
|
||||
return data
|
||||
# Create the event with proper types and consistent IDs
|
||||
chunk_event = ChunkGenerated(
|
||||
request_id=request_id,
|
||||
chunk=TokenChunk(
|
||||
request_id=request_id, # Use the same task_id
|
||||
idx=11,
|
||||
model=model_id, # Use the same model_id
|
||||
text=f'text{11}',
|
||||
token_id=11,
|
||||
finish_reason='stop'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_state_dependency(data: object, logger: Logger) -> State:
|
||||
if not isinstance(data, State):
|
||||
raise RuntimeError("Master State Not Found")
|
||||
return data
|
||||
|
||||
|
||||
# Takes Care Of All States And Events Related To The Master
|
||||
class MasterEventLoopProtocol(NodeEventLoopProtocol): ...
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger = configure_logger("master")
|
||||
|
||||
telemetry_queue: PQueue[LogRecord] = PQueue()
|
||||
metrics_queue: PQueue[LogRecord] = PQueue()
|
||||
cluster_queue: PQueue[LogRecord] = PQueue()
|
||||
|
||||
attach_to_queue(
|
||||
logger,
|
||||
[
|
||||
FilterLogByType(log_types={LogEntryType.telemetry}),
|
||||
],
|
||||
telemetry_queue,
|
||||
)
|
||||
attach_to_queue(
|
||||
logger,
|
||||
[
|
||||
FilterLogByType(log_types={LogEntryType.metrics}),
|
||||
],
|
||||
metrics_queue,
|
||||
)
|
||||
attach_to_queue(
|
||||
logger,
|
||||
[
|
||||
FilterLogByType(log_types={LogEntryType.cluster}),
|
||||
],
|
||||
cluster_queue,
|
||||
# ChunkGenerated needs to be cast to the expected BaseEvent type
|
||||
await events_log.append_events(
|
||||
[chunk_event],
|
||||
origin=NodeId()
|
||||
)
|
||||
|
||||
# TODO: Add Handlers For Pushing Logs To Remote Services
|
||||
telemetry_listener = create_queue_listener(telemetry_queue, [])
|
||||
metrics_listener = create_queue_listener(metrics_queue, [])
|
||||
cluster_listener = create_queue_listener(cluster_queue, [])
|
||||
|
||||
telemetry_listener.start()
|
||||
metrics_listener.start()
|
||||
cluster_listener.start()
|
||||
|
||||
# # Get validated environment
|
||||
# env = get_validated_env(MasterEnvironmentSchema, logger)
|
||||
|
||||
# # Initialize event log manager (creates both worker and global event DBs)
|
||||
# event_log_config = EventLogConfig() # Uses default config
|
||||
# event_log_manager = EventLogManager(
|
||||
# config=event_log_config,
|
||||
# logger=logger
|
||||
# )
|
||||
# await event_log_manager.initialize()
|
||||
|
||||
# # Store for use in API handlers
|
||||
# app.state.event_log_manager = event_log_manager
|
||||
|
||||
# # Initialize forwarder if configured
|
||||
# if env.FORWARDER_BINARY_PATH:
|
||||
# forwarder_supervisor = ForwarderSupervisor(
|
||||
# forwarder_binary_path=env.FORWARDER_BINARY_PATH,
|
||||
# logger=logger
|
||||
# )
|
||||
# # Start as replica by default (until elected)
|
||||
# await forwarder_supervisor.start_as_replica()
|
||||
|
||||
# # Create election callbacks for Rust election system
|
||||
# election_callbacks = ElectionCallbacks(
|
||||
# forwarder_supervisor=forwarder_supervisor,
|
||||
# logger=logger
|
||||
# )
|
||||
|
||||
# # Make callbacks available for Rust code to invoke
|
||||
# app.state.election_callbacks = election_callbacks
|
||||
|
||||
# # Log status
|
||||
# logger.info(
|
||||
# f"Forwarder supervisor initialized. Running: {forwarder_supervisor.is_running}"
|
||||
# )
|
||||
# else:
|
||||
# logger.warning("No forwarder binary path configured")
|
||||
# forwarder_supervisor = None
|
||||
# initial_state = get_master_state(logger)
|
||||
# app.state.master_event_loop = MasterEventLoop()
|
||||
# await app.state.master_event_loop.start()
|
||||
|
||||
yield
|
||||
|
||||
# await app.state.master_event_loop.stop()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
async def main():
|
||||
logger = Logger(name='master_logger')
|
||||
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger=logger)
|
||||
await event_log_manager.initialize()
|
||||
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
|
||||
|
||||
command_queue: Queue[APIRequest] = asyncio.Queue()
|
||||
|
||||
api_thread = threading.Thread(
|
||||
target=start_fastapi_server,
|
||||
args=(
|
||||
command_queue,
|
||||
global_events,
|
||||
),
|
||||
daemon=True
|
||||
)
|
||||
api_thread.start()
|
||||
print('Running FastAPI server in a separate thread. Listening on port 8000.')
|
||||
|
||||
while True:
|
||||
# master loop
|
||||
if not command_queue.empty():
|
||||
command = await command_queue.get()
|
||||
|
||||
print(command)
|
||||
|
||||
await fake_tokens_task(global_events, request_id=command.request_id)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
@app.get("/topology")
|
||||
def get_topology():
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
|
||||
@app.get("/instances/list")
|
||||
def list_instances():
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
|
||||
@app.post("/instances/create")
|
||||
def create_instance(model_id: ModelId) -> InstanceId: ...
|
||||
|
||||
|
||||
@app.get("/instance/{instance_id}/read")
|
||||
def get_instance(instance_id: InstanceId) -> Instance: ...
|
||||
|
||||
|
||||
@app.delete("/instance/{instance_id}/delete")
|
||||
def remove_instance(instance_id: InstanceId) -> None: ...
|
||||
|
||||
|
||||
@app.get("/model/{model_id}/metadata")
|
||||
def get_model_metadata(model_id: ModelId) -> ModelMetadata: ...
|
||||
|
||||
|
||||
@app.post("/model/{model_id}/instances")
|
||||
def get_instances_by_model(model_id: ModelId) -> list[Instance]: ...
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -7,6 +7,7 @@ requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"exo-shared",
|
||||
"fastapi>=0.116.0",
|
||||
"uvicorn>=0.35.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
78
master/tests/api_utils_test.py
Normal file
78
master/tests/api_utils_test.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
final,
|
||||
)
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from openai._streaming import AsyncStream
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice
|
||||
|
||||
from master.main import main as master_main
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
OPENAI_API_KEY: str = "<YOUR_OPENAI_API_KEY>"
|
||||
OPENAI_API_URL: str = "http://0.0.0.0:8000/v1"
|
||||
|
||||
def with_master_main(
|
||||
func: Callable[_P, Awaitable[_R]]
|
||||
) -> Callable[_P, Coroutine[Any, Any, _R]]:
|
||||
@pytest.mark.asyncio
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
master_task = asyncio.create_task(master_main())
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
finally:
|
||||
master_task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await master_task
|
||||
return wrapper
|
||||
|
||||
@final
|
||||
class ChatMessage:
|
||||
"""Strictly-typed chat message for OpenAI API."""
|
||||
def __init__(self, role: str, content: str) -> None:
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
def to_openai(self) -> ChatCompletionMessageParam:
|
||||
if self.role == "user":
|
||||
return {"role": "user", "content": self.content} # type: ChatCompletionUserMessageParam
|
||||
elif self.role == "assistant":
|
||||
return {"role": "assistant", "content": self.content} # type: ChatCompletionAssistantMessageParam
|
||||
elif self.role == "system":
|
||||
return {"role": "system", "content": self.content} # type: ChatCompletionSystemMessageParam
|
||||
else:
|
||||
raise ValueError(f"Unsupported role: {self.role}")
|
||||
|
||||
async def stream_chatgpt_response(
|
||||
messages: list[ChatMessage],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
) -> AsyncGenerator[Choice, None]:
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_URL,
|
||||
)
|
||||
openai_messages: list[ChatCompletionMessageParam] = [m.to_openai() for m in messages]
|
||||
stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
for choice in chunk.choices:
|
||||
yield choice
|
||||
47
master/tests/test_api.py
Normal file
47
master/tests/test_api.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from master.tests.api_utils_test import (
|
||||
ChatMessage,
|
||||
stream_chatgpt_response,
|
||||
with_master_main,
|
||||
)
|
||||
|
||||
|
||||
@with_master_main
|
||||
@pytest.mark.asyncio
|
||||
async def test_master_api_multiple_response_sequential() -> None:
|
||||
messages = [
|
||||
ChatMessage(role="user", content="Hello, who are you?")
|
||||
]
|
||||
token_count = 0
|
||||
text: str = ""
|
||||
async for choice in stream_chatgpt_response(messages):
|
||||
print(choice, flush=True)
|
||||
if choice.delta and choice.delta.content:
|
||||
text += choice.delta.content
|
||||
token_count += 1
|
||||
if choice.finish_reason:
|
||||
break
|
||||
|
||||
assert token_count >= 3, f"Expected at least 3 tokens, got {token_count}"
|
||||
assert len(text) > 0, "Expected non-empty response text"
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="user", content="What time is it in France?")
|
||||
]
|
||||
token_count = 0
|
||||
text = "" # re-initialize, do not redeclare type
|
||||
async for choice in stream_chatgpt_response(messages):
|
||||
print(choice, flush=True)
|
||||
if choice.delta and choice.delta.content:
|
||||
text += choice.delta.content
|
||||
token_count += 1
|
||||
if choice.finish_reason:
|
||||
break
|
||||
|
||||
assert token_count >= 3, f"Expected at least 3 tokens, got {token_count}"
|
||||
assert len(text) > 0, "Expected non-empty response text"
|
||||
@@ -116,6 +116,27 @@ class AsyncSQLiteEventStorage:
|
||||
))
|
||||
|
||||
return events
|
||||
|
||||
async def get_last_idx(self) -> int:
|
||||
if self._closed:
|
||||
raise RuntimeError("Storaged is closed")
|
||||
|
||||
assert self._engine is not None
|
||||
|
||||
async with AsyncSession(self._engine) as session:
|
||||
result = await session.execute(
|
||||
text("SELECT rowid, origin, event_data FROM events ORDER BY rowid DESC LIMIT 1"),
|
||||
{}
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
if len(rows) == 0:
|
||||
return 0
|
||||
if len(rows) == 1:
|
||||
row = rows[0]
|
||||
return cast(int, row[0])
|
||||
else:
|
||||
raise AssertionError("There should have been at most 1 row returned from this SQL query.")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the storage connection and cleanup resources."""
|
||||
@@ -211,12 +232,12 @@ class AsyncSQLiteEventStorage:
|
||||
|
||||
try:
|
||||
async with AsyncSession(self._engine) as session:
|
||||
for event, origin in batch:
|
||||
for event, origin in batch:
|
||||
stored_event = StoredEvent(
|
||||
origin=str(origin.uuid),
|
||||
event_type=str(event.event_type),
|
||||
event_type=event.event_type,
|
||||
event_id=str(event.event_id),
|
||||
event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion
|
||||
event_data=event.model_dump(mode='json') # Serialize UUIDs and other objects to JSON-compatible strings
|
||||
)
|
||||
session.add(stored_event)
|
||||
|
||||
|
||||
@@ -11,12 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData
|
||||
from shared.types.events.chunks import ChunkType, TokenChunk
|
||||
from shared.types.events.events import (
|
||||
ChunkGenerated,
|
||||
EventType,
|
||||
)
|
||||
from shared.types.tasks.common import TaskId
|
||||
from shared.types.tasks.request import RequestId
|
||||
|
||||
# Type ignore comment for all protected member access in this test file
|
||||
# pyright: reportPrivateUsage=false
|
||||
@@ -162,6 +162,41 @@ class TestAsyncSQLiteEventStorage:
|
||||
|
||||
await storage.close()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_last_idx(self, temp_db_path: Path, sample_node_id: NodeId) -> None:
|
||||
"""Test that rowid returns correctly from db."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(db_path=temp_db_path, batch_size=default_config.batch_size, batch_timeout_ms=default_config.batch_timeout_ms, debounce_ms=default_config.debounce_ms, max_age_ms=default_config.max_age_ms)
|
||||
await storage.start()
|
||||
|
||||
# Insert multiple records
|
||||
test_records = [
|
||||
{"event_type": "test_event_1", "data": "first"},
|
||||
{"event_type": "test_event_2", "data": "second"},
|
||||
{"event_type": "test_event_3", "data": "third"}
|
||||
]
|
||||
|
||||
assert storage._engine is not None
|
||||
async with AsyncSession(storage._engine) as session:
|
||||
for record in test_records:
|
||||
await session.execute(
|
||||
text("INSERT INTO events (origin, event_type, event_id, event_data) VALUES (:origin, :event_type, :event_id, :event_data)"),
|
||||
{
|
||||
"origin": str(sample_node_id.uuid),
|
||||
"event_type": record["event_type"],
|
||||
"event_id": str(uuid4()),
|
||||
"event_data": json.dumps(record)
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
last_idx = await storage.get_last_idx()
|
||||
assert last_idx == 3
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rowid_with_multiple_origins(self, temp_db_path: Path) -> None:
|
||||
"""Test rowid sequence across multiple origins."""
|
||||
@@ -404,22 +439,19 @@ class TestAsyncSQLiteEventStorage:
|
||||
await storage.start()
|
||||
|
||||
# Create a ChunkGenerated event with nested TokenChunk
|
||||
task_id = TaskId(uuid=uuid4())
|
||||
chunk_data = TokenChunkData(
|
||||
request_id = RequestId(uuid=uuid4())
|
||||
token_chunk = TokenChunk(
|
||||
text="Hello, world!",
|
||||
token_id=42,
|
||||
finish_reason="stop"
|
||||
)
|
||||
token_chunk = TokenChunk(
|
||||
chunk_data=chunk_data,
|
||||
finish_reason="stop",
|
||||
chunk_type=ChunkType.token,
|
||||
task_id=task_id,
|
||||
request_id=request_id,
|
||||
idx=0,
|
||||
model="test-model"
|
||||
)
|
||||
|
||||
chunk_generated_event = ChunkGenerated(
|
||||
task_id=task_id,
|
||||
request_id=request_id,
|
||||
chunk=token_chunk
|
||||
)
|
||||
|
||||
@@ -441,19 +473,19 @@ class TestAsyncSQLiteEventStorage:
|
||||
retrieved_event = retrieved_event_wrapper.event
|
||||
assert isinstance(retrieved_event, ChunkGenerated)
|
||||
assert retrieved_event.event_type == EventType.ChunkGenerated
|
||||
assert retrieved_event.task_id == task_id
|
||||
assert retrieved_event.request_id == request_id
|
||||
|
||||
# Verify the nested chunk was deserialized correctly
|
||||
retrieved_chunk = retrieved_event.chunk
|
||||
assert isinstance(retrieved_chunk, TokenChunk)
|
||||
assert retrieved_chunk.chunk_type == ChunkType.token
|
||||
assert retrieved_chunk.task_id == task_id
|
||||
assert retrieved_chunk.request_id == request_id
|
||||
assert retrieved_chunk.idx == 0
|
||||
assert retrieved_chunk.model == "test-model"
|
||||
|
||||
# Verify the chunk data
|
||||
assert retrieved_chunk.chunk_data.text == "Hello, world!"
|
||||
assert retrieved_chunk.chunk_data.token_id == 42
|
||||
assert retrieved_chunk.chunk_data.finish_reason == "stop"
|
||||
assert retrieved_chunk.text == "Hello, world!"
|
||||
assert retrieved_chunk.token_id == 42
|
||||
assert retrieved_chunk.finish_reason == "stop"
|
||||
|
||||
await storage.close()
|
||||
@@ -1,11 +1,34 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.tasks.common import ChatCompletionTaskParams, TaskId
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
|
||||
content: str | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_call_id: str | None = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatTask(BaseModel):
|
||||
task_id: TaskId
|
||||
kind: Literal["chat"] = "chat"
|
||||
task_data: ChatCompletionTaskParams
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
messages: list[ChatCompletionMessage]
|
||||
logit_bias: dict[str, int] | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = None
|
||||
presence_penalty: float | None = None
|
||||
response_format: dict[str, Any] | None = None
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
@@ -1,13 +1,11 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
# from openai.types.chat.chat_completion import ChatCompletion
|
||||
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from shared.openai_compat import FinishReason
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.tasks.common import TaskId
|
||||
from shared.types.tasks.request import RequestId
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
@@ -17,38 +15,21 @@ class ChunkType(str, Enum):
|
||||
|
||||
class BaseChunk[ChunkTypeT: ChunkType](BaseModel):
|
||||
chunk_type: ChunkTypeT
|
||||
task_id: TaskId
|
||||
request_id: RequestId
|
||||
idx: int
|
||||
model: ModelId
|
||||
|
||||
|
||||
###
|
||||
|
||||
|
||||
class TokenChunkData(BaseModel):
|
||||
class TokenChunk(BaseChunk[ChunkType.token]):
|
||||
chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True)
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ImageChunkData(BaseModel):
|
||||
data: bytes
|
||||
|
||||
|
||||
###
|
||||
|
||||
|
||||
class TokenChunk(BaseChunk[ChunkType.token]):
|
||||
chunk_data: TokenChunkData
|
||||
chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True)
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk[ChunkType.image]):
|
||||
chunk_data: ImageChunkData
|
||||
chunk_type: Literal[ChunkType.image] = Field(default=ChunkType.image, frozen=True)
|
||||
|
||||
|
||||
###
|
||||
data: bytes
|
||||
|
||||
GenerationChunk = Annotated[TokenChunk | ImageChunk, Field(discriminator="chunk_type")]
|
||||
GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(GenerationChunk)
|
||||
@@ -60,10 +41,8 @@ GenerationChunkTypeAdapter: TypeAdapter[GenerationChunk] = TypeAdapter(Generatio
|
||||
# my_chunk: dict[str, Any] = TokenChunk(
|
||||
# task_id=TaskId('nicerid'),
|
||||
# idx=0,
|
||||
# chunk_data=TokenChunkData(
|
||||
# text='hello',
|
||||
# token_id=12,
|
||||
# ),
|
||||
# text='hello',
|
||||
# token_id=12,
|
||||
# chunk_type=ChunkType.token,
|
||||
# model='llama-3.1',
|
||||
# ).model_dump()
|
||||
|
||||
@@ -17,6 +17,7 @@ from shared.types.graphs.topology import (
|
||||
)
|
||||
from shared.types.profiling.common import NodePerformanceProfile
|
||||
from shared.types.tasks.common import Task, TaskId, TaskStatus
|
||||
from shared.types.tasks.request import RequestId
|
||||
from shared.types.worker.common import InstanceId, NodeStatus
|
||||
from shared.types.worker.instances import InstanceParams, TypeOfInstance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
@@ -111,7 +112,7 @@ class WorkerDisconnected(BaseEvent[EventType.WorkerDisconnected]):
|
||||
|
||||
class ChunkGenerated(BaseEvent[EventType.ChunkGenerated]):
|
||||
event_type: Literal[EventType.ChunkGenerated] = EventType.ChunkGenerated
|
||||
task_id: TaskId
|
||||
request_id: RequestId
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.api import ChatCompletionTaskParams
|
||||
from shared.types.common import NewUUID
|
||||
from shared.types.worker.common import InstanceId
|
||||
|
||||
@@ -10,11 +10,9 @@ from shared.types.worker.common import InstanceId
|
||||
class TaskId(NewUUID):
|
||||
pass
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
ChatCompletion = "ChatCompletion"
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
Pending = "Pending"
|
||||
Running = "Running"
|
||||
@@ -22,42 +20,10 @@ class TaskStatus(str, Enum):
|
||||
Failed = "Failed"
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
|
||||
content: str | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_call_id: str | None = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
task_type: Literal[TaskType.ChatCompletion] = TaskType.ChatCompletion
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
messages: list[ChatCompletionMessage]
|
||||
logit_bias: dict[str, int] | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = None
|
||||
presence_penalty: float | None = None
|
||||
response_format: dict[str, Any] | None = None
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_id: TaskId
|
||||
task_type: TaskType # redundant atm as we only have 1 task type.
|
||||
instance_id: InstanceId
|
||||
task_type: TaskType
|
||||
task_status: TaskStatus
|
||||
task_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
12
shared/types/tasks/request.py
Normal file
12
shared/types/tasks/request.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.api import ChatCompletionTaskParams
|
||||
from shared.types.common import NewUUID
|
||||
|
||||
|
||||
class RequestId(NewUUID):
|
||||
pass
|
||||
|
||||
class APIRequest(BaseModel):
|
||||
request_id: RequestId
|
||||
request_params: ChatCompletionTaskParams
|
||||
24
uv.lock
generated
24
uv.lock
generated
@@ -154,6 +154,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click"
|
||||
version = "8.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
@@ -219,12 +228,14 @@ source = { editable = "master" }
|
||||
dependencies = [
|
||||
{ name = "exo-shared", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "uvicorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "exo-shared", editable = "shared" },
|
||||
{ name = "fastapi", specifier = ">=0.116.0" },
|
||||
{ name = "uvicorn", specifier = ">=0.35.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1129,6 +1140,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uvicorn"
|
||||
version = "0.35.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5e/42/e0e305207bb88c6b8d3061399c6a961ffe5fbb7e2aa63c9234df7259e9cd/uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01", size = 78473, upload-time = "2025-06-28T16:15:46.058Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/e2/dc81b1bd1dcfe91735810265e9d26bc8ec5da45b4c0f6237e286819194c3/uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a", size = 66406, upload-time = "2025-06-28T16:15:44.816Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarl"
|
||||
version = "1.20.1"
|
||||
|
||||
@@ -236,12 +236,15 @@ class Worker:
|
||||
assigned_runner.status = RunningRunnerStatus()
|
||||
await queue.put(assigned_runner.status_update_event())
|
||||
|
||||
|
||||
try:
|
||||
async for chunk in assigned_runner.runner.stream_response(
|
||||
task=op.task,
|
||||
request_started_callback=partial(running_callback, queue)):
|
||||
await queue.put(ChunkGenerated(
|
||||
task_id=op.task.task_id,
|
||||
# todo: at some point we will no longer have a bijection between task_id and row_id.
|
||||
# So we probably want to store a mapping between these two in our Worker object.
|
||||
request_id=chunk.request_id,
|
||||
chunk=chunk
|
||||
))
|
||||
|
||||
|
||||
@@ -5,11 +5,12 @@ from collections.abc import AsyncGenerator
|
||||
from types import CoroutineType
|
||||
from typing import Any, Callable
|
||||
|
||||
from shared.types.events.chunks import GenerationChunk, TokenChunk, TokenChunkData
|
||||
from shared.types.events.chunks import GenerationChunk, TokenChunk
|
||||
from shared.types.tasks.common import (
|
||||
ChatCompletionTaskParams,
|
||||
Task,
|
||||
)
|
||||
from shared.types.tasks.request import RequestId
|
||||
from shared.types.worker.commands_runner import (
|
||||
ChatTaskMessage,
|
||||
ErrorResponse,
|
||||
@@ -183,14 +184,12 @@ class RunnerSupervisor:
|
||||
text=text, token=token, finish_reason=finish_reason
|
||||
):
|
||||
yield TokenChunk(
|
||||
task_id=task.task_id,
|
||||
request_id=RequestId(uuid=task.task_id.uuid),
|
||||
idx=token,
|
||||
model=self.model_shard_meta.model_meta.model_id,
|
||||
chunk_data=TokenChunkData(
|
||||
text=text,
|
||||
token_id=token,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
text=text,
|
||||
token_id=token,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
case FinishedResponse():
|
||||
break
|
||||
|
||||
@@ -6,12 +6,11 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models import ModelId, ModelMetadata
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks.common import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionTaskParams,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
|
||||
@@ -45,9 +45,9 @@ async def test_supervisor_single_node_response(
|
||||
|
||||
async for chunk in supervisor.stream_response(task=chat_task):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response += chunk.chunk_data.text
|
||||
if chunk.chunk_data.finish_reason:
|
||||
stop_reason = chunk.chunk_data.finish_reason
|
||||
full_response += chunk.text
|
||||
if chunk.finish_reason:
|
||||
stop_reason = chunk.finish_reason
|
||||
|
||||
# Case-insensitive check for Paris in the response
|
||||
assert "paris" in full_response.lower(), (
|
||||
@@ -87,13 +87,13 @@ async def test_supervisor_two_node_response(
|
||||
nonlocal full_response_0
|
||||
async for chunk in supervisor_0.stream_response(task=chat_task):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response_0 += chunk.chunk_data.text
|
||||
full_response_0 += chunk.text
|
||||
|
||||
async def collect_response_1():
|
||||
nonlocal full_response_1
|
||||
async for chunk in supervisor_1.stream_response(task=chat_task):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response_1 += chunk.chunk_data.text
|
||||
full_response_1 += chunk.text
|
||||
|
||||
# Run both stream responses simultaneously
|
||||
_ = await asyncio.gather(collect_response_0(), collect_response_1())
|
||||
@@ -148,10 +148,10 @@ async def test_supervisor_early_stopping(
|
||||
|
||||
async for chunk in supervisor.stream_response(task=chat_task):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response += chunk.chunk_data.text
|
||||
full_response += chunk.text
|
||||
count += 1
|
||||
if chunk.chunk_data.finish_reason:
|
||||
stop_reason = chunk.chunk_data.finish_reason
|
||||
if chunk.finish_reason:
|
||||
stop_reason = chunk.finish_reason
|
||||
|
||||
print(f"full_response: {full_response}")
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable
|
||||
import pytest
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import TokenChunk, TokenChunkData
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.events.events import ChunkGenerated, RunnerStatusUpdated
|
||||
from shared.types.events.registry import Event
|
||||
from shared.types.tasks.common import Task
|
||||
@@ -107,7 +107,7 @@ async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
|
||||
|
||||
async for chunk in supervisor.stream_response(task=chat_task):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response += chunk.chunk_data.text
|
||||
full_response += chunk.text
|
||||
|
||||
assert "42" in full_response.lower(), (
|
||||
f"Expected '42' in response, but got: {full_response}"
|
||||
@@ -175,7 +175,7 @@ async def test_execute_task_op(
|
||||
assert isinstance(events[-1].runner_status, LoadedRunnerStatus) # It should not have failed.
|
||||
|
||||
gen_events: list[ChunkGenerated] = [x for x in events if isinstance(x, ChunkGenerated)]
|
||||
text_chunks: list[TokenChunkData] = [x.chunk.chunk_data for x in gen_events if isinstance(x.chunk.chunk_data, TokenChunkData)]
|
||||
text_chunks: list[TokenChunk] = [x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)]
|
||||
assert len(text_chunks) == len(events) - 2
|
||||
|
||||
output_text = ''.join([x.text for x in text_chunks])
|
||||
|
||||
Reference in New Issue
Block a user