From b350ededb2000cb02600fb2931e92c15bdd5b7ad Mon Sep 17 00:00:00 2001 From: Matt Beton Date: Wed, 30 Jul 2025 13:30:54 +0100 Subject: [PATCH] Test Supervisor Errors. --- engines/mlx/utils_mlx.py | 1 - shared/apply/apply.py | 21 +- shared/types/events/_events.py | 9 + shared/types/tasks.py | 5 +- shared/types/worker/commands_runner.py | 10 +- worker/main.py | 152 +++++++++++--- worker/runner/communication.py | 14 +- worker/runner/runner.py | 9 +- worker/runner/runner_supervisor.py | 22 ++- worker/tests/test_runner_connection.py | 189 ++++++++++++++++++ worker/tests/test_spinup_timeout.py | 48 +++++ worker/tests/test_supervisor.py | 30 ++- worker/tests/test_supervisor_errors.py | 251 ++++++++++++++++++++++++ worker/tests/test_worker_handlers.py | 9 +- worker/tests/test_worker_integration.py | 105 +++++++++- 15 files changed, 819 insertions(+), 56 deletions(-) create mode 100644 worker/tests/test_runner_connection.py create mode 100644 worker/tests/test_spinup_timeout.py create mode 100644 worker/tests/test_supervisor_errors.py diff --git a/engines/mlx/utils_mlx.py b/engines/mlx/utils_mlx.py index 3b7c5147..1b77413f 100644 --- a/engines/mlx/utils_mlx.py +++ b/engines/mlx/utils_mlx.py @@ -52,7 +52,6 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: os.environ["MLX_RANK"] = str(rank) os.environ["MLX_RING_VERBOSE"] = "1" - # Initialize distributed group = mx.distributed.init(backend="ring", strict=True) runner_print(f"Rank {rank} mlx distributed initialization complete") diff --git a/shared/apply/apply.py b/shared/apply/apply.py index 25eb2f27..18914590 100644 --- a/shared/apply/apply.py +++ b/shared/apply/apply.py @@ -19,6 +19,7 @@ from shared.types.events import ( RunnerStatusUpdated, TaskCreated, TaskDeleted, + TaskFailed, TaskStateUpdated, TopologyEdgeCreated, TopologyEdgeDeleted, @@ -28,7 +29,7 @@ from shared.types.events import ( ) from shared.types.profiling import NodePerformanceProfile from shared.types.state import State -from shared.types.tasks import Task, TaskId +from shared.types.tasks import Task, TaskId, TaskStatus from shared.types.topology import Connection, Node from shared.types.worker.common import NodeStatus, RunnerId from shared.types.worker.instances import Instance, InstanceId, InstanceStatus @@ -63,7 +64,23 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State: if event.task_id not in state.tasks: return state - updated_task = state.tasks[event.task_id].model_copy(update={"task_status": event.task_status}) + update: dict[str, TaskStatus | None] = { + "task_status": event.task_status, + } + if event.task_status != TaskStatus.FAILED: + update["error_type"] = None + update["error_message"] = None + + updated_task = state.tasks[event.task_id].model_copy(update=update) + new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task} + return state.model_copy(update={"tasks": new_tasks}) + +@event_apply.register(TaskFailed) +def apply_task_failed(event: TaskFailed, state: State) -> State: + if event.task_id not in state.tasks: + return state + + updated_task = state.tasks[event.task_id].model_copy(update={"error_type": event.error_type, "error_message": event.error_message}) new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task} return state.model_copy(update={"tasks": new_tasks}) diff --git a/shared/types/events/_events.py b/shared/types/events/_events.py index 6ae7d005..cb092909 100644 --- a/shared/types/events/_events.py +++ b/shared/types/events/_events.py @@ -49,6 +49,7 @@ class _EventType(str, Enum): # Task Events TaskCreated = "TaskCreated" TaskStateUpdated = "TaskStateUpdated" + TaskFailed = "TaskFailed" TaskDeleted = "TaskDeleted" # Streaming Events @@ -119,6 +120,13 @@ class TaskStateUpdated(_BaseEvent[_EventType.TaskStateUpdated]): task_status: TaskStatus +class TaskFailed(_BaseEvent[_EventType.TaskFailed]): + event_type: Literal[_EventType.TaskFailed] = _EventType.TaskFailed + task_id: TaskId + error_type: str + error_message: str + + class InstanceCreated(_BaseEvent[_EventType.InstanceCreated]): event_type: Literal[_EventType.InstanceCreated] = _EventType.InstanceCreated instance: Instance @@ -202,6 +210,7 @@ _Event = Union[ Heartbeat, TaskCreated, TaskStateUpdated, + TaskFailed, TaskDeleted, InstanceCreated, InstanceActivated, diff --git a/shared/types/tasks.py b/shared/types/tasks.py index 00426ba9..c4958eb2 100644 --- a/shared/types/tasks.py +++ b/shared/types/tasks.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Annotated, Literal +from typing import Annotated, Literal, Optional from pydantic import BaseModel, Field @@ -31,4 +31,7 @@ class ChatCompletionTask(BaseModel): task_status: TaskStatus task_params: ChatCompletionTaskParams + error_type: Optional[str] = Field(default=None) + error_message: Optional[str] = Field(default=None) + Task = Annotated[ChatCompletionTask, Field(discriminator="task_type")] diff --git a/shared/types/worker/commands_runner.py b/shared/types/worker/commands_runner.py index 4a05b09b..3ca0bf22 100644 --- a/shared/types/worker/commands_runner.py +++ b/shared/types/worker/commands_runner.py @@ -51,6 +51,7 @@ RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage class RunnerResponseType(str, Enum): + InitializedResponse = "initialized_response" GenerationResponse = "generation_response" FinishedResponse = "finished_response" PrintResponse = "print_response" @@ -64,6 +65,13 @@ class BaseRunnerResponse(BaseModel, Generic[RRT]): pass +class InitializedResponse(BaseRunnerResponse[RunnerResponseType.InitializedResponse]): + type: Literal[RunnerResponseType.InitializedResponse] = Field( + default=RunnerResponseType.InitializedResponse, frozen=True + ) + time_taken: float + + class GenerationResponse(BaseRunnerResponse[RunnerResponseType.GenerationResponse]): type: Literal[RunnerResponseType.GenerationResponse] = Field( default=RunnerResponseType.GenerationResponse, frozen=True @@ -97,7 +105,7 @@ class ErrorResponse(BaseRunnerResponse[RunnerResponseType.ErrorResponse]): RunnerResponse = Annotated[ - GenerationResponse | PrintResponse | FinishedResponse | ErrorResponse, + InitializedResponse | GenerationResponse | PrintResponse | FinishedResponse | ErrorResponse, Field(discriminator="type"), ] RunnerResponseTypeAdapter: TypeAdapter[RunnerResponse] = TypeAdapter(RunnerResponse) diff --git a/worker/main.py b/worker/main.py index 42cf9850..bf537302 100644 --- a/worker/main.py +++ b/worker/main.py @@ -1,5 +1,6 @@ import asyncio import logging +import time from asyncio import Queue from copy import deepcopy from functools import partial @@ -15,15 +16,17 @@ from shared.types.common import Host, NodeId from shared.types.events import ( ChunkGenerated, Event, + InstanceDeleted, InstanceId, NodePerformanceMeasured, RunnerDeleted, RunnerStatusUpdated, + TaskFailed, TaskStateUpdated, ) from shared.types.profiling import NodePerformanceProfile from shared.types.state import State -from shared.types.tasks import TaskStatus +from shared.types.tasks import TaskId, TaskStatus from shared.types.worker.common import RunnerId from shared.types.worker.downloads import ( DownloadCompleted, @@ -68,6 +71,7 @@ class AssignedRunner(BaseModel): hosts: list[Host] status: RunnerStatus + failures: list[tuple[float, Exception]] = [] runner: Optional[RunnerSupervisor] # set if the runner is 'up' model_config = ConfigDict(arbitrary_types_allowed=True) @@ -141,14 +145,36 @@ class Worker: yield async def _execute_runner_up_op( - self, op: RunnerUpOp + self, op: RunnerUpOp, initialize_timeout: Optional[float] = None ) -> AsyncGenerator[Event, None]: assigned_runner = self.assigned_runners[op.runner_id] - assigned_runner.runner = await RunnerSupervisor.create( - model_shard_meta=assigned_runner.shard_metadata, - hosts=assigned_runner.hosts, - ) + # TODO: This should be dynamic, based on the size of the model. + if not initialize_timeout: + GBPS = 10 + + shard = assigned_runner.shard_metadata + weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes + + initialize_timeout = weights_size_kb / (1024**2 * GBPS) + 2.0 # Add a constant 2.0 to ensure connection can be made as well + + try: + assigned_runner.runner = await asyncio.wait_for( + RunnerSupervisor.create( + model_shard_meta=assigned_runner.shard_metadata, + hosts=assigned_runner.hosts, + logger=self.logger, + ), + timeout=initialize_timeout, + ) + except TimeoutError as e: + import traceback + + tb = traceback.format_exc() + e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}") + async for event in self._fail_runner(e=e, runner_id=op.runner_id): + yield event + return if assigned_runner.runner.healthy: assigned_runner.status = LoadedRunnerStatus() @@ -161,8 +187,9 @@ class Worker: ) -> AsyncGenerator[Event, None]: assigned_runner = self.assigned_runners[op.runner_id] - assert isinstance(assigned_runner.runner, RunnerSupervisor) - await assigned_runner.runner.astop() + if isinstance(assigned_runner.runner, RunnerSupervisor): + await assigned_runner.runner.astop() + assigned_runner.runner = None assigned_runner.status = ReadyRunnerStatus() @@ -287,9 +314,6 @@ class Worker: assigned_runner = self.assigned_runners[op.runner_id] async def inner_execute(queue: asyncio.Queue[Event]) -> None: - assert assigned_runner.runner is not None - assert assigned_runner.runner.healthy - async def running_callback(queue: asyncio.Queue[Event]) -> None: # Called when the MLX process has been kicked off assigned_runner.status = RunningRunnerStatus() @@ -302,6 +326,9 @@ class Worker: )) try: + assert assigned_runner.runner is not None + assert assigned_runner.runner.healthy + async for chunk in assigned_runner.runner.stream_response( task=op.task, request_started_callback=partial(running_callback, queue)): @@ -325,34 +352,44 @@ class Worker: except Exception as e: - # TODO: What log level? - self.logger.log(2, f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}') - - if assigned_runner.shard_metadata.device_rank == 0: - await queue.put(TaskStateUpdated( - task_id=op.task.task_id, - task_status=TaskStatus.FAILED, - )) - - assigned_runner.runner = None - assigned_runner.status = FailedRunnerStatus(error_message=str(e)) - await queue.put(assigned_runner.status_update_event()) + # An exception occurs in the runner supervisor + self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}') + async for event in self._fail_task(e, op.runner_id, op.task.task_id): + await queue.put(event) queue: Queue[Event] = asyncio.Queue() task = asyncio.create_task(inner_execute(queue)) + # TODO: Initial (prefil) timeout can be dynamic + # model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes + try: # Yield items from the queue + # timeout = 30. + timeout = 3. while True: - item: Event = await asyncio.wait_for(queue.get(), timeout=5) + item: Event = await asyncio.wait_for(queue.get(), timeout=timeout) yield item + timeout = 2. if isinstance(item, RunnerStatusUpdated) and isinstance( item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus) ): + if isinstance(item.runner_status, LoadedRunnerStatus): + assigned_runner.failures = [] + break + except TimeoutError as e: + # Runner supervisor doesn't respond in time; so we put the runner & task into a failed state + self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.') + async for event in self._fail_task(e, op.runner_id, op.task.task_id): + yield event finally: # Ensure the task is cleaned up - await task + try: + await asyncio.wait_for(task, timeout=5) + except asyncio.TimeoutError: + self.logger.warning("Timed out waiting for task cleanup after inference execution.") + ## Operation Planner @@ -381,6 +418,10 @@ class Worker: def plan(self, state: State) -> RunnerOp | None: # Compare state to worker 'mood' + # for runner_id, assigned_runner in self.assigned_runners.items(): + # if len(assigned_runner.failures) == 3: + # raise Exception('Too many error occurred in assigned runner - assumed to be recurrent and unrecoverable.\nErrors are as follows: {assigned_runner.failures}') + # First, unassign assigned runners that are no longer in the state. for runner_id, _ in self.assigned_runners.items(): runner_ids: list[RunnerId] = [ @@ -512,7 +553,9 @@ class Worker: continue # The only previous state to get to Running is from Loaded for _, task in state.tasks.items(): - if task.instance_id == instance_id and task.task_status == TaskStatus.PENDING: + if task.instance_id == instance_id and ( + task.task_status == TaskStatus.PENDING or task.task_status == TaskStatus.FAILED + ): if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1): return ExecuteTaskOp(runner_id=runner_id, task=task) else: @@ -530,17 +573,56 @@ class Worker: return None + async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]: + if runner_id in self.assigned_runners: + assigned_runner = self.assigned_runners[runner_id] + + assigned_runner.runner = None + assigned_runner.status = FailedRunnerStatus(error_message=str(e)) + assigned_runner.failures.append( + ( + time.time(), + e + ) + ) + + # Reset failure count back to 0 when succesful + if len(assigned_runner.failures) >= 3: + # Too many retries. We will emit a DeleteInstance + yield InstanceDeleted( + instance_id=assigned_runner.instance_id + ) + + yield assigned_runner.status_update_event() + + + async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]: + if runner_id in self.assigned_runners: + yield TaskStateUpdated( + task_id=task_id, + task_status=TaskStatus.FAILED, + ) + + yield TaskFailed( + task_id=task_id, + error_type=str(type(e)), + error_message=str(e) + ) + + async for event in self._fail_runner(e, runner_id): + yield event + + async def event_publisher(self, event: Event) -> None: assert self.worker_events is not None await self.worker_events.append_events([event], self.node_id) - print(f"published event: {event}") + self.logger.info(f"published event: {event}") # Handle state updates async def run(self): assert self.global_events is not None while True: - _rank = list(self.assigned_runners.values())[0].shard_metadata.device_rank if self.assigned_runners else None # 1. get latest events events = await self.global_events.get_events_since(self.state.last_event_applied_idx) @@ -555,8 +637,18 @@ class Worker: # run the op, synchronously blocking for now if op is not None: - async for event in self._execute_op(op): - await self.event_publisher(event) + try: + async for event in self._execute_op(op): + await self.event_publisher(event) + except Exception as e: + # execeute_task_op already has its own exception handling here. So we assume we had an exception in one of the other op types. + # we therefore just fail the runner. + self.logger.warning(f"Encountered exception when executing worker op {op}: {e}. \n Runner will be spun down and retried.") + async for event in self._fail_runner( + e, + runner_id=op.runner_id, + ): + await self.event_publisher(event) await asyncio.sleep(0.01) if len(events) > 0: diff --git a/worker/runner/communication.py b/worker/runner/communication.py index 18001b8f..85efa090 100644 --- a/worker/runner/communication.py +++ b/worker/runner/communication.py @@ -47,9 +47,13 @@ async def runner_read_message() -> RunnerMessage: def runner_write_response(obj: RunnerResponse) -> None: - encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n" - _ = sys.stdout.buffer.write(encoded) - _ = sys.stdout.buffer.flush() + try: + encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n" + _ = sys.stdout.buffer.write(encoded) + _ = sys.stdout.buffer.flush() + except BrokenPipeError: + # Supervisor has closed the pipe, silently exit + sys.exit(0) async def supervisor_read_response( @@ -83,6 +87,10 @@ def runner_print(text: str) -> None: def runner_write_error(error: Exception) -> None: + # Skip writing error if it's a BrokenPipeError - supervisor is already gone + if isinstance(error, BrokenPipeError): + sys.exit(0) + error_response: ErrorResponse = ErrorResponse( type=RunnerResponseType.ErrorResponse, error_type=type(error).__name__, diff --git a/worker/runner/runner.py b/worker/runner/runner.py index d5a1fbb2..f2343e07 100644 --- a/worker/runner/runner.py +++ b/worker/runner/runner.py @@ -1,5 +1,6 @@ import asyncio import concurrent.futures +import time from collections.abc import AsyncGenerator from functools import partial from typing import Callable, cast @@ -17,6 +18,7 @@ from shared.types.worker.commands_runner import ( ExitMessage, FinishedResponse, GenerationResponse, + InitializedResponse, RunnerMessage, SetupMessage, ) @@ -98,23 +100,24 @@ async def _mlx_generate( async def main(): try: runner_print("hello from the runner") - # Get setup info from worker init_message = await runner_read_message() setup_message = ensure_type(init_message, SetupMessage) model_shard_meta = setup_message.model_shard_meta hosts = setup_message.hosts + + setup_start_time = time.time() mlx_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) loop = asyncio.get_running_loop() - runner_print(f"got here; {hosts}") - model, tokenizer, sampler = await loop.run_in_executor( mlx_executor, partial(initialize_mlx, model_shard_meta=model_shard_meta, hosts=hosts), ) + runner_write_response(InitializedResponse(time_taken=time.time() - setup_start_time)) + while True: message: RunnerMessage = await runner_read_message() match message: diff --git a/worker/runner/runner_supervisor.py b/worker/runner/runner_supervisor.py index 8d813697..77d6469f 100644 --- a/worker/runner/runner_supervisor.py +++ b/worker/runner/runner_supervisor.py @@ -2,6 +2,7 @@ import asyncio import contextlib import sys from collections.abc import AsyncGenerator +from logging import Logger from types import CoroutineType from typing import Any, Callable @@ -14,6 +15,7 @@ from shared.types.worker.commands_runner import ( ExitMessage, FinishedResponse, GenerationResponse, + InitializedResponse, PrintResponse, RunnerResponse, SetupMessage, @@ -54,6 +56,7 @@ class RunnerSupervisor: cls, model_shard_meta: ShardMetadata, hosts: list[Host], + logger: Logger ) -> "RunnerSupervisor": """ Create and initialize a RunnerSupervisor instance. @@ -66,7 +69,7 @@ class RunnerSupervisor: *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, - stderr=sys.stderr, + stderr=sys.stderr ) ) @@ -79,6 +82,21 @@ class RunnerSupervisor: ), ) + while True: + line: RunnerResponse | None = await supervisor_read_response( + runner_process + ) + if line is None or isinstance(line, PrintResponse): + # print(line) + continue + elif isinstance(line, ErrorResponse): + raise Exception(line.error_type, line.error_message, line.traceback or "") + else: + assert isinstance(line, InitializedResponse) + logger.info(f'Runner initialized in {line.time_taken} seconds') + print(f'Runner initialized in {line.time_taken} seconds') + break + return cls( model_shard_meta=model_shard_meta, hosts=hosts, @@ -203,6 +221,8 @@ class RunnerSupervisor: token_id=token, finish_reason=finish_reason, ) + case InitializedResponse(): + raise ValueError('Initialized Response read during streaming flow') case FinishedResponse(): break case PrintResponse(text=text): diff --git a/worker/tests/test_runner_connection.py b/worker/tests/test_runner_connection.py new file mode 100644 index 00000000..c988224b --- /dev/null +++ b/worker/tests/test_runner_connection.py @@ -0,0 +1,189 @@ +import asyncio +import os +from logging import Logger +from typing import Callable, Final + +import pytest + +from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager +from shared.types.common import Host, NodeId +from shared.types.events import InstanceCreated, InstanceDeleted +from shared.types.models import ModelId +from shared.types.tasks import Task +from shared.types.worker.common import InstanceId, RunnerId +from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments +from shared.types.worker.runners import FailedRunnerStatus +from shared.types.worker.shards import PipelineShardMetadata +from worker.download.shard_downloader import NoopShardDownloader +from worker.main import Worker + +MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") +NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") +NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb") + +RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111") +INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222") +RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333") +INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444") +MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' +MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' +TASK_1_ID: Final = "55555555-5555-4555-8555-555555555555" +TASK_2_ID: Final = "66666666-6666-4666-8666-666666666666" + +@pytest.fixture +def user_message() -> str: + return "What is the capital of Japan?" + +@pytest.mark.skipif( + os.environ.get("DETAILED", "").lower() != "true", + reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set" +) +async def check_runner_connection( + logger: Logger, + pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], + hosts: Callable[[int], list[Host]], + chat_completion_task: Callable[[InstanceId, str], Task], +) -> bool: + # Track all tasks and workers for cleanup + tasks: list[asyncio.Task[None]] = [] + workers: list[Worker] = [] + + try: + event_log_manager = EventLogManager(EventLogConfig(), logger) + await event_log_manager.initialize() + shard_downloader = NoopShardDownloader() + + global_events = event_log_manager.global_events + await global_events.delete_all_events() + + worker1 = Worker( + NODE_A, + logger=logger, + shard_downloader=shard_downloader, + worker_events=global_events, + global_events=global_events, + ) + workers.append(worker1) + task1 = asyncio.create_task(worker1.run()) + tasks.append(task1) + + worker2 = Worker( + NODE_B, + logger=logger, + shard_downloader=shard_downloader, + worker_events=global_events, + global_events=global_events, + ) + workers.append(worker2) + task2 = asyncio.create_task(worker2.run()) + tasks.append(task2) + + model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') + + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1) + }, + node_to_runner={ + NODE_A: RUNNER_1_ID, + NODE_B: RUNNER_2_ID + } + ) + + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2) + ) + + await global_events.append_events( + [ + InstanceCreated( + instance=instance + ), + ], + origin=MASTER_NODE_ID + ) + + from worker.runner.runner_supervisor import RunnerSupervisor + + async def wait_for_runner_supervisor(worker: Worker, timeout: float = 5.0) -> RunnerSupervisor | None: + end = asyncio.get_event_loop().time() + timeout + while True: + assigned_runners = list(worker.assigned_runners.values()) + if assigned_runners: + runner = assigned_runners[0].runner + if isinstance(runner, RunnerSupervisor): + print('breaking because success') + return runner + if isinstance(assigned_runners[0].status, FailedRunnerStatus): + print('breaking because failed') + return runner + if asyncio.get_event_loop().time() > end: + raise TimeoutError("RunnerSupervisor was not set within timeout") + await asyncio.sleep(0.001) + + runner_supervisor = await wait_for_runner_supervisor(worker1, timeout=6.0) + ret = runner_supervisor is not None and runner_supervisor.healthy + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(0.5) + + return ret + finally: + # Cancel all worker tasks + for task in tasks: + task.cancel() + + # Wait for cancellation to complete + await asyncio.gather(*tasks, return_exceptions=True) + +# Check Running status + +def test_runner_connection_stress( + logger: Logger, + pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], + hosts: Callable[[int], list[Host]], + chat_completion_task: Callable[[InstanceId, str], Task], +) -> None: + total_runs = 100 + successes = 0 + + for _ in range(total_runs): + # Create a fresh event loop for each iteration + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete(check_runner_connection( + logger=logger, + pipeline_shard_meta=pipeline_shard_meta, + hosts=hosts, + chat_completion_task=chat_completion_task, + )) + if result: + successes += 1 + finally: + # Cancel all running tasks + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + + # Run the event loop briefly to allow cancellation to complete + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + + # Close the event loop + loop.close() + + print(f"Runner connection successes: {successes} / {total_runs}") diff --git a/worker/tests/test_spinup_timeout.py b/worker/tests/test_spinup_timeout.py new file mode 100644 index 00000000..f8966d8e --- /dev/null +++ b/worker/tests/test_spinup_timeout.py @@ -0,0 +1,48 @@ +## Tests for worker state handlers + +import os +from typing import Callable + +import pytest + +from shared.types.events import ( + Event, +) +from shared.types.events._events import RunnerStatusUpdated +from shared.types.tasks import Task, TaskId +from shared.types.worker.common import RunnerId +from shared.types.worker.instances import Instance, InstanceId +from shared.types.worker.ops import ( + RunnerUpOp, +) +from shared.types.worker.runners import FailedRunnerStatus +from worker.main import Worker + +# To enable this test, run pytest with: ENABLE_SPINUP_TIMEOUT_TEST=true pytest + +@pytest.mark.skipif( + os.environ.get("DETAILED", "").lower() != "true", + reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set" +) +@pytest.mark.asyncio +async def test_runner_up_op_timeout( + worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], + chat_completion_task: Callable[[InstanceId, TaskId], Task], + monkeypatch: pytest.MonkeyPatch + ): + worker, runner_id, _ = worker_with_assigned_runner + + runner_up_op = RunnerUpOp(runner_id=runner_id) + + # _execute_runner_up_op should throw a TimeoutError with a short timeout + events: list[Event] = [] + async for event in worker._execute_runner_up_op(runner_up_op, initialize_timeout=0.2): # type: ignore[misc] + events.append(event) + + assert isinstance(events[-1], RunnerStatusUpdated) + assert isinstance(events[-1].runner_status, FailedRunnerStatus) + assert events[-1].runner_status.error_message is not None + assert 'timeout' in events[-1].runner_status.error_message.lower() + + del worker.assigned_runners[list(worker.assigned_runners.keys())[0]] + diff --git a/worker/tests/test_supervisor.py b/worker/tests/test_supervisor.py index 1db5a7a2..915c7393 100644 --- a/worker/tests/test_supervisor.py +++ b/worker/tests/test_supervisor.py @@ -1,4 +1,5 @@ import asyncio +from logging import Logger from pathlib import Path from typing import Callable @@ -30,6 +31,7 @@ async def test_supervisor_single_node_response( hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], tmp_path: Path, + logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" model_shard_meta = pipeline_shard_meta(1, 0) @@ -40,6 +42,7 @@ async def test_supervisor_single_node_response( supervisor = await RunnerSupervisor.create( model_shard_meta=model_shard_meta, hosts=hosts(1, offset=10), + logger=logger, ) try: @@ -68,18 +71,25 @@ async def test_supervisor_two_node_response( hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], tmp_path: Path, + logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" instance_id = InstanceId() - supervisor_0 = await RunnerSupervisor.create( - model_shard_meta=pipeline_shard_meta(2, 0), - hosts=hosts(2, offset=15), + create_supervisor_0 = asyncio.create_task( + RunnerSupervisor.create( + model_shard_meta=pipeline_shard_meta(2, 0), + hosts=hosts(2, offset=15), + logger=logger, + ) ) - - supervisor_1 = await RunnerSupervisor.create( - model_shard_meta=pipeline_shard_meta(2, 1), - hosts=hosts(2, offset=15), + create_supervisor_1 = asyncio.create_task( + RunnerSupervisor.create( + model_shard_meta=pipeline_shard_meta(2, 1), + hosts=hosts(2, offset=15), + logger=logger, + ) ) + supervisor_0, supervisor_1 = await asyncio.gather(create_supervisor_0, create_supervisor_1) await asyncio.sleep(0.1) @@ -124,6 +134,7 @@ async def test_supervisor_early_stopping( hosts: Callable[..., list[Host]], chat_completion_task: Callable[[InstanceId, TaskId], Task], tmp_path: Path, + logger: Logger, ): """Test that asking for the capital of France returns 'Paris' in the response""" model_shard_meta = pipeline_shard_meta(1, 0) @@ -132,6 +143,7 @@ async def test_supervisor_early_stopping( supervisor = await RunnerSupervisor.create( model_shard_meta=model_shard_meta, hosts=hosts(1, offset=10), + logger=logger, ) task = chat_completion_task(instance_id, TaskId()) @@ -176,6 +188,7 @@ async def test_supervisor_early_stopping( async def test_supervisor_handles_terminated_runner( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], + logger: Logger, tmp_path: Path, ): """Test that the supervisor handles a terminated runner""" @@ -184,6 +197,7 @@ async def test_supervisor_handles_terminated_runner( supervisor = await RunnerSupervisor.create( model_shard_meta=model_shard_meta, hosts=hosts(1, offset=10), + logger=logger, ) # Terminate the runner @@ -201,6 +215,7 @@ async def test_supervisor_handles_killed_runner( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], tmp_path: Path, + logger: Logger, ): """Test that the supervisor handles a killed runner""" model_shard_meta = pipeline_shard_meta(1, 0) @@ -208,6 +223,7 @@ async def test_supervisor_handles_killed_runner( supervisor = await RunnerSupervisor.create( model_shard_meta=model_shard_meta, hosts=hosts(1, offset=10), + logger=logger, ) assert supervisor.healthy diff --git a/worker/tests/test_supervisor_errors.py b/worker/tests/test_supervisor_errors.py new file mode 100644 index 00000000..8b13ef62 --- /dev/null +++ b/worker/tests/test_supervisor_errors.py @@ -0,0 +1,251 @@ +import asyncio +from collections.abc import AsyncGenerator +from types import CoroutineType +from typing import Any, Awaitable, Callable, Final + +import pytest +from _pytest.monkeypatch import MonkeyPatch + +# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py +from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.types.common import NodeId +from shared.types.events import ( + ChunkGenerated, + InstanceCreated, + InstanceDeleted, + RunnerStatusUpdated, + TaskCreated, + TaskStateUpdated, + TaskFailed, +) +from shared.types.events.chunks import GenerationChunk, TokenChunk +from shared.types.models import ModelId +from shared.types.tasks import Task, TaskId, TaskStatus +from shared.types.worker.common import InstanceId, RunnerId +from shared.types.worker.instances import ( + Instance, + InstanceStatus, +) +from shared.types.worker.runners import FailedRunnerStatus +from worker.main import Worker +from worker.runner.runner_supervisor import RunnerSupervisor + +MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") +NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") +NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb") + +# Define constant IDs for deterministic test cases +RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111") +INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222") +RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333") +INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444") +MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' +MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit' +TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555") +TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666") + +@pytest.fixture +def user_message(): + """Override this fixture in tests to customize the message""" + return "Who is the longest ruling monarch of England?" + +# TODO: Make this all monkeypatched instead. + +async def test_stream_response_failed_always( + monkeypatch: MonkeyPatch, + worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], + instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + chat_completion_task: Callable[[InstanceId, TaskId], Task] +): + worker, global_events = await worker_running(NODE_A) + + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE + + async def mock_stream_response( + self: RunnerSupervisor, + task: Task, + request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, + ) -> AsyncGenerator[GenerationChunk]: + raise RuntimeError("Simulated stream response failure") + return + yield + + monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response) + + task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( + [ + InstanceCreated(instance=instance_value), + TaskCreated(task_id=task.task_id, task=task) + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(5.) + + + events = await global_events.get_events_since(0) + + assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3 + assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3 + assert any([isinstance(x.event, InstanceDeleted) for x in events]) + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(0.3) + +async def test_stream_response_failed_once( + monkeypatch: MonkeyPatch, + worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], + instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + chat_completion_task: Callable[[InstanceId, TaskId], Task] +): + failed_already = False + original_stream_response = RunnerSupervisor.stream_response + + async def mock_stream_response( + self: RunnerSupervisor, + task: Task, + request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, + ) -> AsyncGenerator[GenerationChunk]: + nonlocal failed_already + if not failed_already: + failed_already = True + raise RuntimeError("Simulated stream response failure") + else: + async for event in original_stream_response(self, task, request_started_callback): + yield event + return + + monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response) + + worker, global_events = await worker_running(NODE_A) + + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE + + task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( + [ + InstanceCreated(instance=instance_value), + TaskCreated(task_id=task.task_id, task=task) + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(5.) + + # TODO: The ideal with this test is if we had some tooling to scroll through the state, and say + # 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero' + + # as we reset the failures back to zero when we have a successful inference. + assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0 + assert worker.state.tasks[TASK_1_ID].error_type is None + assert worker.state.tasks[TASK_1_ID].error_message is None + + events = await global_events.get_events_since(0) + assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1 + assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1 + + response_string = '' + events = await global_events.get_events_since(0) + + seen_task_started, seen_task_finished = False, False + for wrapped_event in events: + event = wrapped_event.event + if isinstance(event, TaskStateUpdated): + if event.task_status == TaskStatus.RUNNING: + seen_task_started = True + if event.task_status == TaskStatus.COMPLETE: + seen_task_finished = True + + if isinstance(event, ChunkGenerated): + assert isinstance(event.chunk, TokenChunk) + response_string += event.chunk.text + + assert 'elizabeth' in response_string.lower() + assert seen_task_started + assert seen_task_finished + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(0.3) + + +async def test_stream_response_timeout( + monkeypatch: MonkeyPatch, + worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], + instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + chat_completion_task: Callable[[InstanceId, TaskId], Task] +): + async def mock_stream_response( + self: RunnerSupervisor, + task: Task, + request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, + ) -> AsyncGenerator[GenerationChunk]: + # TODO: Also a test where we yield a few chunks and then time out. + print('sleeping starting') + await asyncio.sleep(4.) + print('sleeping finished') + return + yield + + monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response) + + worker, global_events = await worker_running(NODE_A) + + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE + + task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( + [ + InstanceCreated(instance=instance_value), + TaskCreated(task_id=task.task_id, task=task) + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(7.) + + + # as we reset the failures back to zero when we have a successful inference. + + # print('ASSERTION ERR:') + # print(worker.assigned_runners[RUNNER_1_ID].failures[1][1]) + + assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0 + assert worker.state.tasks[TASK_1_ID].error_type is None + assert worker.state.tasks[TASK_1_ID].error_message is None + + events = await global_events.get_events_since(0) + print(events) + assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1 + assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1 + assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1 + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(0.3) \ No newline at end of file diff --git a/worker/tests/test_worker_handlers.py b/worker/tests/test_worker_handlers.py index ed2fed95..bc145db7 100644 --- a/worker/tests/test_worker_handlers.py +++ b/worker/tests/test_worker_handlers.py @@ -11,6 +11,7 @@ from shared.types.events import ( Event, RunnerDeleted, RunnerStatusUpdated, + TaskFailed, TaskStateUpdated, ) from shared.types.events.chunks import TokenChunk @@ -217,7 +218,7 @@ async def test_execute_task_fails( async for event in worker._execute_op(execute_task_op): # type: ignore[misc] events.append(event) - assert len(events) == 4 + assert len(events) == 5 print(events) @@ -230,5 +231,7 @@ async def test_execute_task_fails( assert isinstance(events[2], TaskStateUpdated) assert events[2].task_status == TaskStatus.FAILED # Task marked as failed. - assert isinstance(events[3], RunnerStatusUpdated) - assert isinstance(events[3].runner_status, FailedRunnerStatus) # It should have failed. \ No newline at end of file + assert isinstance(events[3], TaskFailed) + + assert isinstance(events[4], RunnerStatusUpdated) + assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed. \ No newline at end of file diff --git a/worker/tests/test_worker_integration.py b/worker/tests/test_worker_integration.py index 63e3abbd..99f8ed05 100644 --- a/worker/tests/test_worker_integration.py +++ b/worker/tests/test_worker_integration.py @@ -7,7 +7,8 @@ import pytest # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py from shared.db.sqlite.connector import AsyncSQLiteEventStorage from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from shared.types.common import Host, NodeId +from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams +from shared.types.common import CommandId, Host, NodeId from shared.types.events import ( InstanceCreated, InstanceDeleted, @@ -17,7 +18,7 @@ from shared.types.events import ( ) from shared.types.events.chunks import TokenChunk from shared.types.models import ModelId -from shared.types.tasks import Task, TaskId +from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType from shared.types.worker.common import InstanceId, RunnerId from shared.types.worker.instances import ( Instance, @@ -117,7 +118,7 @@ async def test_runner_assigned_active( origin=MASTER_NODE_ID ) - await asyncio.sleep(0.1) + await asyncio.sleep(1.0) assert len(worker.assigned_runners) == 1 assert RUNNER_1_ID in worker.assigned_runners @@ -200,7 +201,7 @@ async def test_runner_unassigns( origin=MASTER_NODE_ID ) - await asyncio.sleep(0.1) + await asyncio.sleep(0.5) # already tested by test_runner_assigned_active assert len(worker.assigned_runners) == 1 @@ -354,6 +355,102 @@ async def test_2_runner_inference( await asyncio.sleep(2.0) +async def test_2_runner_multi_message( + logger: Logger, + pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], + hosts: Callable[[int], list[Host]], + ): + event_log_manager = EventLogManager(EventLogConfig(), logger) + await event_log_manager.initialize() + shard_downloader = NoopShardDownloader() + + global_events = event_log_manager.global_events + await global_events.delete_all_events() + + worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) + asyncio.create_task(worker1.run()) + + worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events) + asyncio.create_task(worker2.run()) + + ## Instance + model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') + + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1) + }, + node_to_runner={ + NODE_A: RUNNER_1_ID, + NODE_B: RUNNER_2_ID + } + ) + + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2) + ) + + # Task - we have three messages here, which is what the task is about + + completion_create_params = ChatCompletionTaskParams( + model="gpt-4", + messages=[ + ChatCompletionMessage(role="user", content='What is the capital of France?'), + ChatCompletionMessage(role="assistant", content='The capital of France is Paris.'), + ChatCompletionMessage(role="user", content='Ok great. Now write me a haiku about what you can do there.'), + ], + stream=True, + ) + + task = ChatCompletionTask( + task_id=TASK_1_ID, + command_id=CommandId(), + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=completion_create_params + ) + + await global_events.append_events( + [ + InstanceCreated( + instance=instance + ), + TaskCreated( + task_id=task.task_id, + task=task + ) + ], + origin=MASTER_NODE_ID + ) + + seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert any(keyword in response_string.lower() for keyword in ('kiss', 'paris', 'art', 'love')) + + + idx = await global_events.get_last_idx() + await asyncio.sleep(1.0) + events = await global_events.get_events_since(idx) + assert len(events) == 0 + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(2.0) async def test_runner_respawn(