mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
251 lines
9.4 KiB
Python
251 lines
9.4 KiB
Python
import asyncio
|
|
from collections.abc import AsyncGenerator
|
|
from types import CoroutineType
|
|
from typing import Any, Awaitable, Callable, Final
|
|
|
|
import pytest
|
|
from _pytest.monkeypatch import MonkeyPatch
|
|
|
|
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
|
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
|
from shared.types.common import NodeId
|
|
from shared.types.events import (
|
|
ChunkGenerated,
|
|
InstanceCreated,
|
|
InstanceDeleted,
|
|
RunnerStatusUpdated,
|
|
TaskCreated,
|
|
TaskStateUpdated,
|
|
TaskFailed,
|
|
)
|
|
from shared.types.events.chunks import GenerationChunk, TokenChunk
|
|
from shared.types.models import ModelId
|
|
from shared.types.tasks import Task, TaskId, TaskStatus
|
|
from shared.types.worker.common import InstanceId, RunnerId
|
|
from shared.types.worker.instances import (
|
|
Instance,
|
|
InstanceStatus,
|
|
)
|
|
from shared.types.worker.runners import FailedRunnerStatus
|
|
from worker.main import Worker
|
|
from worker.runner.runner_supervisor import RunnerSupervisor
|
|
|
|
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
|
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
|
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
|
|
|
|
# Define constant IDs for deterministic test cases
|
|
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
|
|
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
|
|
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
|
|
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
|
|
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
|
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
|
TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555")
|
|
TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
|
|
|
|
@pytest.fixture
|
|
def user_message():
|
|
"""Override this fixture in tests to customize the message"""
|
|
return "Who is the longest ruling monarch of England?"
|
|
|
|
# TODO: Make this all monkeypatched instead.
|
|
|
|
async def test_stream_response_failed_always(
|
|
monkeypatch: MonkeyPatch,
|
|
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
|
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
|
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
|
):
|
|
worker, global_events = await worker_running(NODE_A)
|
|
|
|
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
|
instance_value.instance_type = InstanceStatus.ACTIVE
|
|
|
|
async def mock_stream_response(
|
|
self: RunnerSupervisor,
|
|
task: Task,
|
|
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
|
) -> AsyncGenerator[GenerationChunk]:
|
|
raise RuntimeError("Simulated stream response failure")
|
|
return
|
|
yield
|
|
|
|
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
|
|
|
|
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
|
await global_events.append_events(
|
|
[
|
|
InstanceCreated(instance=instance_value),
|
|
TaskCreated(task_id=task.task_id, task=task)
|
|
],
|
|
origin=MASTER_NODE_ID
|
|
)
|
|
|
|
await asyncio.sleep(5.)
|
|
|
|
|
|
events = await global_events.get_events_since(0)
|
|
|
|
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
|
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3
|
|
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
|
|
|
await global_events.append_events(
|
|
[
|
|
InstanceDeleted(
|
|
instance_id=instance_value.instance_id,
|
|
),
|
|
],
|
|
origin=MASTER_NODE_ID
|
|
)
|
|
|
|
await asyncio.sleep(0.3)
|
|
|
|
async def test_stream_response_failed_once(
|
|
monkeypatch: MonkeyPatch,
|
|
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
|
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
|
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
|
):
|
|
failed_already = False
|
|
original_stream_response = RunnerSupervisor.stream_response
|
|
|
|
async def mock_stream_response(
|
|
self: RunnerSupervisor,
|
|
task: Task,
|
|
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
|
) -> AsyncGenerator[GenerationChunk]:
|
|
nonlocal failed_already
|
|
if not failed_already:
|
|
failed_already = True
|
|
raise RuntimeError("Simulated stream response failure")
|
|
else:
|
|
async for event in original_stream_response(self, task, request_started_callback):
|
|
yield event
|
|
return
|
|
|
|
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
|
|
|
|
worker, global_events = await worker_running(NODE_A)
|
|
|
|
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
|
instance_value.instance_type = InstanceStatus.ACTIVE
|
|
|
|
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
|
await global_events.append_events(
|
|
[
|
|
InstanceCreated(instance=instance_value),
|
|
TaskCreated(task_id=task.task_id, task=task)
|
|
],
|
|
origin=MASTER_NODE_ID
|
|
)
|
|
|
|
await asyncio.sleep(5.)
|
|
|
|
# TODO: The ideal with this test is if we had some tooling to scroll through the state, and say
|
|
# 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero'
|
|
|
|
# as we reset the failures back to zero when we have a successful inference.
|
|
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
|
|
assert worker.state.tasks[TASK_1_ID].error_type is None
|
|
assert worker.state.tasks[TASK_1_ID].error_message is None
|
|
|
|
events = await global_events.get_events_since(0)
|
|
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
|
|
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
|
|
|
|
response_string = ''
|
|
events = await global_events.get_events_since(0)
|
|
|
|
seen_task_started, seen_task_finished = False, False
|
|
for wrapped_event in events:
|
|
event = wrapped_event.event
|
|
if isinstance(event, TaskStateUpdated):
|
|
if event.task_status == TaskStatus.RUNNING:
|
|
seen_task_started = True
|
|
if event.task_status == TaskStatus.COMPLETE:
|
|
seen_task_finished = True
|
|
|
|
if isinstance(event, ChunkGenerated):
|
|
assert isinstance(event.chunk, TokenChunk)
|
|
response_string += event.chunk.text
|
|
|
|
assert 'elizabeth' in response_string.lower()
|
|
assert seen_task_started
|
|
assert seen_task_finished
|
|
|
|
await global_events.append_events(
|
|
[
|
|
InstanceDeleted(
|
|
instance_id=instance_value.instance_id,
|
|
),
|
|
],
|
|
origin=MASTER_NODE_ID
|
|
)
|
|
|
|
await asyncio.sleep(0.3)
|
|
|
|
|
|
async def test_stream_response_timeout(
|
|
monkeypatch: MonkeyPatch,
|
|
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
|
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
|
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
|
):
|
|
async def mock_stream_response(
|
|
self: RunnerSupervisor,
|
|
task: Task,
|
|
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
|
) -> AsyncGenerator[GenerationChunk]:
|
|
# TODO: Also a test where we yield a few chunks and then time out.
|
|
print('sleeping starting')
|
|
await asyncio.sleep(4.)
|
|
print('sleeping finished')
|
|
return
|
|
yield
|
|
|
|
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
|
|
|
|
worker, global_events = await worker_running(NODE_A)
|
|
|
|
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
|
instance_value.instance_type = InstanceStatus.ACTIVE
|
|
|
|
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
|
await global_events.append_events(
|
|
[
|
|
InstanceCreated(instance=instance_value),
|
|
TaskCreated(task_id=task.task_id, task=task)
|
|
],
|
|
origin=MASTER_NODE_ID
|
|
)
|
|
|
|
await asyncio.sleep(7.)
|
|
|
|
|
|
# as we reset the failures back to zero when we have a successful inference.
|
|
|
|
# print('ASSERTION ERR:')
|
|
# print(worker.assigned_runners[RUNNER_1_ID].failures[1][1])
|
|
|
|
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
|
|
assert worker.state.tasks[TASK_1_ID].error_type is None
|
|
assert worker.state.tasks[TASK_1_ID].error_message is None
|
|
|
|
events = await global_events.get_events_since(0)
|
|
print(events)
|
|
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
|
|
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
|
|
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1
|
|
|
|
await global_events.append_events(
|
|
[
|
|
InstanceDeleted(
|
|
instance_id=instance_value.instance_id,
|
|
),
|
|
],
|
|
origin=MASTER_NODE_ID
|
|
)
|
|
|
|
await asyncio.sleep(0.3) |