Test Supervisor Errors.

This commit is contained in:
Matt Beton
2025-07-30 13:30:54 +01:00
committed by GitHub
parent ff3d11c748
commit b350ededb2
15 changed files with 819 additions and 56 deletions

View File

@@ -52,7 +52,6 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
# Initialize distributed
group = mx.distributed.init(backend="ring", strict=True)
runner_print(f"Rank {rank} mlx distributed initialization complete")

View File

@@ -19,6 +19,7 @@ from shared.types.events import (
RunnerStatusUpdated,
TaskCreated,
TaskDeleted,
TaskFailed,
TaskStateUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
@@ -28,7 +29,7 @@ from shared.types.events import (
)
from shared.types.profiling import NodePerformanceProfile
from shared.types.state import State
from shared.types.tasks import Task, TaskId
from shared.types.tasks import Task, TaskId, TaskStatus
from shared.types.topology import Connection, Node
from shared.types.worker.common import NodeStatus, RunnerId
from shared.types.worker.instances import Instance, InstanceId, InstanceStatus
@@ -63,7 +64,23 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
if event.task_id not in state.tasks:
return state
updated_task = state.tasks[event.task_id].model_copy(update={"task_status": event.task_status})
update: dict[str, TaskStatus | None] = {
"task_status": event.task_status,
}
if event.task_status != TaskStatus.FAILED:
update["error_type"] = None
update["error_message"] = None
updated_task = state.tasks[event.task_id].model_copy(update=update)
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
return state.model_copy(update={"tasks": new_tasks})
@event_apply.register(TaskFailed)
def apply_task_failed(event: TaskFailed, state: State) -> State:
if event.task_id not in state.tasks:
return state
updated_task = state.tasks[event.task_id].model_copy(update={"error_type": event.error_type, "error_message": event.error_message})
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
return state.model_copy(update={"tasks": new_tasks})

View File

@@ -49,6 +49,7 @@ class _EventType(str, Enum):
# Task Events
TaskCreated = "TaskCreated"
TaskStateUpdated = "TaskStateUpdated"
TaskFailed = "TaskFailed"
TaskDeleted = "TaskDeleted"
# Streaming Events
@@ -119,6 +120,13 @@ class TaskStateUpdated(_BaseEvent[_EventType.TaskStateUpdated]):
task_status: TaskStatus
class TaskFailed(_BaseEvent[_EventType.TaskFailed]):
event_type: Literal[_EventType.TaskFailed] = _EventType.TaskFailed
task_id: TaskId
error_type: str
error_message: str
class InstanceCreated(_BaseEvent[_EventType.InstanceCreated]):
event_type: Literal[_EventType.InstanceCreated] = _EventType.InstanceCreated
instance: Instance
@@ -202,6 +210,7 @@ _Event = Union[
Heartbeat,
TaskCreated,
TaskStateUpdated,
TaskFailed,
TaskDeleted,
InstanceCreated,
InstanceActivated,

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Annotated, Literal
from typing import Annotated, Literal, Optional
from pydantic import BaseModel, Field
@@ -31,4 +31,7 @@ class ChatCompletionTask(BaseModel):
task_status: TaskStatus
task_params: ChatCompletionTaskParams
error_type: Optional[str] = Field(default=None)
error_message: Optional[str] = Field(default=None)
Task = Annotated[ChatCompletionTask, Field(discriminator="task_type")]

View File

@@ -51,6 +51,7 @@ RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage
class RunnerResponseType(str, Enum):
InitializedResponse = "initialized_response"
GenerationResponse = "generation_response"
FinishedResponse = "finished_response"
PrintResponse = "print_response"
@@ -64,6 +65,13 @@ class BaseRunnerResponse(BaseModel, Generic[RRT]):
pass
class InitializedResponse(BaseRunnerResponse[RunnerResponseType.InitializedResponse]):
type: Literal[RunnerResponseType.InitializedResponse] = Field(
default=RunnerResponseType.InitializedResponse, frozen=True
)
time_taken: float
class GenerationResponse(BaseRunnerResponse[RunnerResponseType.GenerationResponse]):
type: Literal[RunnerResponseType.GenerationResponse] = Field(
default=RunnerResponseType.GenerationResponse, frozen=True
@@ -97,7 +105,7 @@ class ErrorResponse(BaseRunnerResponse[RunnerResponseType.ErrorResponse]):
RunnerResponse = Annotated[
GenerationResponse | PrintResponse | FinishedResponse | ErrorResponse,
InitializedResponse | GenerationResponse | PrintResponse | FinishedResponse | ErrorResponse,
Field(discriminator="type"),
]
RunnerResponseTypeAdapter: TypeAdapter[RunnerResponse] = TypeAdapter(RunnerResponse)

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import time
from asyncio import Queue
from copy import deepcopy
from functools import partial
@@ -15,15 +16,17 @@ from shared.types.common import Host, NodeId
from shared.types.events import (
ChunkGenerated,
Event,
InstanceDeleted,
InstanceId,
NodePerformanceMeasured,
RunnerDeleted,
RunnerStatusUpdated,
TaskFailed,
TaskStateUpdated,
)
from shared.types.profiling import NodePerformanceProfile
from shared.types.state import State
from shared.types.tasks import TaskStatus
from shared.types.tasks import TaskId, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.downloads import (
DownloadCompleted,
@@ -68,6 +71,7 @@ class AssignedRunner(BaseModel):
hosts: list[Host]
status: RunnerStatus
failures: list[tuple[float, Exception]] = []
runner: Optional[RunnerSupervisor] # set if the runner is 'up'
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -141,14 +145,36 @@ class Worker:
yield
async def _execute_runner_up_op(
self, op: RunnerUpOp
self, op: RunnerUpOp, initialize_timeout: Optional[float] = None
) -> AsyncGenerator[Event, None]:
assigned_runner = self.assigned_runners[op.runner_id]
assigned_runner.runner = await RunnerSupervisor.create(
model_shard_meta=assigned_runner.shard_metadata,
hosts=assigned_runner.hosts,
)
# TODO: This should be dynamic, based on the size of the model.
if not initialize_timeout:
GBPS = 10
shard = assigned_runner.shard_metadata
weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes
initialize_timeout = weights_size_kb / (1024**2 * GBPS) + 2.0 # Add a constant 2.0 to ensure connection can be made as well
try:
assigned_runner.runner = await asyncio.wait_for(
RunnerSupervisor.create(
model_shard_meta=assigned_runner.shard_metadata,
hosts=assigned_runner.hosts,
logger=self.logger,
),
timeout=initialize_timeout,
)
except TimeoutError as e:
import traceback
tb = traceback.format_exc()
e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}")
async for event in self._fail_runner(e=e, runner_id=op.runner_id):
yield event
return
if assigned_runner.runner.healthy:
assigned_runner.status = LoadedRunnerStatus()
@@ -161,8 +187,9 @@ class Worker:
) -> AsyncGenerator[Event, None]:
assigned_runner = self.assigned_runners[op.runner_id]
assert isinstance(assigned_runner.runner, RunnerSupervisor)
await assigned_runner.runner.astop()
if isinstance(assigned_runner.runner, RunnerSupervisor):
await assigned_runner.runner.astop()
assigned_runner.runner = None
assigned_runner.status = ReadyRunnerStatus()
@@ -287,9 +314,6 @@ class Worker:
assigned_runner = self.assigned_runners[op.runner_id]
async def inner_execute(queue: asyncio.Queue[Event]) -> None:
assert assigned_runner.runner is not None
assert assigned_runner.runner.healthy
async def running_callback(queue: asyncio.Queue[Event]) -> None:
# Called when the MLX process has been kicked off
assigned_runner.status = RunningRunnerStatus()
@@ -302,6 +326,9 @@ class Worker:
))
try:
assert assigned_runner.runner is not None
assert assigned_runner.runner.healthy
async for chunk in assigned_runner.runner.stream_response(
task=op.task,
request_started_callback=partial(running_callback, queue)):
@@ -325,34 +352,44 @@ class Worker:
except Exception as e:
# TODO: What log level?
self.logger.log(2, f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}')
if assigned_runner.shard_metadata.device_rank == 0:
await queue.put(TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.FAILED,
))
assigned_runner.runner = None
assigned_runner.status = FailedRunnerStatus(error_message=str(e))
await queue.put(assigned_runner.status_update_event())
# An exception occurs in the runner supervisor
self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}')
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
await queue.put(event)
queue: Queue[Event] = asyncio.Queue()
task = asyncio.create_task(inner_execute(queue))
# TODO: Initial (prefil) timeout can be dynamic
# model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes
try:
# Yield items from the queue
# timeout = 30.
timeout = 3.
while True:
item: Event = await asyncio.wait_for(queue.get(), timeout=5)
item: Event = await asyncio.wait_for(queue.get(), timeout=timeout)
yield item
timeout = 2.
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus)
):
if isinstance(item.runner_status, LoadedRunnerStatus):
assigned_runner.failures = []
break
except TimeoutError as e:
# Runner supervisor doesn't respond in time; so we put the runner & task into a failed state
self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.')
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
yield event
finally:
# Ensure the task is cleaned up
await task
try:
await asyncio.wait_for(task, timeout=5)
except asyncio.TimeoutError:
self.logger.warning("Timed out waiting for task cleanup after inference execution.")
## Operation Planner
@@ -381,6 +418,10 @@ class Worker:
def plan(self, state: State) -> RunnerOp | None:
# Compare state to worker 'mood'
# for runner_id, assigned_runner in self.assigned_runners.items():
# if len(assigned_runner.failures) == 3:
# raise Exception('Too many error occurred in assigned runner - assumed to be recurrent and unrecoverable.\nErrors are as follows: {assigned_runner.failures}')
# First, unassign assigned runners that are no longer in the state.
for runner_id, _ in self.assigned_runners.items():
runner_ids: list[RunnerId] = [
@@ -512,7 +553,9 @@ class Worker:
continue # The only previous state to get to Running is from Loaded
for _, task in state.tasks.items():
if task.instance_id == instance_id and task.task_status == TaskStatus.PENDING:
if task.instance_id == instance_id and (
task.task_status == TaskStatus.PENDING or task.task_status == TaskStatus.FAILED
):
if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1):
return ExecuteTaskOp(runner_id=runner_id, task=task)
else:
@@ -530,17 +573,56 @@ class Worker:
return None
async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]:
if runner_id in self.assigned_runners:
assigned_runner = self.assigned_runners[runner_id]
assigned_runner.runner = None
assigned_runner.status = FailedRunnerStatus(error_message=str(e))
assigned_runner.failures.append(
(
time.time(),
e
)
)
# Reset failure count back to 0 when succesful
if len(assigned_runner.failures) >= 3:
# Too many retries. We will emit a DeleteInstance
yield InstanceDeleted(
instance_id=assigned_runner.instance_id
)
yield assigned_runner.status_update_event()
async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]:
if runner_id in self.assigned_runners:
yield TaskStateUpdated(
task_id=task_id,
task_status=TaskStatus.FAILED,
)
yield TaskFailed(
task_id=task_id,
error_type=str(type(e)),
error_message=str(e)
)
async for event in self._fail_runner(e, runner_id):
yield event
async def event_publisher(self, event: Event) -> None:
assert self.worker_events is not None
await self.worker_events.append_events([event], self.node_id)
print(f"published event: {event}")
self.logger.info(f"published event: {event}")
# Handle state updates
async def run(self):
assert self.global_events is not None
while True:
_rank = list(self.assigned_runners.values())[0].shard_metadata.device_rank if self.assigned_runners else None
# 1. get latest events
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
@@ -555,8 +637,18 @@ class Worker:
# run the op, synchronously blocking for now
if op is not None:
async for event in self._execute_op(op):
await self.event_publisher(event)
try:
async for event in self._execute_op(op):
await self.event_publisher(event)
except Exception as e:
# execeute_task_op already has its own exception handling here. So we assume we had an exception in one of the other op types.
# we therefore just fail the runner.
self.logger.warning(f"Encountered exception when executing worker op {op}: {e}. \n Runner will be spun down and retried.")
async for event in self._fail_runner(
e,
runner_id=op.runner_id,
):
await self.event_publisher(event)
await asyncio.sleep(0.01)
if len(events) > 0:

