Inference Integration Test

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
Matt Beton
2025-07-26 20:08:25 +01:00
committed by GitHub
parent 2e4635a8f5
commit 93330f0283
13 changed files with 476 additions and 73 deletions

25
read_events.py Normal file
View File

@@ -0,0 +1,25 @@
import asyncio
from logging import Logger
from worker.main import get_node_id
from shared.types.common import NodeId
from shared.db.sqlite.event_log_manager import EventLogManager, EventLogConfig
async def main():
node_id: NodeId = get_node_id()
logger: Logger = Logger('worker_log')
event_log_manager: EventLogManager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
events = await event_log_manager.global_events.get_events_since(0)
for wrapped_event in events:
event = wrapped_event.event
event_type = type(event).__name__.replace('_', ' ').title()
attributes = ', '.join(f"{key}={value!r}" for key, value in vars(event).items())
print(f"{event_type}: {attributes}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -106,7 +106,7 @@ def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> Sta
return state.model_copy(update={"runners": new_runners})
@event_apply.register(RunnerDeleted)
def apply_runner_deleted(event: RunnerStatusUpdated, state: State) -> State:
def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {rid: rs for rid, rs in state.runners.items() if rid != event.runner_id}
return state.model_copy(update={"runners": new_runners})

View File

@@ -15,6 +15,7 @@ class RunnerOpType(str, Enum):
UNASSIGN_RUNNER = "unassign_runner"
RUNNER_UP = "runner_up"
RUNNER_DOWN = "runner_down"
RUNNER_FAILED = "runner_failed"
DOWNLOAD = "download"
CHAT_COMPLETION = "chat_completion"
@@ -42,6 +43,10 @@ class RunnerDownOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_DOWN]]):
op_type: Literal[RunnerOpType.RUNNER_DOWN] = Field(default=RunnerOpType.RUNNER_DOWN, frozen=True)
runner_id: RunnerId
class RunnerFailedOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_FAILED]]):
op_type: Literal[RunnerOpType.RUNNER_FAILED] = Field(default=RunnerOpType.RUNNER_FAILED, frozen=True)
runner_id: RunnerId
class DownloadOp(BaseRunnerOp[Literal[RunnerOpType.DOWNLOAD]]):
op_type: Literal[RunnerOpType.DOWNLOAD] = Field(default=RunnerOpType.DOWNLOAD, frozen=True)
instance_id: InstanceId
@@ -62,6 +67,7 @@ RunnerOp = Annotated[
UnassignRunnerOp,
RunnerUpOp,
RunnerDownOp,
RunnerFailedOp,
DownloadOp,
ExecuteTaskOp,
],

View File

