Files
exo/worker/tests/test_integration/integration_utils.py
Matt Beton 1fe4ed3442 Worker Exception & Timeout Refactor
Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Seth Howes <sethshowes@gmail.com>
2025-08-02 08:28:37 -07:00

81 lines
2.8 KiB
Python

import asyncio
from typing import Callable, Optional, Tuple, TypeVar
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.events import ChunkGenerated, TaskStateUpdated
from shared.types.events.chunks import TokenChunk
from shared.types.tasks import TaskId, TaskStatus
async def read_streaming_response(global_events: AsyncSQLiteEventStorage, filter_task: Optional[TaskId] = None) -> Tuple[bool, bool, str]:
# Read off all events - these should be our GenerationChunk events
seen_task_started, seen_task_finished = 0, 0
response_string = ''
finish_reason: str | None = None
if not filter_task:
idx = await global_events.get_last_idx()
else:
found = False
idx = 0
while not found:
events = await global_events.get_events_since(idx)
for event in events:
if isinstance(event.event, TaskStateUpdated) and event.event.task_status == TaskStatus.RUNNING and event.event.task_id == filter_task:
found = True
idx = event.idx_in_log - 1
break
print(f'START IDX {idx}')
while not finish_reason:
events = await global_events.get_events_since(idx)
if len(events) == 0:
await asyncio.sleep(0.01)
continue
idx = events[-1].idx_in_log
for wrapped_event in events:
event = wrapped_event.event
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.RUNNING:
seen_task_started += 1
if event.task_status == TaskStatus.COMPLETE:
seen_task_finished += 1
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
response_string += event.chunk.text
if event.chunk.finish_reason:
finish_reason = event.chunk.finish_reason
await asyncio.sleep(0.2)
print(f'event log: {await global_events.get_events_since(0)}')
return seen_task_started == 1, seen_task_finished == 1, response_string
T = TypeVar("T")
async def until_event_with_timeout(
global_events: AsyncSQLiteEventStorage,
event_type: type[T],
multiplicity: int = 1,
condition: Callable[[T], bool] = lambda x: True,
) -> None:
idx = await global_events.get_last_idx()
times_seen = 0
while True:
events = await global_events.get_events_since(idx)
if events:
for wrapped_event in events:
if isinstance(wrapped_event.event, event_type) and condition(wrapped_event.event):
times_seen += 1
if times_seen >= multiplicity:
return
idx = events[-1].idx_in_log
await asyncio.sleep(0.01)