View File

@@ -47,9 +47,13 @@ async def runner_read_message() -> RunnerMessage:
def runner_write_response(obj: RunnerResponse) -> None:
encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n"
_ = sys.stdout.buffer.write(encoded)
_ = sys.stdout.buffer.flush()
try:
encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n"
_ = sys.stdout.buffer.write(encoded)
_ = sys.stdout.buffer.flush()
except BrokenPipeError:
# Supervisor has closed the pipe, silently exit
sys.exit(0)
async def supervisor_read_response(
@@ -83,6 +87,10 @@ def runner_print(text: str) -> None:
def runner_write_error(error: Exception) -> None:
# Skip writing error if it's a BrokenPipeError - supervisor is already gone
if isinstance(error, BrokenPipeError):
sys.exit(0)
error_response: ErrorResponse = ErrorResponse(
type=RunnerResponseType.ErrorResponse,
error_type=type(error).__name__,

View File

@@ -1,5 +1,6 @@
import asyncio
import concurrent.futures
import time
from collections.abc import AsyncGenerator
from functools import partial
from typing import Callable, cast
@@ -17,6 +18,7 @@ from shared.types.worker.commands_runner import (
ExitMessage,
FinishedResponse,
GenerationResponse,
InitializedResponse,
RunnerMessage,
SetupMessage,
)
@@ -98,23 +100,24 @@ async def _mlx_generate(
async def main():
try:
runner_print("hello from the runner")
# Get setup info from worker
init_message = await runner_read_message()
setup_message = ensure_type(init_message, SetupMessage)
model_shard_meta = setup_message.model_shard_meta
hosts = setup_message.hosts
setup_start_time = time.time()
mlx_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
loop = asyncio.get_running_loop()
runner_print(f"got here; {hosts}")
model, tokenizer, sampler = await loop.run_in_executor(
mlx_executor,
partial(initialize_mlx, model_shard_meta=model_shard_meta, hosts=hosts),
)
runner_write_response(InitializedResponse(time_taken=time.time() - setup_start_time))
while True:
message: RunnerMessage = await runner_read_message()
match message:

View File

@@ -2,6 +2,7 @@ import asyncio
import contextlib
import sys
from collections.abc import AsyncGenerator
from logging import Logger
from types import CoroutineType
from typing import Any, Callable
@@ -14,6 +15,7 @@ from shared.types.worker.commands_runner import (
ExitMessage,
FinishedResponse,
GenerationResponse,
InitializedResponse,
PrintResponse,
RunnerResponse,
SetupMessage,
@@ -54,6 +56,7 @@ class RunnerSupervisor:
cls,
model_shard_meta: ShardMetadata,
hosts: list[Host],
logger: Logger
) -> "RunnerSupervisor":
"""
Create and initialize a RunnerSupervisor instance.
@@ -66,7 +69,7 @@ class RunnerSupervisor:
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=sys.stderr,
stderr=sys.stderr
)
)
@@ -79,6 +82,21 @@ class RunnerSupervisor:
),
)
while True:
line: RunnerResponse | None = await supervisor_read_response(
runner_process
)
if line is None or isinstance(line, PrintResponse):
# print(line)
continue
elif isinstance(line, ErrorResponse):
raise Exception(line.error_type, line.error_message, line.traceback or "")
else:
assert isinstance(line, InitializedResponse)
logger.info(f'Runner initialized in {line.time_taken} seconds')
print(f'Runner initialized in {line.time_taken} seconds')
break
return cls(
model_shard_meta=model_shard_meta,
hosts=hosts,
@@ -203,6 +221,8 @@ class RunnerSupervisor:
token_id=token,
finish_reason=finish_reason,
)
case InitializedResponse():
raise ValueError('Initialized Response read during streaming flow')
case FinishedResponse():
break
case PrintResponse(text=text):

View File

@@ -0,0 +1,189 @@
import asyncio
import os
from logging import Logger
from typing import Callable, Final
import pytest
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import Host, NodeId
from shared.types.events import InstanceCreated, InstanceDeleted
from shared.types.models import ModelId
from shared.types.tasks import Task
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
from shared.types.worker.runners import FailedRunnerStatus
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.shard_downloader import NoopShardDownloader
from worker.main import Worker
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
TASK_1_ID: Final = "55555555-5555-4555-8555-555555555555"
TASK_2_ID: Final = "66666666-6666-4666-8666-666666666666"
@pytest.fixture
def user_message() -> str:
return "What is the capital of Japan?"
@pytest.mark.skipif(
os.environ.get("DETAILED", "").lower() != "true",
reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set"
)
async def check_runner_connection(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, str], Task],
) -> bool:
# Track all tasks and workers for cleanup
tasks: list[asyncio.Task[None]] = []
workers: list[Worker] = []
try:
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(
NODE_A,
logger=logger,
shard_downloader=shard_downloader,
worker_events=global_events,
global_events=global_events,
)
workers.append(worker1)
task1 = asyncio.create_task(worker1.run())
tasks.append(task1)
worker2 = Worker(
NODE_B,
logger=logger,
shard_downloader=shard_downloader,
worker_events=global_events,
global_events=global_events,
)
workers.append(worker2)
task2 = asyncio.create_task(worker2.run())
tasks.append(task2)
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
],
origin=MASTER_NODE_ID
)
from worker.runner.runner_supervisor import RunnerSupervisor
async def wait_for_runner_supervisor(worker: Worker, timeout: float = 5.0) -> RunnerSupervisor | None:
end = asyncio.get_event_loop().time() + timeout
while True:
assigned_runners = list(worker.assigned_runners.values())
if assigned_runners:
runner = assigned_runners[0].runner
if isinstance(runner, RunnerSupervisor):
print('breaking because success')
return runner
if isinstance(assigned_runners[0].status, FailedRunnerStatus):
print('breaking because failed')
return runner
if asyncio.get_event_loop().time() > end:
raise TimeoutError("RunnerSupervisor was not set within timeout")
await asyncio.sleep(0.001)
runner_supervisor = await wait_for_runner_supervisor(worker1, timeout=6.0)
ret = runner_supervisor is not None and runner_supervisor.healthy
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.5)
return ret
finally:
# Cancel all worker tasks
for task in tasks:
task.cancel()
# Wait for cancellation to complete
await asyncio.gather(*tasks, return_exceptions=True)
# Check Running status
def test_runner_connection_stress(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, str], Task],
) -> None:
total_runs = 100
successes = 0
for _ in range(total_runs):
# Create a fresh event loop for each iteration
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(check_runner_connection(
logger=logger,
pipeline_shard_meta=pipeline_shard_meta,
hosts=hosts,
chat_completion_task=chat_completion_task,
))
if result:
successes += 1
finally:
# Cancel all running tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# Run the event loop briefly to allow cancellation to complete
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
# Close the event loop
loop.close()
print(f"Runner connection successes: {successes} / {total_runs}")

View File

@@ -0,0 +1,48 @@
## Tests for worker state handlers
import os
from typing import Callable
import pytest
from shared.types.events import (
Event,
)
from shared.types.events._events import RunnerStatusUpdated
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import RunnerId
from shared.types.worker.instances import Instance, InstanceId
from shared.types.worker.ops import (
RunnerUpOp,
)
from shared.types.worker.runners import FailedRunnerStatus
from worker.main import Worker
# To enable this test, run pytest with: ENABLE_SPINUP_TIMEOUT_TEST=true pytest
@pytest.mark.skipif(
os.environ.get("DETAILED", "").lower() != "true",
reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set"
)
@pytest.mark.asyncio
async def test_runner_up_op_timeout(
worker_with_assigned_runner: tuple[Worker, RunnerId, Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
monkeypatch: pytest.MonkeyPatch
):
worker, runner_id, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=runner_id)
# _execute_runner_up_op should throw a TimeoutError with a short timeout
events: list[Event] = []
async for event in worker._execute_runner_up_op(runner_up_op, initialize_timeout=0.2): # type: ignore[misc]
events.append(event)
assert isinstance(events[-1], RunnerStatusUpdated)
assert isinstance(events[-1].runner_status, FailedRunnerStatus)
assert events[-1].runner_status.error_message is not None
assert 'timeout' in events[-1].runner_status.error_message.lower()
del worker.assigned_runners[list(worker.assigned_runners.keys())[0]]

View File

@@ -1,4 +1,5 @@
import asyncio
from logging import Logger
from pathlib import Path
from typing import Callable
@@ -30,6 +31,7 @@ async def test_supervisor_single_node_response(
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
tmp_path: Path,
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
@@ -40,6 +42,7 @@ async def test_supervisor_single_node_response(
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
try:
@@ -68,18 +71,25 @@ async def test_supervisor_two_node_response(
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
tmp_path: Path,
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
instance_id = InstanceId()
supervisor_0 = await RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, 0),
hosts=hosts(2, offset=15),
create_supervisor_0 = asyncio.create_task(
RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, 0),
hosts=hosts(2, offset=15),
logger=logger,
)
)
supervisor_1 = await RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, 1),
hosts=hosts(2, offset=15),
create_supervisor_1 = asyncio.create_task(
RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, 1),
hosts=hosts(2, offset=15),
logger=logger,
)
)
supervisor_0, supervisor_1 = await asyncio.gather(create_supervisor_0, create_supervisor_1)
await asyncio.sleep(0.1)
@@ -124,6 +134,7 @@ async def test_supervisor_early_stopping(
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
tmp_path: Path,
logger: Logger,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
@@ -132,6 +143,7 @@ async def test_supervisor_early_stopping(
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
task = chat_completion_task(instance_id, TaskId())
@@ -176,6 +188,7 @@ async def test_supervisor_early_stopping(
async def test_supervisor_handles_terminated_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
logger: Logger,
tmp_path: Path,
):
"""Test that the supervisor handles a terminated runner"""
@@ -184,6 +197,7 @@ async def test_supervisor_handles_terminated_runner(
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
# Terminate the runner
@@ -201,6 +215,7 @@ async def test_supervisor_handles_killed_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
tmp_path: Path,
logger: Logger,
):
"""Test that the supervisor handles a killed runner"""
model_shard_meta = pipeline_shard_meta(1, 0)
@@ -208,6 +223,7 @@ async def test_supervisor_handles_killed_runner(
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
)
assert supervisor.healthy

View File

@@ -0,0 +1,251 @@
import asyncio
from collections.abc import AsyncGenerator
from types import CoroutineType
from typing import Any, Awaitable, Callable, Final
import pytest
from _pytest.monkeypatch import MonkeyPatch
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.common import NodeId
from shared.types.events import (
ChunkGenerated,
InstanceCreated,
InstanceDeleted,
RunnerStatusUpdated,
TaskCreated,
TaskStateUpdated,
TaskFailed,
)
from shared.types.events.chunks import GenerationChunk, TokenChunk
from shared.types.models import ModelId
from shared.types.tasks import Task, TaskId, TaskStatus
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
Instance,
InstanceStatus,
)
from shared.types.worker.runners import FailedRunnerStatus
from worker.main import Worker
from worker.runner.runner_supervisor import RunnerSupervisor
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
# Define constant IDs for deterministic test cases
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555")
TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
@pytest.fixture
def user_message():
"""Override this fixture in tests to customize the message"""
return "Who is the longest ruling monarch of England?"
# TODO: Make this all monkeypatched instead.
async def test_stream_response_failed_always(
monkeypatch: MonkeyPatch,
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
raise RuntimeError("Simulated stream response failure")
return
yield
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(5.)
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3
assert any([isinstance(x.event, InstanceDeleted) for x in events])
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
async def test_stream_response_failed_once(
monkeypatch: MonkeyPatch,
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
failed_already = False
original_stream_response = RunnerSupervisor.stream_response
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
nonlocal failed_already
if not failed_already:
failed_already = True
raise RuntimeError("Simulated stream response failure")
else:
async for event in original_stream_response(self, task, request_started_callback):
yield event
return
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(5.)
# TODO: The ideal with this test is if we had some tooling to scroll through the state, and say
# 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero'
# as we reset the failures back to zero when we have a successful inference.
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
assert worker.state.tasks[TASK_1_ID].error_type is None
assert worker.state.tasks[TASK_1_ID].error_message is None
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
response_string = ''
events = await global_events.get_events_since(0)
seen_task_started, seen_task_finished = False, False
for wrapped_event in events:
event = wrapped_event.event
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.RUNNING:
seen_task_started = True
if event.task_status == TaskStatus.COMPLETE:
seen_task_finished = True
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
response_string += event.chunk.text
assert 'elizabeth' in response_string.lower()
assert seen_task_started
assert seen_task_finished
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
async def test_stream_response_timeout(
monkeypatch: MonkeyPatch,
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task]
):
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
# TODO: Also a test where we yield a few chunks and then time out.
print('sleeping starting')
await asyncio.sleep(4.)
print('sleeping finished')
return
yield
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task)
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(7.)
# as we reset the failures back to zero when we have a successful inference.
# print('ASSERTION ERR:')
# print(worker.assigned_runners[RUNNER_1_ID].failures[1][1])
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
assert worker.state.tasks[TASK_1_ID].error_type is None
assert worker.state.tasks[TASK_1_ID].error_message is None
events = await global_events.get_events_since(0)
print(events)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)

View File

@@ -11,6 +11,7 @@ from shared.types.events import (
Event,
RunnerDeleted,
RunnerStatusUpdated,
TaskFailed,
TaskStateUpdated,
)
from shared.types.events.chunks import TokenChunk
@@ -217,7 +218,7 @@ async def test_execute_task_fails(
async for event in worker._execute_op(execute_task_op): # type: ignore[misc]
events.append(event)
assert len(events) == 4
assert len(events) == 5
print(events)
@@ -230,5 +231,7 @@ async def test_execute_task_fails(
assert isinstance(events[2], TaskStateUpdated)
assert events[2].task_status == TaskStatus.FAILED # Task marked as failed.
assert isinstance(events[3], RunnerStatusUpdated)
assert isinstance(events[3].runner_status, FailedRunnerStatus) # It should have failed.
assert isinstance(events[3], TaskFailed)
assert isinstance(events[4], RunnerStatusUpdated)
assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed.

View File

@@ -7,7 +7,8 @@ import pytest
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import Host, NodeId
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.common import CommandId, Host, NodeId
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
@@ -17,7 +18,7 @@ from shared.types.events import (
)
from shared.types.events.chunks import TokenChunk
from shared.types.models import ModelId
from shared.types.tasks import Task, TaskId
from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import (
Instance,
@@ -117,7 +118,7 @@ async def test_runner_assigned_active(
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.1)
await asyncio.sleep(1.0)
assert len(worker.assigned_runners) == 1
assert RUNNER_1_ID in worker.assigned_runners
@@ -200,7 +201,7 @@ async def test_runner_unassigns(
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.1)
await asyncio.sleep(0.5)
# already tested by test_runner_assigned_active
assert len(worker.assigned_runners) == 1
@@ -354,6 +355,102 @@ async def test_2_runner_inference(
await asyncio.sleep(2.0)
async def test_2_runner_multi_message(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
shard_downloader = NoopShardDownloader()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker1.run())
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker2.run())
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
# Task - we have three messages here, which is what the task is about
completion_create_params = ChatCompletionTaskParams(
model="gpt-4",
messages=[
ChatCompletionMessage(role="user", content='What is the capital of France?'),
ChatCompletionMessage(role="assistant", content='The capital of France is Paris.'),
ChatCompletionMessage(role="user", content='Ok great. Now write me a haiku about what you can do there.'),
],
stream=True,
)
task = ChatCompletionTask(
task_id=TASK_1_ID,
command_id=CommandId(),
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=completion_create_params
)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert any(keyword in response_string.lower() for keyword in ('kiss', 'paris', 'art', 'love'))
idx = await global_events.get_last_idx()
await asyncio.sleep(1.0)
events = await global_events.get_events_since(idx)
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)
async def test_runner_respawn(