@@ -38,6 +38,7 @@ from shared.types.worker.ops import (
DownloadOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerFailedOp,
RunnerOp,
RunnerOpType,
RunnerUpOp,
@@ -162,6 +163,18 @@ class Worker:
assigned_runner.status = ReadyRunnerStatus()
yield assigned_runner.status_update_event()
return
async def _execute_runner_failed_op(
self, op: RunnerFailedOp
) -> AsyncGenerator[Event, None]:
'''
We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down.
'''
assigned_runner = self.assigned_runners[op.runner_id]
assigned_runner.status = FailedRunnerStatus()
yield self.assigned_runners[op.runner_id].status_update_event()
async def _execute_download_op(
self, op: DownloadOp
@@ -309,6 +322,8 @@ class Worker:
event_generator = self._execute_runner_up_op(op)
case RunnerOpType.RUNNER_DOWN:
event_generator = self._execute_runner_down_op(op)
case RunnerOpType.RUNNER_FAILED:
event_generator = self._execute_runner_failed_op(op)
case RunnerOpType.DOWNLOAD:
event_generator = self._execute_download_op(op)
case RunnerOpType.CHAT_COMPLETION:
@@ -331,6 +346,12 @@ class Worker:
if runner_id not in runner_ids:
return UnassignRunnerOp(runner_id=runner_id)
for runner_id, assigned_runner in self.assigned_runners.items():
if assigned_runner.runner is not None and \
not assigned_runner.runner.healthy and \
not isinstance(assigned_runner.status, FailedRunnerStatus):
return RunnerFailedOp(runner_id=runner_id)
# Then spin down active runners
for _instance_id, instance in state.instances.items():
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
@@ -346,7 +367,9 @@ class Worker:
# If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down
# TODO: We need to limit number of retries if we keep failing.
for _instance_id, instance in state.instances.items():
if self.node_id in instance.shard_assignments.node_to_runner:
if self.node_id in instance.shard_assignments.node_to_runner and \
instance.shard_assignments.node_to_runner[self.node_id] in self.assigned_runners and \
not isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, ReadyRunnerStatus): # make sure that our runner has not already been spun down into ready state
other_node_in_instance_has_failed = False
for runner_id in instance.shard_assignments.runner_to_shard:
if runner_id in state.runners and \
@@ -362,13 +385,17 @@ class Worker:
for _instance_id, instance in state.instances.items():
if self.node_id in instance.shard_assignments.node_to_runner and \
instance.shard_assignments.node_to_runner[self.node_id] in state.runners and \
isinstance(state.runners[instance.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus):
isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, FailedRunnerStatus):
num_spundown_nodes = 0
for runner_id in instance.shard_assignments.runner_to_shard:
if isinstance(state.runners[runner_id], ReadyRunnerStatus) and \
runner_id not in self.assigned_runners:
num_spundown_nodes += 1
# Suggested:
# if runner_id in state.runners and isinstance(state.runners[runner_id], ReadyRunnerStatus):
# if runner_id != instance.shard_assignments.node_to_runner[self.node_id]:
# num_spundown_nodes += 1
if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1:
# All the other nodes are spun down - so now we can spin down too.
@@ -421,7 +448,7 @@ class Worker:
# Need to assert all other runners are ready before we can spin up.
ready_to_spin = True
for runner_id in instance.shard_assignments.node_to_runner.values():
if state.runners[runner_id].runner_status != RunnerStatusType.Ready:
if runner_id in state.runners and state.runners[runner_id].runner_status != RunnerStatusType.Ready:
ready_to_spin = False
if ready_to_spin:
@@ -438,7 +465,7 @@ class Worker:
continue # The only previous state to get to Running is from Loaded
for _, task in state.tasks.items():
if task.instance_id == instance_id:
if task.instance_id == instance_id and task.task_status == TaskStatus.PENDING:
if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1):
return ExecuteTaskOp(runner_id=runner_id, task=task)
else:
@@ -465,11 +492,9 @@ class Worker:
assert self.global_events is not None
while True:
_rank = list(self.assigned_runners.values())[0].shard_metadata.device_rank if self.assigned_runners else None
# 1. get latest events
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
if len(events) == 0:
await asyncio.sleep(0.01)
continue
# 2. for each event, apply it to the state and run sagas
for event_from_log in events:

View File

@@ -91,8 +91,8 @@ async def _mlx_generate(
runner_print(item.text)
yield item
# TODO: There is a big bug on this line!
assert future.done()
# Wait for the executor thread to complete
await future
async def main():

View File

@@ -88,12 +88,23 @@ class RunnerSupervisor:
async def astop(self) -> None:
async def terminate() -> None:
self.runner_process.terminate()
_ = await self.runner_process.wait()
# Check if process is already dead before trying to terminate
if self.runner_process.returncode is None:
self.runner_process.terminate()
# Wait for the process to exit (or confirm it's already exited)
try:
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=1.0)
except asyncio.TimeoutError:
# If terminate didn't work, force kill
if self.runner_process.returncode is None:
self.runner_process.kill()
_ = await self.runner_process.wait()
if not self.healthy:
print("Runner process is not healthy, killing...")
await terminate()
print('terminated')
if self.runner_process.stdout is not None:
while True:
@@ -107,15 +118,20 @@ class RunnerSupervisor:
except asyncio.TimeoutError:
break
try:
# Give the process a moment to exit gracefully
await supervisor_write_message(
proc=self.runner_process, message=ExitMessage()
)
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1)
except asyncio.TimeoutError:
print("Runner process did not terminate, killing...")
await terminate()
# Only try to send ExitMessage if process is still alive
if self.runner_process.returncode is None:
try:
# Give the process a moment to exit gracefully
await supervisor_write_message(
proc=self.runner_process, message=ExitMessage()
)
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1)
except asyncio.TimeoutError:
print("Runner process did not terminate, killing...")
await terminate()
except Exception:
# If we can't write to the process (e.g., broken pipe), it's probably already dead
pass
self.running = False
@@ -124,7 +140,7 @@ class RunnerSupervisor:
self.running = False
def __del__(self) -> None:
if not self.running:
if self.running:
print(
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process."
)

