Files
exo/worker/tests/test_supervisor_errors.py
2025-07-30 13:30:54 +01:00

251 lines
9.4 KiB
Python

import asyncio
from collections.abc import AsyncGenerator
from types import CoroutineType
from typing import Any, Awaitable, Callable, Final
import pytest
from _pytest.monkeypatch import MonkeyPatch
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
InstanceCreated,
InstanceDeleted,
RunnerStatusUpdated,
TaskCreated,
TaskStateUpdated,
TaskFailed,
)
from shared.types.events.chunks import GenerationChunk, TokenChunk
from shared.types.models import ModelId
from shared.types.tasks import Task, TaskId, TaskStatus
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
Instance,
InstanceStatus,
)
from shared.types.worker.runners import FailedRunnerStatus
from worker.main import Worker
from worker.runner.runner_supervisor import RunnerSupervisor
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
# Define constant IDs for deterministic test cases
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555")
TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
@pytest.fixture
def user_message():
"""Override this fixture in tests to customize the message"""
return "Who is the longest ruling monarch of England?"
# TODO: Make this all monkeypatched instead.
async def test_stream_response_failed_always(
monkeypatch: MonkeyPatch,
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
raise RuntimeError("Simulated stream response failure")
return
yield
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(5.)
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3
assert any([isinstance(x.event, InstanceDeleted) for x in events])
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
async def test_stream_response_failed_once(
monkeypatch: MonkeyPatch,
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
failed_already = False
original_stream_response = RunnerSupervisor.stream_response
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
nonlocal failed_already
if not failed_already:
failed_already = True
raise RuntimeError("Simulated stream response failure")
else:
async for event in original_stream_response(self, task, request_started_callback):
yield event
return
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(5.)
# TODO: The ideal with this test is if we had some tooling to scroll through the state, and say
# 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero'
# as we reset the failures back to zero when we have a successful inference.
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
assert worker.state.tasks[TASK_1_ID].error_type is None
assert worker.state.tasks[TASK_1_ID].error_message is None
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
response_string = ''
events = await global_events.get_events_since(0)
seen_task_started, seen_task_finished = False, False
for wrapped_event in events:
event = wrapped_event.event
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.RUNNING:
seen_task_started = True
if event.task_status == TaskStatus.COMPLETE:
seen_task_finished = True
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
response_string += event.chunk.text
assert 'elizabeth' in response_string.lower()
assert seen_task_started
assert seen_task_finished
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
async def test_stream_response_timeout(
monkeypatch: MonkeyPatch,
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
# TODO: Also a test where we yield a few chunks and then time out.
print('sleeping starting')
await asyncio.sleep(4.)
print('sleeping finished')
return
yield
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(7.)
# as we reset the failures back to zero when we have a successful inference.
# print('ASSERTION ERR:')
# print(worker.assigned_runners[RUNNER_1_ID].failures[1][1])
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
assert worker.state.tasks[TASK_1_ID].error_type is None
assert worker.state.tasks[TASK_1_ID].error_message is None
events = await global_events.get_events_since(0)
print(events)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)