Files
exo/worker/tests/test_worker_handlers.py
Alex Cheema a241c92dd1 Glue
2025-07-25 13:10:29 +01:00

228 lines
7.5 KiB
Python

## Tests for worker state handlers
from pathlib import Path
from typing import Callable
import pytest
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
Event,
RunnerDeleted,
RunnerStatusUpdated,
TaskStateUpdated,
)
from shared.types.events.chunks import TokenChunk
from shared.types.tasks import Task, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.instances import Instance
from shared.types.worker.ops import (
AssignRunnerOp,
DownloadOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerUpOp,
UnassignRunnerOp,
)
from shared.types.worker.runners import (
FailedRunnerStatus,
LoadedRunnerStatus,
ReadyRunnerStatus,
RunningRunnerStatus,
)
from worker.main import Worker
@pytest.fixture
def user_message():
"""Override the default message to ask about France's capital"""
return "What, according to Douglas Adams, is the meaning of life, the universe and everything?"
@pytest.mark.asyncio
async def test_assign_op(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance], tmp_path: Path):
runner_id = RunnerId()
instance_obj: Instance = instance(worker.node_id, runner_id)
assign_op = AssignRunnerOp(
runner_id=runner_id,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)
events: list[Event] = []
async for event in worker._execute_op(assign_op): # type: ignore[misc]
events.append(event)
# We should have a status update saying 'starting'.
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, ReadyRunnerStatus)
# And the runner should be assigned
assert runner_id in worker.assigned_runners
assert isinstance(worker.assigned_runners[runner_id].status, ReadyRunnerStatus)
@pytest.mark.asyncio
async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path):
worker, runner_id, _ = worker_with_assigned_runner
unassign_op = UnassignRunnerOp(
runner_id=runner_id
)
events: list[Event] = []
async for event in worker._execute_op(unassign_op): # type: ignore[misc]
events.append(event)
# We should have no assigned runners and no events were emitted
assert len(worker.assigned_runners) == 0
assert len(events) == 1
assert isinstance(events[0], RunnerDeleted)
@pytest.mark.asyncio
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_completion_task: Task, tmp_path: Path):
worker, runner_id, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=runner_id)
events: list[Event] = []
async for event in worker._execute_op(runner_up_op): # type: ignore[misc]
events.append(event)
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, LoadedRunnerStatus)
# Is the runner actually running?
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.healthy
full_response = ''
async for chunk in supervisor.stream_response(task=chat_completion_task):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
assert "42" in full_response.lower(), (
f"Expected '42' in response, but got: {full_response}"
)
runner = worker.assigned_runners[runner_id].runner
assert runner is not None
await runner.astop() # Neat cleanup.
@pytest.mark.asyncio
async def test_runner_down_op(worker_with_running_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path):
worker, runner_id, _ = worker_with_running_runner
runner_down_op = RunnerDownOp(runner_id=runner_id)
events: list[Event] = []
async for event in worker._execute_op(runner_down_op): # type: ignore[misc]
events.append(event)
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, ReadyRunnerStatus)
@pytest.mark.asyncio
async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], tmp_path: Path):
worker, runner_id, instance_obj = worker_with_assigned_runner
print(f'{worker.assigned_runners=}')
download_op = DownloadOp(
instance_id=instance_obj.instance_id,
runner_id=runner_id,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
)
events: list[Event] = []
async for event in worker._execute_op(download_op): # type: ignore[misc]
events.append(event)
# Should give download status and then a final download status with DownloadCompleted
print(events)
@pytest.mark.asyncio
async def test_execute_task_op(
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
chat_completion_task: Task, tmp_path: Path):
worker, runner_id, _ = worker_with_running_runner
execute_task_op = ExecuteTaskOp(
runner_id=runner_id,
task=chat_completion_task
)
events: list[Event] = []
async for event in worker._execute_op(execute_task_op): # type: ignore[misc]
events.append(event)
assert len(events) > 20
print(f'{events=}')
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, RunningRunnerStatus)
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
assert isinstance(events[-2], TaskStateUpdated)
assert events[-2].task_status == TaskStatus.COMPLETE # It tried to start.
assert isinstance(events[-1], RunnerStatusUpdated)
assert isinstance(events[-1].runner_status, LoadedRunnerStatus) # It should not have failed.
gen_events: list[ChunkGenerated] = [x for x in events if isinstance(x, ChunkGenerated)]
text_chunks: list[TokenChunk] = [x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)]
assert len(text_chunks) == len(events) - 4
output_text = ''.join([x.text for x in text_chunks])
assert '42' in output_text
runner = worker.assigned_runners[runner_id].runner
assert runner is not None
await runner.astop() # Neat cleanup.
@pytest.mark.asyncio
async def test_execute_task_fails(
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
chat_completion_task: Task, tmp_path: Path):
worker, runner_id, _ = worker_with_running_runner
messages = chat_completion_task.task_params.messages
messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL'
execute_task_op = ExecuteTaskOp(
runner_id=runner_id,
task=chat_completion_task
)
events: list[Event] = []
async for event in worker._execute_op(execute_task_op): # type: ignore[misc]
events.append(event)
assert len(events) == 4
print(events)
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, RunningRunnerStatus) # It tried to start.
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
assert isinstance(events[2], TaskStateUpdated)
assert events[2].task_status == TaskStatus.FAILED # Task marked as failed.
assert isinstance(events[3], RunnerStatusUpdated)
assert isinstance(events[3].runner_status, FailedRunnerStatus) # It should have failed.