View File

@@ -99,14 +99,16 @@ def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
)
@pytest.fixture
def chat_completion_task(completion_create_params: ChatCompletionTaskParams) -> ChatCompletionTask:
return ChatCompletionTask(
task_id=TaskId(),
instance_id=InstanceId(),
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=completion_create_params
)
def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
def _chat_completion_task(instance_id: InstanceId) -> ChatCompletionTask:
return ChatCompletionTask(
task_id=TaskId(),
instance_id=instance_id,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=completion_create_params
)
return _chat_completion_task
@pytest.fixture
def node_id() -> NodeId:
@@ -129,7 +131,7 @@ def logger() -> Logger:
@pytest.fixture
def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts_one: list[Host]):
def _instance(node_id: NodeId, runner_id: RunnerId) -> Instance:
def _instance(instance_id: InstanceId, node_id: NodeId, runner_id: RunnerId) -> Instance:
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
@@ -156,10 +158,10 @@ async def worker(node_id: NodeId, logger: Logger):
return Worker(node_id, logger, worker_events=event_log_manager.global_events, global_events=event_log_manager.global_events)
@pytest.fixture
async def worker_with_assigned_runner(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance]):
async def worker_with_assigned_runner(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]):
"""Fixture that provides a worker with an already assigned runner."""
instance_obj: Instance = instance(worker.node_id, RunnerId())
instance_obj: Instance = instance(InstanceId(), worker.node_id, RunnerId())
# Extract runner_id from shard assignments
runner_id = next(iter(instance_obj.shard_assignments.runner_to_shard))

View File

