mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Test Supervisor Errors.
This commit is contained in:
@@ -52,7 +52,6 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group:
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
|
||||
# Initialize distributed
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
runner_print(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from shared.types.events import (
|
||||
RunnerStatusUpdated,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskFailed,
|
||||
TaskStateUpdated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
@@ -28,7 +29,7 @@ from shared.types.events import (
|
||||
)
|
||||
from shared.types.profiling import NodePerformanceProfile
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from shared.types.topology import Connection, Node
|
||||
from shared.types.worker.common import NodeStatus, RunnerId
|
||||
from shared.types.worker.instances import Instance, InstanceId, InstanceStatus
|
||||
@@ -63,7 +64,23 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
|
||||
if event.task_id not in state.tasks:
|
||||
return state
|
||||
|
||||
updated_task = state.tasks[event.task_id].model_copy(update={"task_status": event.task_status})
|
||||
update: dict[str, TaskStatus | None] = {
|
||||
"task_status": event.task_status,
|
||||
}
|
||||
if event.task_status != TaskStatus.FAILED:
|
||||
update["error_type"] = None
|
||||
update["error_message"] = None
|
||||
|
||||
updated_task = state.tasks[event.task_id].model_copy(update=update)
|
||||
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
@event_apply.register(TaskFailed)
|
||||
def apply_task_failed(event: TaskFailed, state: State) -> State:
|
||||
if event.task_id not in state.tasks:
|
||||
return state
|
||||
|
||||
updated_task = state.tasks[event.task_id].model_copy(update={"error_type": event.error_type, "error_message": event.error_message})
|
||||
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class _EventType(str, Enum):
|
||||
# Task Events
|
||||
TaskCreated = "TaskCreated"
|
||||
TaskStateUpdated = "TaskStateUpdated"
|
||||
TaskFailed = "TaskFailed"
|
||||
TaskDeleted = "TaskDeleted"
|
||||
|
||||
# Streaming Events
|
||||
@@ -119,6 +120,13 @@ class TaskStateUpdated(_BaseEvent[_EventType.TaskStateUpdated]):
|
||||
task_status: TaskStatus
|
||||
|
||||
|
||||
class TaskFailed(_BaseEvent[_EventType.TaskFailed]):
|
||||
event_type: Literal[_EventType.TaskFailed] = _EventType.TaskFailed
|
||||
task_id: TaskId
|
||||
error_type: str
|
||||
error_message: str
|
||||
|
||||
|
||||
class InstanceCreated(_BaseEvent[_EventType.InstanceCreated]):
|
||||
event_type: Literal[_EventType.InstanceCreated] = _EventType.InstanceCreated
|
||||
instance: Instance
|
||||
@@ -202,6 +210,7 @@ _Event = Union[
|
||||
Heartbeat,
|
||||
TaskCreated,
|
||||
TaskStateUpdated,
|
||||
TaskFailed,
|
||||
TaskDeleted,
|
||||
InstanceCreated,
|
||||
InstanceActivated,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -31,4 +31,7 @@ class ChatCompletionTask(BaseModel):
|
||||
task_status: TaskStatus
|
||||
task_params: ChatCompletionTaskParams
|
||||
|
||||
error_type: Optional[str] = Field(default=None)
|
||||
error_message: Optional[str] = Field(default=None)
|
||||
|
||||
Task = Annotated[ChatCompletionTask, Field(discriminator="task_type")]
|
||||
|
||||
@@ -51,6 +51,7 @@ RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage
|
||||
|
||||
|
||||
class RunnerResponseType(str, Enum):
|
||||
InitializedResponse = "initialized_response"
|
||||
GenerationResponse = "generation_response"
|
||||
FinishedResponse = "finished_response"
|
||||
PrintResponse = "print_response"
|
||||
@@ -64,6 +65,13 @@ class BaseRunnerResponse(BaseModel, Generic[RRT]):
|
||||
pass
|
||||
|
||||
|
||||
class InitializedResponse(BaseRunnerResponse[RunnerResponseType.InitializedResponse]):
|
||||
type: Literal[RunnerResponseType.InitializedResponse] = Field(
|
||||
default=RunnerResponseType.InitializedResponse, frozen=True
|
||||
)
|
||||
time_taken: float
|
||||
|
||||
|
||||
class GenerationResponse(BaseRunnerResponse[RunnerResponseType.GenerationResponse]):
|
||||
type: Literal[RunnerResponseType.GenerationResponse] = Field(
|
||||
default=RunnerResponseType.GenerationResponse, frozen=True
|
||||
@@ -97,7 +105,7 @@ class ErrorResponse(BaseRunnerResponse[RunnerResponseType.ErrorResponse]):
|
||||
|
||||
|
||||
RunnerResponse = Annotated[
|
||||
GenerationResponse | PrintResponse | FinishedResponse | ErrorResponse,
|
||||
InitializedResponse | GenerationResponse | PrintResponse | FinishedResponse | ErrorResponse,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
RunnerResponseTypeAdapter: TypeAdapter[RunnerResponse] = TypeAdapter(RunnerResponse)
|
||||
|
||||
152
worker/main.py
152
worker/main.py
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from asyncio import Queue
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
@@ -15,15 +16,17 @@ from shared.types.common import Host, NodeId
|
||||
from shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
InstanceDeleted,
|
||||
InstanceId,
|
||||
NodePerformanceMeasured,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskFailed,
|
||||
TaskStateUpdated,
|
||||
)
|
||||
from shared.types.profiling import NodePerformanceProfile
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import TaskStatus
|
||||
from shared.types.tasks import TaskId, TaskStatus
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
@@ -68,6 +71,7 @@ class AssignedRunner(BaseModel):
|
||||
hosts: list[Host]
|
||||
|
||||
status: RunnerStatus
|
||||
failures: list[tuple[float, Exception]] = []
|
||||
runner: Optional[RunnerSupervisor] # set if the runner is 'up'
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -141,14 +145,36 @@ class Worker:
|
||||
yield
|
||||
|
||||
async def _execute_runner_up_op(
|
||||
self, op: RunnerUpOp
|
||||
self, op: RunnerUpOp, initialize_timeout: Optional[float] = None
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
assigned_runner = self.assigned_runners[op.runner_id]
|
||||
|
||||
assigned_runner.runner = await RunnerSupervisor.create(
|
||||
model_shard_meta=assigned_runner.shard_metadata,
|
||||
hosts=assigned_runner.hosts,
|
||||
)
|
||||
# TODO: This should be dynamic, based on the size of the model.
|
||||
if not initialize_timeout:
|
||||
GBPS = 10
|
||||
|
||||
shard = assigned_runner.shard_metadata
|
||||
weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes
|
||||
|
||||
initialize_timeout = weights_size_kb / (1024**2 * GBPS) + 2.0 # Add a constant 2.0 to ensure connection can be made as well
|
||||
|
||||
try:
|
||||
assigned_runner.runner = await asyncio.wait_for(
|
||||
RunnerSupervisor.create(
|
||||
model_shard_meta=assigned_runner.shard_metadata,
|
||||
hosts=assigned_runner.hosts,
|
||||
logger=self.logger,
|
||||
),
|
||||
timeout=initialize_timeout,
|
||||
)
|
||||
except TimeoutError as e:
|
||||
import traceback
|
||||
|
||||
tb = traceback.format_exc()
|
||||
e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}")
|
||||
async for event in self._fail_runner(e=e, runner_id=op.runner_id):
|
||||
yield event
|
||||
return
|
||||
|
||||
if assigned_runner.runner.healthy:
|
||||
assigned_runner.status = LoadedRunnerStatus()
|
||||
@@ -161,8 +187,9 @@ class Worker:
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
assigned_runner = self.assigned_runners[op.runner_id]
|
||||
|
||||
assert isinstance(assigned_runner.runner, RunnerSupervisor)
|
||||
await assigned_runner.runner.astop()
|
||||
if isinstance(assigned_runner.runner, RunnerSupervisor):
|
||||
await assigned_runner.runner.astop()
|
||||
|
||||
assigned_runner.runner = None
|
||||
|
||||
assigned_runner.status = ReadyRunnerStatus()
|
||||
@@ -287,9 +314,6 @@ class Worker:
|
||||
assigned_runner = self.assigned_runners[op.runner_id]
|
||||
|
||||
async def inner_execute(queue: asyncio.Queue[Event]) -> None:
|
||||
assert assigned_runner.runner is not None
|
||||
assert assigned_runner.runner.healthy
|
||||
|
||||
async def running_callback(queue: asyncio.Queue[Event]) -> None:
|
||||
# Called when the MLX process has been kicked off
|
||||
assigned_runner.status = RunningRunnerStatus()
|
||||
@@ -302,6 +326,9 @@ class Worker:
|
||||
))
|
||||
|
||||
try:
|
||||
assert assigned_runner.runner is not None
|
||||
assert assigned_runner.runner.healthy
|
||||
|
||||
async for chunk in assigned_runner.runner.stream_response(
|
||||
task=op.task,
|
||||
request_started_callback=partial(running_callback, queue)):
|
||||
@@ -325,34 +352,44 @@ class Worker:
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# TODO: What log level?
|
||||
self.logger.log(2, f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}')
|
||||
|
||||
if assigned_runner.shard_metadata.device_rank == 0:
|
||||
await queue.put(TaskStateUpdated(
|
||||
task_id=op.task.task_id,
|
||||
task_status=TaskStatus.FAILED,
|
||||
))
|
||||
|
||||
assigned_runner.runner = None
|
||||
assigned_runner.status = FailedRunnerStatus(error_message=str(e))
|
||||
await queue.put(assigned_runner.status_update_event())
|
||||
# An exception occurs in the runner supervisor
|
||||
self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}')
|
||||
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
|
||||
await queue.put(event)
|
||||
|
||||
queue: Queue[Event] = asyncio.Queue()
|
||||
task = asyncio.create_task(inner_execute(queue))
|
||||
|
||||
# TODO: Initial (prefil) timeout can be dynamic
|
||||
# model_kb = assigned_runner.shard_metadata.model_meta.storage_size_kilobytes
|
||||
|
||||
try:
|
||||
# Yield items from the queue
|
||||
# timeout = 30.
|
||||
timeout = 3.
|
||||
while True:
|
||||
item: Event = await asyncio.wait_for(queue.get(), timeout=5)
|
||||
item: Event = await asyncio.wait_for(queue.get(), timeout=timeout)
|
||||
yield item
|
||||
timeout = 2.
|
||||
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
||||
item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus)
|
||||
):
|
||||
if isinstance(item.runner_status, LoadedRunnerStatus):
|
||||
assigned_runner.failures = []
|
||||
|
||||
break
|
||||
except TimeoutError as e:
|
||||
# Runner supervisor doesn't respond in time; so we put the runner & task into a failed state
|
||||
self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.')
|
||||
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
|
||||
yield event
|
||||
finally:
|
||||
# Ensure the task is cleaned up
|
||||
await task
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Timed out waiting for task cleanup after inference execution.")
|
||||
|
||||
|
||||
## Operation Planner
|
||||
|
||||
@@ -381,6 +418,10 @@ class Worker:
|
||||
def plan(self, state: State) -> RunnerOp | None:
|
||||
# Compare state to worker 'mood'
|
||||
|
||||
# for runner_id, assigned_runner in self.assigned_runners.items():
|
||||
# if len(assigned_runner.failures) == 3:
|
||||
# raise Exception('Too many error occurred in assigned runner - assumed to be recurrent and unrecoverable.\nErrors are as follows: {assigned_runner.failures}')
|
||||
|
||||
# First, unassign assigned runners that are no longer in the state.
|
||||
for runner_id, _ in self.assigned_runners.items():
|
||||
runner_ids: list[RunnerId] = [
|
||||
@@ -512,7 +553,9 @@ class Worker:
|
||||
continue # The only previous state to get to Running is from Loaded
|
||||
|
||||
for _, task in state.tasks.items():
|
||||
if task.instance_id == instance_id and task.task_status == TaskStatus.PENDING:
|
||||
if task.instance_id == instance_id and (
|
||||
task.task_status == TaskStatus.PENDING or task.task_status == TaskStatus.FAILED
|
||||
):
|
||||
if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1):
|
||||
return ExecuteTaskOp(runner_id=runner_id, task=task)
|
||||
else:
|
||||
@@ -530,17 +573,56 @@ class Worker:
|
||||
return None
|
||||
|
||||
|
||||
async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]:
|
||||
if runner_id in self.assigned_runners:
|
||||
assigned_runner = self.assigned_runners[runner_id]
|
||||
|
||||
assigned_runner.runner = None
|
||||
assigned_runner.status = FailedRunnerStatus(error_message=str(e))
|
||||
assigned_runner.failures.append(
|
||||
(
|
||||
time.time(),
|
||||
e
|
||||
)
|
||||
)
|
||||
|
||||
# Reset failure count back to 0 when succesful
|
||||
if len(assigned_runner.failures) >= 3:
|
||||
# Too many retries. We will emit a DeleteInstance
|
||||
yield InstanceDeleted(
|
||||
instance_id=assigned_runner.instance_id
|
||||
)
|
||||
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
|
||||
async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]:
|
||||
if runner_id in self.assigned_runners:
|
||||
yield TaskStateUpdated(
|
||||
task_id=task_id,
|
||||
task_status=TaskStatus.FAILED,
|
||||
)
|
||||
|
||||
yield TaskFailed(
|
||||
task_id=task_id,
|
||||
error_type=str(type(e)),
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
async for event in self._fail_runner(e, runner_id):
|
||||
yield event
|
||||
|
||||
|
||||
async def event_publisher(self, event: Event) -> None:
|
||||
assert self.worker_events is not None
|
||||
await self.worker_events.append_events([event], self.node_id)
|
||||
print(f"published event: {event}")
|
||||
self.logger.info(f"published event: {event}")
|
||||
|
||||
# Handle state updates
|
||||
async def run(self):
|
||||
assert self.global_events is not None
|
||||
|
||||
while True:
|
||||
_rank = list(self.assigned_runners.values())[0].shard_metadata.device_rank if self.assigned_runners else None
|
||||
# 1. get latest events
|
||||
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
|
||||
|
||||
@@ -555,8 +637,18 @@ class Worker:
|
||||
|
||||
# run the op, synchronously blocking for now
|
||||
if op is not None:
|
||||
async for event in self._execute_op(op):
|
||||
await self.event_publisher(event)
|
||||
try:
|
||||
async for event in self._execute_op(op):
|
||||
await self.event_publisher(event)
|
||||
except Exception as e:
|
||||
# execeute_task_op already has its own exception handling here. So we assume we had an exception in one of the other op types.
|
||||
# we therefore just fail the runner.
|
||||
self.logger.warning(f"Encountered exception when executing worker op {op}: {e}. \n Runner will be spun down and retried.")
|
||||
async for event in self._fail_runner(
|
||||
e,
|
||||
runner_id=op.runner_id,
|
||||
):
|
||||
await self.event_publisher(event)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
if len(events) > 0:
|
||||
|
||||
@@ -47,9 +47,13 @@ async def runner_read_message() -> RunnerMessage:
|
||||
|
||||
|
||||
def runner_write_response(obj: RunnerResponse) -> None:
|
||||
encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n"
|
||||
_ = sys.stdout.buffer.write(encoded)
|
||||
_ = sys.stdout.buffer.flush()
|
||||
try:
|
||||
encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n"
|
||||
_ = sys.stdout.buffer.write(encoded)
|
||||
_ = sys.stdout.buffer.flush()
|
||||
except BrokenPipeError:
|
||||
# Supervisor has closed the pipe, silently exit
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
async def supervisor_read_response(
|
||||
@@ -83,6 +87,10 @@ def runner_print(text: str) -> None:
|
||||
|
||||
|
||||
def runner_write_error(error: Exception) -> None:
|
||||
# Skip writing error if it's a BrokenPipeError - supervisor is already gone
|
||||
if isinstance(error, BrokenPipeError):
|
||||
sys.exit(0)
|
||||
|
||||
error_response: ErrorResponse = ErrorResponse(
|
||||
type=RunnerResponseType.ErrorResponse,
|
||||
error_type=type(error).__name__,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import partial
|
||||
from typing import Callable, cast
|
||||
@@ -17,6 +18,7 @@ from shared.types.worker.commands_runner import (
|
||||
ExitMessage,
|
||||
FinishedResponse,
|
||||
GenerationResponse,
|
||||
InitializedResponse,
|
||||
RunnerMessage,
|
||||
SetupMessage,
|
||||
)
|
||||
@@ -98,23 +100,24 @@ async def _mlx_generate(
|
||||
async def main():
|
||||
try:
|
||||
runner_print("hello from the runner")
|
||||
|
||||
# Get setup info from worker
|
||||
init_message = await runner_read_message()
|
||||
setup_message = ensure_type(init_message, SetupMessage)
|
||||
model_shard_meta = setup_message.model_shard_meta
|
||||
hosts = setup_message.hosts
|
||||
|
||||
setup_start_time = time.time()
|
||||
|
||||
mlx_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
runner_print(f"got here; {hosts}")
|
||||
|
||||
model, tokenizer, sampler = await loop.run_in_executor(
|
||||
mlx_executor,
|
||||
partial(initialize_mlx, model_shard_meta=model_shard_meta, hosts=hosts),
|
||||
)
|
||||
|
||||
runner_write_response(InitializedResponse(time_taken=time.time() - setup_start_time))
|
||||
|
||||
while True:
|
||||
message: RunnerMessage = await runner_read_message()
|
||||
match message:
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import contextlib
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from logging import Logger
|
||||
from types import CoroutineType
|
||||
from typing import Any, Callable
|
||||
|
||||
@@ -14,6 +15,7 @@ from shared.types.worker.commands_runner import (
|
||||
ExitMessage,
|
||||
FinishedResponse,
|
||||
GenerationResponse,
|
||||
InitializedResponse,
|
||||
PrintResponse,
|
||||
RunnerResponse,
|
||||
SetupMessage,
|
||||
@@ -54,6 +56,7 @@ class RunnerSupervisor:
|
||||
cls,
|
||||
model_shard_meta: ShardMetadata,
|
||||
hosts: list[Host],
|
||||
logger: Logger
|
||||
) -> "RunnerSupervisor":
|
||||
"""
|
||||
Create and initialize a RunnerSupervisor instance.
|
||||
@@ -66,7 +69,7 @@ class RunnerSupervisor:
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=sys.stderr,
|
||||
stderr=sys.stderr
|
||||
)
|
||||
)
|
||||
|
||||
@@ -79,6 +82,21 @@ class RunnerSupervisor:
|
||||
),
|
||||
)
|
||||
|
||||
while True:
|
||||
line: RunnerResponse | None = await supervisor_read_response(
|
||||
runner_process
|
||||
)
|
||||
if line is None or isinstance(line, PrintResponse):
|
||||
# print(line)
|
||||
continue
|
||||
elif isinstance(line, ErrorResponse):
|
||||
raise Exception(line.error_type, line.error_message, line.traceback or "")
|
||||
else:
|
||||
assert isinstance(line, InitializedResponse)
|
||||
logger.info(f'Runner initialized in {line.time_taken} seconds')
|
||||
print(f'Runner initialized in {line.time_taken} seconds')
|
||||
break
|
||||
|
||||
return cls(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts,
|
||||
@@ -203,6 +221,8 @@ class RunnerSupervisor:
|
||||
token_id=token,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
case InitializedResponse():
|
||||
raise ValueError('Initialized Response read during streaming flow')
|
||||
case FinishedResponse():
|
||||
break
|
||||
case PrintResponse(text=text):
|
||||
|
||||
189
worker/tests/test_runner_connection.py
Normal file
189
worker/tests/test_runner_connection.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import asyncio
|
||||
import os
|
||||
from logging import Logger
|
||||
from typing import Callable, Final
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
|
||||
from shared.types.common import Host, NodeId
|
||||
from shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.tasks import Task
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
|
||||
from shared.types.worker.runners import FailedRunnerStatus
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.download.shard_downloader import NoopShardDownloader
|
||||
from worker.main import Worker
|
||||
|
||||
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
|
||||
|
||||
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
|
||||
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
|
||||
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
|
||||
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
|
||||
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
||||
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
||||
TASK_1_ID: Final = "55555555-5555-4555-8555-555555555555"
|
||||
TASK_2_ID: Final = "66666666-6666-4666-8666-666666666666"
|
||||
|
||||
@pytest.fixture
|
||||
def user_message() -> str:
|
||||
return "What is the capital of Japan?"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("DETAILED", "").lower() != "true",
|
||||
reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set"
|
||||
)
|
||||
async def check_runner_connection(
|
||||
logger: Logger,
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
hosts: Callable[[int], list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, str], Task],
|
||||
) -> bool:
|
||||
# Track all tasks and workers for cleanup
|
||||
tasks: list[asyncio.Task[None]] = []
|
||||
workers: list[Worker] = []
|
||||
|
||||
try:
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
shard_downloader = NoopShardDownloader()
|
||||
|
||||
global_events = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(
|
||||
NODE_A,
|
||||
logger=logger,
|
||||
shard_downloader=shard_downloader,
|
||||
worker_events=global_events,
|
||||
global_events=global_events,
|
||||
)
|
||||
workers.append(worker1)
|
||||
task1 = asyncio.create_task(worker1.run())
|
||||
tasks.append(task1)
|
||||
|
||||
worker2 = Worker(
|
||||
NODE_B,
|
||||
logger=logger,
|
||||
shard_downloader=shard_downloader,
|
||||
worker_events=global_events,
|
||||
global_events=global_events,
|
||||
)
|
||||
workers.append(worker2)
|
||||
task2 = asyncio.create_task(worker2.run())
|
||||
tasks.append(task2)
|
||||
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_id,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: pipeline_shard_meta(2, 0),
|
||||
RUNNER_2_ID: pipeline_shard_meta(2, 1)
|
||||
},
|
||||
node_to_runner={
|
||||
NODE_A: RUNNER_1_ID,
|
||||
NODE_B: RUNNER_2_ID
|
||||
}
|
||||
)
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2)
|
||||
)
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
from worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
async def wait_for_runner_supervisor(worker: Worker, timeout: float = 5.0) -> RunnerSupervisor | None:
|
||||
end = asyncio.get_event_loop().time() + timeout
|
||||
while True:
|
||||
assigned_runners = list(worker.assigned_runners.values())
|
||||
if assigned_runners:
|
||||
runner = assigned_runners[0].runner
|
||||
if isinstance(runner, RunnerSupervisor):
|
||||
print('breaking because success')
|
||||
return runner
|
||||
if isinstance(assigned_runners[0].status, FailedRunnerStatus):
|
||||
print('breaking because failed')
|
||||
return runner
|
||||
if asyncio.get_event_loop().time() > end:
|
||||
raise TimeoutError("RunnerSupervisor was not set within timeout")
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
runner_supervisor = await wait_for_runner_supervisor(worker1, timeout=6.0)
|
||||
ret = runner_supervisor is not None and runner_supervisor.healthy
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return ret
|
||||
finally:
|
||||
# Cancel all worker tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for cancellation to complete
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check Running status
|
||||
|
||||
def test_runner_connection_stress(
|
||||
logger: Logger,
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
hosts: Callable[[int], list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, str], Task],
|
||||
) -> None:
|
||||
total_runs = 100
|
||||
successes = 0
|
||||
|
||||
for _ in range(total_runs):
|
||||
# Create a fresh event loop for each iteration
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(check_runner_connection(
|
||||
logger=logger,
|
||||
pipeline_shard_meta=pipeline_shard_meta,
|
||||
hosts=hosts,
|
||||
chat_completion_task=chat_completion_task,
|
||||
))
|
||||
if result:
|
||||
successes += 1
|
||||
finally:
|
||||
# Cancel all running tasks
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Run the event loop briefly to allow cancellation to complete
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
|
||||
# Close the event loop
|
||||
loop.close()
|
||||
|
||||
print(f"Runner connection successes: {successes} / {total_runs}")
|
||||
48
worker/tests/test_spinup_timeout.py
Normal file
48
worker/tests/test_spinup_timeout.py
Normal file
@@ -0,0 +1,48 @@
|
||||
## Tests for worker state handlers
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.types.events import (
|
||||
Event,
|
||||
)
|
||||
from shared.types.events._events import RunnerStatusUpdated
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.instances import Instance, InstanceId
|
||||
from shared.types.worker.ops import (
|
||||
RunnerUpOp,
|
||||
)
|
||||
from shared.types.worker.runners import FailedRunnerStatus
|
||||
from worker.main import Worker
|
||||
|
||||
# To enable this test, run pytest with: ENABLE_SPINUP_TIMEOUT_TEST=true pytest
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("DETAILED", "").lower() != "true",
|
||||
reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_up_op_timeout(
|
||||
worker_with_assigned_runner: tuple[Worker, RunnerId, Instance],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
worker, runner_id, _ = worker_with_assigned_runner
|
||||
|
||||
runner_up_op = RunnerUpOp(runner_id=runner_id)
|
||||
|
||||
# _execute_runner_up_op should throw a TimeoutError with a short timeout
|
||||
events: list[Event] = []
|
||||
async for event in worker._execute_runner_up_op(runner_up_op, initialize_timeout=0.2): # type: ignore[misc]
|
||||
events.append(event)
|
||||
|
||||
assert isinstance(events[-1], RunnerStatusUpdated)
|
||||
assert isinstance(events[-1].runner_status, FailedRunnerStatus)
|
||||
assert events[-1].runner_status.error_message is not None
|
||||
assert 'timeout' in events[-1].runner_status.error_message.lower()
|
||||
|
||||
del worker.assigned_runners[list(worker.assigned_runners.keys())[0]]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -30,6 +31,7 @@ async def test_supervisor_single_node_response(
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
tmp_path: Path,
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
@@ -40,6 +42,7 @@ async def test_supervisor_single_node_response(
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -68,18 +71,25 @@ async def test_supervisor_two_node_response(
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
tmp_path: Path,
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
instance_id = InstanceId()
|
||||
supervisor_0 = await RunnerSupervisor.create(
|
||||
model_shard_meta=pipeline_shard_meta(2, 0),
|
||||
hosts=hosts(2, offset=15),
|
||||
create_supervisor_0 = asyncio.create_task(
|
||||
RunnerSupervisor.create(
|
||||
model_shard_meta=pipeline_shard_meta(2, 0),
|
||||
hosts=hosts(2, offset=15),
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
supervisor_1 = await RunnerSupervisor.create(
|
||||
model_shard_meta=pipeline_shard_meta(2, 1),
|
||||
hosts=hosts(2, offset=15),
|
||||
create_supervisor_1 = asyncio.create_task(
|
||||
RunnerSupervisor.create(
|
||||
model_shard_meta=pipeline_shard_meta(2, 1),
|
||||
hosts=hosts(2, offset=15),
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
supervisor_0, supervisor_1 = await asyncio.gather(create_supervisor_0, create_supervisor_1)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -124,6 +134,7 @@ async def test_supervisor_early_stopping(
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
tmp_path: Path,
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
@@ -132,6 +143,7 @@ async def test_supervisor_early_stopping(
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
task = chat_completion_task(instance_id, TaskId())
|
||||
@@ -176,6 +188,7 @@ async def test_supervisor_early_stopping(
|
||||
async def test_supervisor_handles_terminated_runner(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
logger: Logger,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test that the supervisor handles a terminated runner"""
|
||||
@@ -184,6 +197,7 @@ async def test_supervisor_handles_terminated_runner(
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Terminate the runner
|
||||
@@ -201,6 +215,7 @@ async def test_supervisor_handles_killed_runner(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
tmp_path: Path,
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that the supervisor handles a killed runner"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
@@ -208,6 +223,7 @@ async def test_supervisor_handles_killed_runner(
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert supervisor.healthy
|
||||
|
||||
251
worker/tests/test_supervisor_errors.py
Normal file
251
worker/tests/test_supervisor_errors.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from types import CoroutineType
|
||||
from typing import Any, Awaitable, Callable, Final
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import (
|
||||
ChunkGenerated,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskCreated,
|
||||
TaskStateUpdated,
|
||||
TaskFailed,
|
||||
)
|
||||
from shared.types.events.chunks import GenerationChunk, TokenChunk
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceStatus,
|
||||
)
|
||||
from shared.types.worker.runners import FailedRunnerStatus
|
||||
from worker.main import Worker
|
||||
from worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
|
||||
|
||||
# Define constant IDs for deterministic test cases
|
||||
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
|
||||
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
|
||||
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
|
||||
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
|
||||
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
||||
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
||||
TASK_1_ID: Final[TaskId] = TaskId("55555555-5555-4555-8555-555555555555")
|
||||
TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
|
||||
|
||||
@pytest.fixture
|
||||
def user_message():
|
||||
"""Override this fixture in tests to customize the message"""
|
||||
return "Who is the longest ruling monarch of England?"
|
||||
|
||||
# TODO: Make this all monkeypatched instead.
|
||||
|
||||
async def test_stream_response_failed_always(
|
||||
monkeypatch: MonkeyPatch,
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
async def mock_stream_response(
|
||||
self: RunnerSupervisor,
|
||||
task: Task,
|
||||
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
||||
) -> AsyncGenerator[GenerationChunk]:
|
||||
raise RuntimeError("Simulated stream response failure")
|
||||
return
|
||||
yield
|
||||
|
||||
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
|
||||
|
||||
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(instance=instance_value),
|
||||
TaskCreated(task_id=task.task_id, task=task)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(5.)
|
||||
|
||||
|
||||
events = await global_events.get_events_since(0)
|
||||
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
||||
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance_value.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
async def test_stream_response_failed_once(
|
||||
monkeypatch: MonkeyPatch,
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
failed_already = False
|
||||
original_stream_response = RunnerSupervisor.stream_response
|
||||
|
||||
async def mock_stream_response(
|
||||
self: RunnerSupervisor,
|
||||
task: Task,
|
||||
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
||||
) -> AsyncGenerator[GenerationChunk]:
|
||||
nonlocal failed_already
|
||||
if not failed_already:
|
||||
failed_already = True
|
||||
raise RuntimeError("Simulated stream response failure")
|
||||
else:
|
||||
async for event in original_stream_response(self, task, request_started_callback):
|
||||
yield event
|
||||
return
|
||||
|
||||
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
|
||||
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(instance=instance_value),
|
||||
TaskCreated(task_id=task.task_id, task=task)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(5.)
|
||||
|
||||
# TODO: The ideal with this test is if we had some tooling to scroll through the state, and say
|
||||
# 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero'
|
||||
|
||||
# as we reset the failures back to zero when we have a successful inference.
|
||||
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
|
||||
assert worker.state.tasks[TASK_1_ID].error_type is None
|
||||
assert worker.state.tasks[TASK_1_ID].error_message is None
|
||||
|
||||
events = await global_events.get_events_since(0)
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
|
||||
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
|
||||
|
||||
response_string = ''
|
||||
events = await global_events.get_events_since(0)
|
||||
|
||||
seen_task_started, seen_task_finished = False, False
|
||||
for wrapped_event in events:
|
||||
event = wrapped_event.event
|
||||
if isinstance(event, TaskStateUpdated):
|
||||
if event.task_status == TaskStatus.RUNNING:
|
||||
seen_task_started = True
|
||||
if event.task_status == TaskStatus.COMPLETE:
|
||||
seen_task_finished = True
|
||||
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
response_string += event.chunk.text
|
||||
|
||||
assert 'elizabeth' in response_string.lower()
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance_value.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
async def test_stream_response_timeout(
|
||||
monkeypatch: MonkeyPatch,
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
async def mock_stream_response(
|
||||
self: RunnerSupervisor,
|
||||
task: Task,
|
||||
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
||||
) -> AsyncGenerator[GenerationChunk]:
|
||||
# TODO: Also a test where we yield a few chunks and then time out.
|
||||
print('sleeping starting')
|
||||
await asyncio.sleep(4.)
|
||||
print('sleeping finished')
|
||||
return
|
||||
yield
|
||||
|
||||
monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
|
||||
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(instance=instance_value),
|
||||
TaskCreated(task_id=task.task_id, task=task)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(7.)
|
||||
|
||||
|
||||
# as we reset the failures back to zero when we have a successful inference.
|
||||
|
||||
# print('ASSERTION ERR:')
|
||||
# print(worker.assigned_runners[RUNNER_1_ID].failures[1][1])
|
||||
|
||||
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
|
||||
assert worker.state.tasks[TASK_1_ID].error_type is None
|
||||
assert worker.state.tasks[TASK_1_ID].error_message is None
|
||||
|
||||
events = await global_events.get_events_since(0)
|
||||
print(events)
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
|
||||
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
|
||||
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance_value.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
@@ -11,6 +11,7 @@ from shared.types.events import (
|
||||
Event,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskFailed,
|
||||
TaskStateUpdated,
|
||||
)
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
@@ -217,7 +218,7 @@ async def test_execute_task_fails(
|
||||
async for event in worker._execute_op(execute_task_op): # type: ignore[misc]
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 4
|
||||
assert len(events) == 5
|
||||
|
||||
print(events)
|
||||
|
||||
@@ -230,5 +231,7 @@ async def test_execute_task_fails(
|
||||
assert isinstance(events[2], TaskStateUpdated)
|
||||
assert events[2].task_status == TaskStatus.FAILED # Task marked as failed.
|
||||
|
||||
assert isinstance(events[3], RunnerStatusUpdated)
|
||||
assert isinstance(events[3].runner_status, FailedRunnerStatus) # It should have failed.
|
||||
assert isinstance(events[3], TaskFailed)
|
||||
|
||||
assert isinstance(events[4], RunnerStatusUpdated)
|
||||
assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed.
|
||||
@@ -7,7 +7,8 @@ import pytest
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
|
||||
from shared.types.common import Host, NodeId
|
||||
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from shared.types.common import CommandId, Host, NodeId
|
||||
from shared.types.events import (
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
@@ -17,7 +18,7 @@ from shared.types.events import (
|
||||
)
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -117,7 +118,7 @@ async def test_runner_assigned_active(
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
assert len(worker.assigned_runners) == 1
|
||||
assert RUNNER_1_ID in worker.assigned_runners
|
||||
@@ -200,7 +201,7 @@ async def test_runner_unassigns(
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# already tested by test_runner_assigned_active
|
||||
assert len(worker.assigned_runners) == 1
|
||||
@@ -354,6 +355,102 @@ async def test_2_runner_inference(
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
async def test_2_runner_multi_message(
|
||||
logger: Logger,
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
hosts: Callable[[int], list[Host]],
|
||||
):
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
shard_downloader = NoopShardDownloader()
|
||||
|
||||
global_events = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(worker1.run())
|
||||
|
||||
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(worker2.run())
|
||||
|
||||
## Instance
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_id,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: pipeline_shard_meta(2, 0),
|
||||
RUNNER_2_ID: pipeline_shard_meta(2, 1)
|
||||
},
|
||||
node_to_runner={
|
||||
NODE_A: RUNNER_1_ID,
|
||||
NODE_B: RUNNER_2_ID
|
||||
}
|
||||
)
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2)
|
||||
)
|
||||
|
||||
# Task - we have three messages here, which is what the task is about
|
||||
|
||||
completion_create_params = ChatCompletionTaskParams(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content='What is the capital of France?'),
|
||||
ChatCompletionMessage(role="assistant", content='The capital of France is Paris.'),
|
||||
ChatCompletionMessage(role="user", content='Ok great. Now write me a haiku about what you can do there.'),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
task = ChatCompletionTask(
|
||||
task_id=TASK_1_ID,
|
||||
command_id=CommandId(),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=completion_create_params
|
||||
)
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance
|
||||
),
|
||||
TaskCreated(
|
||||
task_id=task.task_id,
|
||||
task=task
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
assert any(keyword in response_string.lower() for keyword in ('kiss', 'paris', 'art', 'love'))
|
||||
|
||||
|
||||
idx = await global_events.get_last_idx()
|
||||
await asyncio.sleep(1.0)
|
||||
events = await global_events.get_events_since(idx)
|
||||
assert len(events) == 0
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
|
||||
async def test_runner_respawn(
|
||||
|
||||
Reference in New Issue
Block a user