@@ -9,6 +9,7 @@ from shared.types.worker.commands_runner import (
RunnerMessageTypeAdapter,
SetupMessage,
)
from shared.types.worker.common import InstanceId
from shared.types.worker.mlx import Host
from shared.types.worker.shards import PipelineShardMetadata
@@ -37,9 +38,10 @@ def test_supervisor_setup_message_serdes(
def test_supervisor_task_message_serdes(
chat_completion_task: Task,
chat_completion_task: Callable[[InstanceId], Task],
):
task = chat_completion_task(InstanceId())
task_message = ChatTaskMessage(
task_data=chat_completion_task.task_params,
task_data=task.task_params,
)
assert_equal_serdes(task_message, RunnerMessageTypeAdapter)

View File

@@ -11,6 +11,7 @@ from shared.types.tasks import (
Task,
TaskType,
)
from shared.types.worker.common import InstanceId
from shared.types.worker.mlx import Host
from shared.types.worker.shards import PipelineShardMetadata
from worker.runner.runner_supervisor import RunnerSupervisor
@@ -26,11 +27,12 @@ def user_message():
async def test_supervisor_single_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Task,
chat_completion_task: Callable[[InstanceId], Task],
tmp_path: Path,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
instance_id = InstanceId()
print(f'{model_shard_meta=}')
@@ -43,7 +45,7 @@ async def test_supervisor_single_node_response(
full_response = ""
stop_reason: FinishReason | None = None
async for chunk in supervisor.stream_response(task=chat_completion_task):
async for chunk in supervisor.stream_response(task=chat_completion_task(instance_id)):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
if chunk.finish_reason:
@@ -63,10 +65,11 @@ async def test_supervisor_single_node_response(
async def test_supervisor_two_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Task,
chat_completion_task: Callable[[InstanceId], Task],
tmp_path: Path,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
instance_id = InstanceId()
supervisor_0 = await RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, 0),
hosts=hosts(2, offset=15),
@@ -85,13 +88,13 @@ async def test_supervisor_two_node_response(
async def collect_response_0():
nonlocal full_response_0
async for chunk in supervisor_0.stream_response(task=chat_completion_task):
async for chunk in supervisor_0.stream_response(task=chat_completion_task(instance_id)):
if isinstance(chunk, TokenChunk):
full_response_0 += chunk.text
async def collect_response_1():
nonlocal full_response_1
async for chunk in supervisor_1.stream_response(task=chat_completion_task):
async for chunk in supervisor_1.stream_response(task=chat_completion_task(instance_id)):
if isinstance(chunk, TokenChunk):
full_response_1 += chunk.text
@@ -118,22 +121,25 @@ async def test_supervisor_two_node_response(
async def test_supervisor_early_stopping(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Task,
chat_completion_task: Callable[[InstanceId], Task],
tmp_path: Path,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
instance_id = InstanceId()
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
task = chat_completion_task(instance_id)
max_tokens = 50
assert chat_completion_task.task_type == TaskType.CHAT_COMPLETION
print(f'chat_completion_task.task_params: {chat_completion_task.task_params}')
assert isinstance(chat_completion_task.task_params, ChatCompletionTaskParams)
task_params: ChatCompletionTaskParams = chat_completion_task.task_params
assert task.task_type == TaskType.CHAT_COMPLETION
print(f'chat_completion_task.task_params: {task.task_params}')
assert isinstance(task.task_params, ChatCompletionTaskParams)
task_params: ChatCompletionTaskParams = task.task_params
try:
task_params.max_tokens = max_tokens
@@ -146,7 +152,7 @@ async def test_supervisor_early_stopping(
count = 0
stop_reason: FinishReason | None = None
async for chunk in supervisor.stream_response(task=chat_completion_task):
async for chunk in supervisor.stream_response(task=task):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
count += 1
@@ -169,7 +175,6 @@ async def test_supervisor_early_stopping(
async def test_supervisor_handles_terminated_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Task,
tmp_path: Path,
):
"""Test that the supervisor handles a terminated runner"""
@@ -194,7 +199,6 @@ async def test_supervisor_handles_terminated_runner(
async def test_supervisor_handles_killed_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Task,
tmp_path: Path,
):
"""Test that the supervisor handles a killed runner"""

View File

@@ -16,7 +16,7 @@ from shared.types.events import (
from shared.types.events.chunks import TokenChunk
from shared.types.tasks import Task, TaskStatus
from shared.types.worker.common import RunnerId
from shared.types.worker.instances import Instance
from shared.types.worker.instances import Instance, InstanceId
from shared.types.worker.ops import (
AssignRunnerOp,
DownloadOp,
@@ -40,9 +40,9 @@ def user_message():
return "What, according to Douglas Adams, is the meaning of life, the universe and everything?"
@pytest.mark.asyncio
async def test_assign_op(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance], tmp_path: Path):
async def test_assign_op(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance], tmp_path: Path):
runner_id = RunnerId()
instance_obj: Instance = instance(worker.node_id, runner_id)
instance_obj: Instance = instance(InstanceId(), worker.node_id, runner_id)
assign_op = AssignRunnerOp(
runner_id=runner_id,
@@ -84,7 +84,7 @@ async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
assert isinstance(events[0], RunnerDeleted)
@pytest.mark.asyncio
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_completion_task: Task, tmp_path: Path):
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path):
worker, runner_id, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=runner_id)
@@ -104,7 +104,7 @@ async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
full_response = ''
async for chunk in supervisor.stream_response(task=chat_completion_task):
async for chunk in supervisor.stream_response(task=chat_completion_task(InstanceId())):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
@@ -153,12 +153,12 @@ async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
@pytest.mark.asyncio
async def test_execute_task_op(
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
chat_completion_task: Task, tmp_path: Path):
chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path):
worker, runner_id, _ = worker_with_running_runner
execute_task_op = ExecuteTaskOp(
runner_id=runner_id,
task=chat_completion_task
task=chat_completion_task(InstanceId())
)
events: list[Event] = []
@@ -196,15 +196,16 @@ async def test_execute_task_op(
@pytest.mark.asyncio
async def test_execute_task_fails(
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
chat_completion_task: Task, tmp_path: Path):
chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path):
worker, runner_id, _ = worker_with_running_runner
messages = chat_completion_task.task_params.messages
task = chat_completion_task(InstanceId())
messages = task.task_params.messages
messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL'
execute_task_op = ExecuteTaskOp(
runner_id=runner_id,
task=chat_completion_task
task=task
)
events: list[Event] = []

View File

@@ -1,27 +1,39 @@
import asyncio
from logging import Logger
from typing import Awaitable, Callable, Final
import pytest
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
from shared.types.common import NodeId
from shared.types.events import (
InstanceCreated,
InstanceDeleted,
RunnerDeleted,
RunnerStatusUpdated,
TaskCreated,
)
from shared.types.events.chunks import TokenChunk
from shared.types.models import ModelId
from shared.types.tasks import Task, TaskId
from shared.types.worker.common import InstanceId, RunnerId
from shared.types.worker.instances import Instance, InstanceStatus
from shared.types.worker.instances import (
Instance,
InstanceStatus,
ShardAssignments,
)
from shared.types.worker.mlx import Host
from shared.types.worker.runners import (
FailedRunnerStatus,
LoadedRunnerStatus,
ReadyRunnerStatus,
# RunningRunnerStatus,
)
from worker.main import Worker
from shared.types.worker.shards import PipelineShardMetadata
from worker.main import AssignedRunner, Worker
from worker.tests.test_worker_integration_utils import read_streaming_response
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
@@ -42,14 +54,14 @@ def user_message():
async def test_runner_assigned(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[NodeId, RunnerId], Instance]
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
):
worker, global_events = await worker_running(NODE_A)
print(worker)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.INACTIVE
await global_events.append_events(
@@ -79,12 +91,12 @@ async def test_runner_assigned(
async def test_runner_assigned_active(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[NodeId, RunnerId], Instance],
chat_completion_task: Task
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId], Task]
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
await global_events.append_events(
@@ -118,7 +130,7 @@ async def test_runner_assigned_active(
full_response = ''
async for chunk in supervisor.stream_response(task=chat_completion_task):
async for chunk in supervisor.stream_response(task=chat_completion_task(INSTANCE_1_ID)):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
@@ -128,11 +140,11 @@ async def test_runner_assigned_active(
async def test_runner_assigned_wrong_node(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[NodeId, RunnerId], Instance]
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
):
worker, global_events = await worker_running(NODE_A)
instance_value = instance(NODE_B, RUNNER_1_ID)
instance_value = instance(INSTANCE_1_ID, NODE_B, RUNNER_1_ID)
await global_events.append_events(
[
@@ -157,11 +169,11 @@ async def test_runner_assigned_wrong_node(
async def test_runner_unassigns(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[NodeId, RunnerId], Instance]
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
):
worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
await global_events.append_events(
@@ -206,4 +218,254 @@ async def test_runner_unassigns(
events = await global_events.get_events_since(0)
assert isinstance(events[-1].event, RunnerDeleted)
# After deletion, runner should be removed from state.runners
assert len(worker.state.runners) == 0
assert len(worker.state.runners) == 0
async def test_runner_inference(
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId], Task]
):
_worker, global_events = await worker_running(NODE_A)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
task: Task = chat_completion_task(INSTANCE_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance_value,
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(0.3)
async def test_2_runner_inference(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId], Task]
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker1.run())
worker2 = Worker(NODE_B, logger=logger, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker2.run())
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
task = chat_completion_task(INSTANCE_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
idx = await global_events.get_last_idx()
await asyncio.sleep(1.0)
events = await global_events.get_events_since(idx)
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(2.0)
async def test_runner_respawn(
logger: Logger,
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId], Task]
):
event_log_manager = EventLogManager(EventLogConfig(), logger)
await event_log_manager.initialize()
global_events = event_log_manager.global_events
await global_events.delete_all_events()
worker1 = Worker(NODE_A, logger=logger, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker1.run())
worker2 = Worker(NODE_B, logger=logger, worker_events=global_events, global_events=global_events)
asyncio.create_task(worker2.run())
## Instance
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1)
},
node_to_runner={
NODE_A: RUNNER_1_ID,
NODE_B: RUNNER_2_ID
}
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=shard_assignments,
hosts=hosts(2)
)
task = chat_completion_task(INSTANCE_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance
),
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
await asyncio.sleep(0.1)
idx = await global_events.get_last_idx()
assigned_runner: AssignedRunner = worker1.assigned_runners[RUNNER_1_ID]
assert assigned_runner.runner is not None
assigned_runner.runner.runner_process.kill()
# Wait for the process to actually be detected as dead or cleaned up
for _ in range(100): # Wait up to 1 second
await asyncio.sleep(0.01)
# The worker may clean up the runner (set to None) when it detects it's dead
if assigned_runner.runner and not assigned_runner.runner.healthy:
break
else:
raise AssertionError("Runner should have been detected as unhealthy or cleaned up after kill()")
await asyncio.sleep(5.0)
events = await global_events.get_events_since(idx)
print(f'{events=}')
# assert len(events) == 2
assert isinstance(events[0].event, RunnerStatusUpdated)
assert isinstance(events[0].event.runner_status, FailedRunnerStatus)
assert isinstance(events[1].event, RunnerStatusUpdated)
assert isinstance(events[1].event.runner_status, ReadyRunnerStatus)
assert events[1].event.runner_id == RUNNER_2_ID
assert isinstance(events[2].event, RunnerStatusUpdated)
assert isinstance(events[2].event.runner_status, ReadyRunnerStatus)
assert events[2].event.runner_id == RUNNER_1_ID
print(worker1.state)
print(worker2.state)
for event in [events[3].event, events[4].event]:
assert isinstance(event, RunnerStatusUpdated)
assert isinstance(event.runner_status, LoadedRunnerStatus)
task = chat_completion_task(INSTANCE_1_ID)
await global_events.append_events(
[
TaskCreated(
task_id=task.task_id,
task=task
)
],
origin=MASTER_NODE_ID
)
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert 'tokyo' in response_string.lower()
await asyncio.sleep(0.1)
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID
)
await asyncio.sleep(1.0)

View File

@@ -0,0 +1,44 @@
import asyncio
from typing import Tuple
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.types.events import ChunkGenerated, TaskStateUpdated
from shared.types.events.chunks import TokenChunk
from shared.types.tasks import TaskStatus
async def read_streaming_response(global_events: AsyncSQLiteEventStorage) -> Tuple[bool, bool, str]:
# Read off all events - these should be our GenerationChunk events
seen_task_started, seen_task_finished = 0, 0
response_string = ''
finish_reason: str | None = None
idx = 0
while not finish_reason:
events = await global_events.get_events_since(idx)
if len(events) == 0:
await asyncio.sleep(0.01)
continue
idx = events[-1].idx_in_log
for wrapped_event in events:
event = wrapped_event.event
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.RUNNING:
seen_task_started += 1
if event.task_status == TaskStatus.COMPLETE:
seen_task_finished += 1
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
response_string += event.chunk.text
if event.chunk.finish_reason:
finish_reason = event.chunk.finish_reason
await asyncio.sleep(0.2)
print(f'event log: {await global_events.get_events_since(0)}')
return seen_task_started == 1, seen_task_finished == 1, response_string

View File

@@ -481,7 +481,23 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
)
},
runners={RUNNER_1_ID: LoadedRunnerStatus()},
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
tasks={
TASK_1_ID: ChatCompletionTask(
task_id=TASK_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[
ChatCompletionMessage(
role="user",
content="Hello, world!"
)
]
),
instance_id=INSTANCE_1_ID
)
},
),
expected_op=ExecuteTaskOp(runner_id=RUNNER_1_ID, task=ChatCompletionTask(
task_id=TASK_1_ID,