diff --git a/read_events.py b/read_events.py new file mode 100644 index 00000000..d63ad636 --- /dev/null +++ b/read_events.py @@ -0,0 +1,25 @@ +import asyncio +from logging import Logger + + +from worker.main import get_node_id +from shared.types.common import NodeId +from shared.db.sqlite.event_log_manager import EventLogManager, EventLogConfig + +async def main(): + node_id: NodeId = get_node_id() + logger: Logger = Logger('worker_log') + + event_log_manager: EventLogManager = EventLogManager(EventLogConfig(), logger) + await event_log_manager.initialize() + + events = await event_log_manager.global_events.get_events_since(0) + + for wrapped_event in events: + event = wrapped_event.event + event_type = type(event).__name__.replace('_', ' ').title() + attributes = ', '.join(f"{key}={value!r}" for key, value in vars(event).items()) + print(f"{event_type}: {attributes}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/shared/apply/apply.py b/shared/apply/apply.py index 85289c00..0cf79e40 100644 --- a/shared/apply/apply.py +++ b/shared/apply/apply.py @@ -106,7 +106,7 @@ def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> Sta return state.model_copy(update={"runners": new_runners}) @event_apply.register(RunnerDeleted) -def apply_runner_deleted(event: RunnerStatusUpdated, state: State) -> State: +def apply_runner_deleted(event: RunnerDeleted, state: State) -> State: new_runners: Mapping[RunnerId, RunnerStatus] = {rid: rs for rid, rs in state.runners.items() if rid != event.runner_id} return state.model_copy(update={"runners": new_runners}) diff --git a/shared/types/worker/ops.py b/shared/types/worker/ops.py index fb4a7521..97787fba 100644 --- a/shared/types/worker/ops.py +++ b/shared/types/worker/ops.py @@ -15,6 +15,7 @@ class RunnerOpType(str, Enum): UNASSIGN_RUNNER = "unassign_runner" RUNNER_UP = "runner_up" RUNNER_DOWN = "runner_down" + RUNNER_FAILED = "runner_failed" DOWNLOAD = "download" CHAT_COMPLETION = "chat_completion" @@ -42,6 +43,10 @@ class RunnerDownOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_DOWN]]): op_type: Literal[RunnerOpType.RUNNER_DOWN] = Field(default=RunnerOpType.RUNNER_DOWN, frozen=True) runner_id: RunnerId +class RunnerFailedOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_FAILED]]): + op_type: Literal[RunnerOpType.RUNNER_FAILED] = Field(default=RunnerOpType.RUNNER_FAILED, frozen=True) + runner_id: RunnerId + class DownloadOp(BaseRunnerOp[Literal[RunnerOpType.DOWNLOAD]]): op_type: Literal[RunnerOpType.DOWNLOAD] = Field(default=RunnerOpType.DOWNLOAD, frozen=True) instance_id: InstanceId @@ -62,6 +67,7 @@ RunnerOp = Annotated[ UnassignRunnerOp, RunnerUpOp, RunnerDownOp, + RunnerFailedOp, DownloadOp, ExecuteTaskOp, ], diff --git a/worker/main.py b/worker/main.py index 0196116c..e41ab847 100644 --- a/worker/main.py +++ b/worker/main.py @@ -38,6 +38,7 @@ from shared.types.worker.ops import ( DownloadOp, ExecuteTaskOp, RunnerDownOp, + RunnerFailedOp, RunnerOp, RunnerOpType, RunnerUpOp, @@ -162,6 +163,18 @@ class Worker: assigned_runner.status = ReadyRunnerStatus() yield assigned_runner.status_update_event() + return + + async def _execute_runner_failed_op( + self, op: RunnerFailedOp + ) -> AsyncGenerator[Event, None]: + ''' + We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down. + ''' + assigned_runner = self.assigned_runners[op.runner_id] + + assigned_runner.status = FailedRunnerStatus() + yield self.assigned_runners[op.runner_id].status_update_event() async def _execute_download_op( self, op: DownloadOp @@ -309,6 +322,8 @@ class Worker: event_generator = self._execute_runner_up_op(op) case RunnerOpType.RUNNER_DOWN: event_generator = self._execute_runner_down_op(op) + case RunnerOpType.RUNNER_FAILED: + event_generator = self._execute_runner_failed_op(op) case RunnerOpType.DOWNLOAD: event_generator = self._execute_download_op(op) case RunnerOpType.CHAT_COMPLETION: @@ -331,6 +346,12 @@ class Worker: if runner_id not in runner_ids: return UnassignRunnerOp(runner_id=runner_id) + for runner_id, assigned_runner in self.assigned_runners.items(): + if assigned_runner.runner is not None and \ + not assigned_runner.runner.healthy and \ + not isinstance(assigned_runner.status, FailedRunnerStatus): + return RunnerFailedOp(runner_id=runner_id) + # Then spin down active runners for _instance_id, instance in state.instances.items(): for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): @@ -346,7 +367,9 @@ class Worker: # If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down # TODO: We need to limit number of retries if we keep failing. for _instance_id, instance in state.instances.items(): - if self.node_id in instance.shard_assignments.node_to_runner: + if self.node_id in instance.shard_assignments.node_to_runner and \ + instance.shard_assignments.node_to_runner[self.node_id] in self.assigned_runners and \ + not isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, ReadyRunnerStatus): # make sure that our runner has not already been spun down into ready state other_node_in_instance_has_failed = False for runner_id in instance.shard_assignments.runner_to_shard: if runner_id in state.runners and \ @@ -362,13 +385,17 @@ class Worker: for _instance_id, instance in state.instances.items(): if self.node_id in instance.shard_assignments.node_to_runner and \ instance.shard_assignments.node_to_runner[self.node_id] in state.runners and \ - isinstance(state.runners[instance.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus): - + isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, FailedRunnerStatus): + num_spundown_nodes = 0 for runner_id in instance.shard_assignments.runner_to_shard: if isinstance(state.runners[runner_id], ReadyRunnerStatus) and \ runner_id not in self.assigned_runners: num_spundown_nodes += 1 + # Suggested: + # if runner_id in state.runners and isinstance(state.runners[runner_id], ReadyRunnerStatus): + # if runner_id != instance.shard_assignments.node_to_runner[self.node_id]: + # num_spundown_nodes += 1 if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1: # All the other nodes are spun down - so now we can spin down too. @@ -421,7 +448,7 @@ class Worker: # Need to assert all other runners are ready before we can spin up. ready_to_spin = True for runner_id in instance.shard_assignments.node_to_runner.values(): - if state.runners[runner_id].runner_status != RunnerStatusType.Ready: + if runner_id in state.runners and state.runners[runner_id].runner_status != RunnerStatusType.Ready: ready_to_spin = False if ready_to_spin: @@ -438,7 +465,7 @@ class Worker: continue # The only previous state to get to Running is from Loaded for _, task in state.tasks.items(): - if task.instance_id == instance_id: + if task.instance_id == instance_id and task.task_status == TaskStatus.PENDING: if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1): return ExecuteTaskOp(runner_id=runner_id, task=task) else: @@ -465,11 +492,9 @@ class Worker: assert self.global_events is not None while True: + _rank = list(self.assigned_runners.values())[0].shard_metadata.device_rank if self.assigned_runners else None # 1. get latest events events = await self.global_events.get_events_since(self.state.last_event_applied_idx) - if len(events) == 0: - await asyncio.sleep(0.01) - continue # 2. for each event, apply it to the state and run sagas for event_from_log in events: diff --git a/worker/runner/runner.py b/worker/runner/runner.py index 99d6a2e5..d5a1fbb2 100644 --- a/worker/runner/runner.py +++ b/worker/runner/runner.py @@ -91,8 +91,8 @@ async def _mlx_generate( runner_print(item.text) yield item - # TODO: There is a big bug on this line! - assert future.done() + # Wait for the executor thread to complete + await future async def main(): diff --git a/worker/runner/runner_supervisor.py b/worker/runner/runner_supervisor.py index 3d1b0553..54d380d2 100644 --- a/worker/runner/runner_supervisor.py +++ b/worker/runner/runner_supervisor.py @@ -88,12 +88,23 @@ class RunnerSupervisor: async def astop(self) -> None: async def terminate() -> None: - self.runner_process.terminate() - _ = await self.runner_process.wait() + # Check if process is already dead before trying to terminate + if self.runner_process.returncode is None: + self.runner_process.terminate() + + # Wait for the process to exit (or confirm it's already exited) + try: + _ = await asyncio.wait_for(self.runner_process.wait(), timeout=1.0) + except asyncio.TimeoutError: + # If terminate didn't work, force kill + if self.runner_process.returncode is None: + self.runner_process.kill() + _ = await self.runner_process.wait() if not self.healthy: print("Runner process is not healthy, killing...") await terminate() + print('terminated') if self.runner_process.stdout is not None: while True: @@ -107,15 +118,20 @@ class RunnerSupervisor: except asyncio.TimeoutError: break - try: - # Give the process a moment to exit gracefully - await supervisor_write_message( - proc=self.runner_process, message=ExitMessage() - ) - _ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1) - except asyncio.TimeoutError: - print("Runner process did not terminate, killing...") - await terminate() + # Only try to send ExitMessage if process is still alive + if self.runner_process.returncode is None: + try: + # Give the process a moment to exit gracefully + await supervisor_write_message( + proc=self.runner_process, message=ExitMessage() + ) + _ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1) + except asyncio.TimeoutError: + print("Runner process did not terminate, killing...") + await terminate() + except Exception: + # If we can't write to the process (e.g., broken pipe), it's probably already dead + pass self.running = False @@ -124,7 +140,7 @@ class RunnerSupervisor: self.running = False def __del__(self) -> None: - if not self.running: + if self.running: print( "Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process." ) diff --git a/worker/tests/conftest.py b/worker/tests/conftest.py index 38ed90d8..ad76fdab 100644 --- a/worker/tests/conftest.py +++ b/worker/tests/conftest.py @@ -99,14 +99,16 @@ def completion_create_params(user_message: str) -> ChatCompletionTaskParams: ) @pytest.fixture -def chat_completion_task(completion_create_params: ChatCompletionTaskParams) -> ChatCompletionTask: - return ChatCompletionTask( - task_id=TaskId(), - instance_id=InstanceId(), - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=completion_create_params - ) +def chat_completion_task(completion_create_params: ChatCompletionTaskParams): + def _chat_completion_task(instance_id: InstanceId) -> ChatCompletionTask: + return ChatCompletionTask( + task_id=TaskId(), + instance_id=instance_id, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=completion_create_params + ) + return _chat_completion_task @pytest.fixture def node_id() -> NodeId: @@ -129,7 +131,7 @@ def logger() -> Logger: @pytest.fixture def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts_one: list[Host]): - def _instance(node_id: NodeId, runner_id: RunnerId) -> Instance: + def _instance(instance_id: InstanceId, node_id: NodeId, runner_id: RunnerId) -> Instance: model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') shard_assignments = ShardAssignments( @@ -156,10 +158,10 @@ async def worker(node_id: NodeId, logger: Logger): return Worker(node_id, logger, worker_events=event_log_manager.global_events, global_events=event_log_manager.global_events) @pytest.fixture -async def worker_with_assigned_runner(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance]): +async def worker_with_assigned_runner(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]): """Fixture that provides a worker with an already assigned runner.""" - instance_obj: Instance = instance(worker.node_id, RunnerId()) + instance_obj: Instance = instance(InstanceId(), worker.node_id, RunnerId()) # Extract runner_id from shard assignments runner_id = next(iter(instance_obj.shard_assignments.runner_to_shard)) diff --git a/worker/tests/test_serdes.py b/worker/tests/test_serdes.py index 6e54178b..42af427e 100644 --- a/worker/tests/test_serdes.py +++ b/worker/tests/test_serdes.py @@ -9,6 +9,7 @@ from shared.types.worker.commands_runner import ( RunnerMessageTypeAdapter, SetupMessage, ) +from shared.types.worker.common import InstanceId from shared.types.worker.mlx import Host from shared.types.worker.shards import PipelineShardMetadata @@ -37,9 +38,10 @@ def test_supervisor_setup_message_serdes( def test_supervisor_task_message_serdes( - chat_completion_task: Task, + chat_completion_task: Callable[[InstanceId], Task], ): + task = chat_completion_task(InstanceId()) task_message = ChatTaskMessage( - task_data=chat_completion_task.task_params, + task_data=task.task_params, ) assert_equal_serdes(task_message, RunnerMessageTypeAdapter) diff --git a/worker/tests/test_supervisor.py b/worker/tests/test_supervisor.py index b482e833..5a77eccd 100644 --- a/worker/tests/test_supervisor.py +++ b/worker/tests/test_supervisor.py @@ -11,6 +11,7 @@ from shared.types.tasks import ( Task, TaskType, ) +from shared.types.worker.common import InstanceId from shared.types.worker.mlx import Host from shared.types.worker.shards import PipelineShardMetadata from worker.runner.runner_supervisor import RunnerSupervisor @@ -26,11 +27,12 @@ def user_message(): async def test_supervisor_single_node_response( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - chat_completion_task: Task, + chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path, ): """Test that asking for the capital of France returns 'Paris' in the response""" model_shard_meta = pipeline_shard_meta(1, 0) + instance_id = InstanceId() print(f'{model_shard_meta=}') @@ -43,7 +45,7 @@ async def test_supervisor_single_node_response( full_response = "" stop_reason: FinishReason | None = None - async for chunk in supervisor.stream_response(task=chat_completion_task): + async for chunk in supervisor.stream_response(task=chat_completion_task(instance_id)): if isinstance(chunk, TokenChunk): full_response += chunk.text if chunk.finish_reason: @@ -63,10 +65,11 @@ async def test_supervisor_single_node_response( async def test_supervisor_two_node_response( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - chat_completion_task: Task, + chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path, ): """Test that asking for the capital of France returns 'Paris' in the response""" + instance_id = InstanceId() supervisor_0 = await RunnerSupervisor.create( model_shard_meta=pipeline_shard_meta(2, 0), hosts=hosts(2, offset=15), @@ -85,13 +88,13 @@ async def test_supervisor_two_node_response( async def collect_response_0(): nonlocal full_response_0 - async for chunk in supervisor_0.stream_response(task=chat_completion_task): + async for chunk in supervisor_0.stream_response(task=chat_completion_task(instance_id)): if isinstance(chunk, TokenChunk): full_response_0 += chunk.text async def collect_response_1(): nonlocal full_response_1 - async for chunk in supervisor_1.stream_response(task=chat_completion_task): + async for chunk in supervisor_1.stream_response(task=chat_completion_task(instance_id)): if isinstance(chunk, TokenChunk): full_response_1 += chunk.text @@ -118,22 +121,25 @@ async def test_supervisor_two_node_response( async def test_supervisor_early_stopping( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - chat_completion_task: Task, + chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path, ): """Test that asking for the capital of France returns 'Paris' in the response""" model_shard_meta = pipeline_shard_meta(1, 0) + instance_id = InstanceId() supervisor = await RunnerSupervisor.create( model_shard_meta=model_shard_meta, hosts=hosts(1, offset=10), ) + task = chat_completion_task(instance_id) + max_tokens = 50 - assert chat_completion_task.task_type == TaskType.CHAT_COMPLETION - print(f'chat_completion_task.task_params: {chat_completion_task.task_params}') - assert isinstance(chat_completion_task.task_params, ChatCompletionTaskParams) - task_params: ChatCompletionTaskParams = chat_completion_task.task_params + assert task.task_type == TaskType.CHAT_COMPLETION + print(f'chat_completion_task.task_params: {task.task_params}') + assert isinstance(task.task_params, ChatCompletionTaskParams) + task_params: ChatCompletionTaskParams = task.task_params try: task_params.max_tokens = max_tokens @@ -146,7 +152,7 @@ async def test_supervisor_early_stopping( count = 0 stop_reason: FinishReason | None = None - async for chunk in supervisor.stream_response(task=chat_completion_task): + async for chunk in supervisor.stream_response(task=task): if isinstance(chunk, TokenChunk): full_response += chunk.text count += 1 @@ -169,7 +175,6 @@ async def test_supervisor_early_stopping( async def test_supervisor_handles_terminated_runner( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - chat_completion_task: Task, tmp_path: Path, ): """Test that the supervisor handles a terminated runner""" @@ -194,7 +199,6 @@ async def test_supervisor_handles_terminated_runner( async def test_supervisor_handles_killed_runner( pipeline_shard_meta: Callable[..., PipelineShardMetadata], hosts: Callable[..., list[Host]], - chat_completion_task: Task, tmp_path: Path, ): """Test that the supervisor handles a killed runner""" diff --git a/worker/tests/test_worker_handlers.py b/worker/tests/test_worker_handlers.py index eb791f2d..ef5c634e 100644 --- a/worker/tests/test_worker_handlers.py +++ b/worker/tests/test_worker_handlers.py @@ -16,7 +16,7 @@ from shared.types.events import ( from shared.types.events.chunks import TokenChunk from shared.types.tasks import Task, TaskStatus from shared.types.worker.common import RunnerId -from shared.types.worker.instances import Instance +from shared.types.worker.instances import Instance, InstanceId from shared.types.worker.ops import ( AssignRunnerOp, DownloadOp, @@ -40,9 +40,9 @@ def user_message(): return "What, according to Douglas Adams, is the meaning of life, the universe and everything?" @pytest.mark.asyncio -async def test_assign_op(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance], tmp_path: Path): +async def test_assign_op(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance], tmp_path: Path): runner_id = RunnerId() - instance_obj: Instance = instance(worker.node_id, runner_id) + instance_obj: Instance = instance(InstanceId(), worker.node_id, runner_id) assign_op = AssignRunnerOp( runner_id=runner_id, @@ -84,7 +84,7 @@ async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId, assert isinstance(events[0], RunnerDeleted) @pytest.mark.asyncio -async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_completion_task: Task, tmp_path: Path): +async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path): worker, runner_id, _ = worker_with_assigned_runner runner_up_op = RunnerUpOp(runner_id=runner_id) @@ -104,7 +104,7 @@ async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, full_response = '' - async for chunk in supervisor.stream_response(task=chat_completion_task): + async for chunk in supervisor.stream_response(task=chat_completion_task(InstanceId())): if isinstance(chunk, TokenChunk): full_response += chunk.text @@ -153,12 +153,12 @@ async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId, @pytest.mark.asyncio async def test_execute_task_op( worker_with_running_runner: tuple[Worker, RunnerId, Instance], - chat_completion_task: Task, tmp_path: Path): + chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path): worker, runner_id, _ = worker_with_running_runner execute_task_op = ExecuteTaskOp( runner_id=runner_id, - task=chat_completion_task + task=chat_completion_task(InstanceId()) ) events: list[Event] = [] @@ -196,15 +196,16 @@ async def test_execute_task_op( @pytest.mark.asyncio async def test_execute_task_fails( worker_with_running_runner: tuple[Worker, RunnerId, Instance], - chat_completion_task: Task, tmp_path: Path): + chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path): worker, runner_id, _ = worker_with_running_runner - messages = chat_completion_task.task_params.messages + task = chat_completion_task(InstanceId()) + messages = task.task_params.messages messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL' execute_task_op = ExecuteTaskOp( runner_id=runner_id, - task=chat_completion_task + task=task ) events: list[Event] = [] diff --git a/worker/tests/test_worker_integration.py b/worker/tests/test_worker_integration.py index f83b1013..3041080c 100644 --- a/worker/tests/test_worker_integration.py +++ b/worker/tests/test_worker_integration.py @@ -1,27 +1,39 @@ import asyncio +from logging import Logger from typing import Awaitable, Callable, Final import pytest +# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager from shared.types.common import NodeId from shared.types.events import ( InstanceCreated, InstanceDeleted, RunnerDeleted, RunnerStatusUpdated, + TaskCreated, ) from shared.types.events.chunks import TokenChunk from shared.types.models import ModelId from shared.types.tasks import Task, TaskId from shared.types.worker.common import InstanceId, RunnerId -from shared.types.worker.instances import Instance, InstanceStatus +from shared.types.worker.instances import ( + Instance, + InstanceStatus, + ShardAssignments, +) +from shared.types.worker.mlx import Host from shared.types.worker.runners import ( + FailedRunnerStatus, LoadedRunnerStatus, ReadyRunnerStatus, # RunningRunnerStatus, ) -from worker.main import Worker +from shared.types.worker.shards import PipelineShardMetadata +from worker.main import AssignedRunner, Worker +from worker.tests.test_worker_integration_utils import read_streaming_response MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa") @@ -42,14 +54,14 @@ def user_message(): async def test_runner_assigned( worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], - instance: Callable[[NodeId, RunnerId], Instance] + instance: Callable[[InstanceId, NodeId, RunnerId], Instance] ): worker, global_events = await worker_running(NODE_A) print(worker) - instance_value: Instance = instance(NODE_A, RUNNER_1_ID) + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.INACTIVE await global_events.append_events( @@ -79,12 +91,12 @@ async def test_runner_assigned( async def test_runner_assigned_active( worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], - instance: Callable[[NodeId, RunnerId], Instance], - chat_completion_task: Task + instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + chat_completion_task: Callable[[InstanceId], Task] ): worker, global_events = await worker_running(NODE_A) - instance_value: Instance = instance(NODE_A, RUNNER_1_ID) + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE await global_events.append_events( @@ -118,7 +130,7 @@ async def test_runner_assigned_active( full_response = '' - async for chunk in supervisor.stream_response(task=chat_completion_task): + async for chunk in supervisor.stream_response(task=chat_completion_task(INSTANCE_1_ID)): if isinstance(chunk, TokenChunk): full_response += chunk.text @@ -128,11 +140,11 @@ async def test_runner_assigned_active( async def test_runner_assigned_wrong_node( worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], - instance: Callable[[NodeId, RunnerId], Instance] + instance: Callable[[InstanceId, NodeId, RunnerId], Instance] ): worker, global_events = await worker_running(NODE_A) - instance_value = instance(NODE_B, RUNNER_1_ID) + instance_value = instance(INSTANCE_1_ID, NODE_B, RUNNER_1_ID) await global_events.append_events( [ @@ -157,11 +169,11 @@ async def test_runner_assigned_wrong_node( async def test_runner_unassigns( worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], - instance: Callable[[NodeId, RunnerId], Instance] + instance: Callable[[InstanceId, NodeId, RunnerId], Instance] ): worker, global_events = await worker_running(NODE_A) - instance_value: Instance = instance(NODE_A, RUNNER_1_ID) + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) instance_value.instance_type = InstanceStatus.ACTIVE await global_events.append_events( @@ -206,4 +218,254 @@ async def test_runner_unassigns( events = await global_events.get_events_since(0) assert isinstance(events[-1].event, RunnerDeleted) # After deletion, runner should be removed from state.runners - assert len(worker.state.runners) == 0 \ No newline at end of file + assert len(worker.state.runners) == 0 + +async def test_runner_inference( + worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]], + instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + chat_completion_task: Callable[[InstanceId], Task] + ): + _worker, global_events = await worker_running(NODE_A) + + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE + + task: Task = chat_completion_task(INSTANCE_1_ID) + await global_events.append_events( + [ + InstanceCreated( + instance=instance_value, + ), + TaskCreated( + task_id=task.task_id, + task=task + ) + ], + origin=MASTER_NODE_ID + ) + + seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert 'tokyo' in response_string.lower() + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(0.3) + +async def test_2_runner_inference( + logger: Logger, + pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], + hosts: Callable[[int], list[Host]], + chat_completion_task: Callable[[InstanceId], Task] + ): + event_log_manager = EventLogManager(EventLogConfig(), logger) + await event_log_manager.initialize() + + global_events = event_log_manager.global_events + await global_events.delete_all_events() + + worker1 = Worker(NODE_A, logger=logger, worker_events=global_events, global_events=global_events) + asyncio.create_task(worker1.run()) + + worker2 = Worker(NODE_B, logger=logger, worker_events=global_events, global_events=global_events) + asyncio.create_task(worker2.run()) + + ## Instance + model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') + + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1) + }, + node_to_runner={ + NODE_A: RUNNER_1_ID, + NODE_B: RUNNER_2_ID + } + ) + + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2) + ) + + task = chat_completion_task(INSTANCE_1_ID) + await global_events.append_events( + [ + InstanceCreated( + instance=instance + ), + TaskCreated( + task_id=task.task_id, + task=task + ) + ], + origin=MASTER_NODE_ID + ) + + seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert 'tokyo' in response_string.lower() + + + idx = await global_events.get_last_idx() + await asyncio.sleep(1.0) + events = await global_events.get_events_since(idx) + assert len(events) == 0 + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(2.0) + + + +async def test_runner_respawn( + logger: Logger, + pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], + hosts: Callable[[int], list[Host]], + chat_completion_task: Callable[[InstanceId], Task] + ): + event_log_manager = EventLogManager(EventLogConfig(), logger) + await event_log_manager.initialize() + + global_events = event_log_manager.global_events + await global_events.delete_all_events() + + worker1 = Worker(NODE_A, logger=logger, worker_events=global_events, global_events=global_events) + asyncio.create_task(worker1.run()) + + worker2 = Worker(NODE_B, logger=logger, worker_events=global_events, global_events=global_events) + asyncio.create_task(worker2.run()) + + ## Instance + model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit') + + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={ + RUNNER_1_ID: pipeline_shard_meta(2, 0), + RUNNER_2_ID: pipeline_shard_meta(2, 1) + }, + node_to_runner={ + NODE_A: RUNNER_1_ID, + NODE_B: RUNNER_2_ID + } + ) + + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(2) + ) + + task = chat_completion_task(INSTANCE_1_ID) + await global_events.append_events( + [ + InstanceCreated( + instance=instance + ), + TaskCreated( + task_id=task.task_id, + task=task + ) + ], + origin=MASTER_NODE_ID + ) + + seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert 'tokyo' in response_string.lower() + + await asyncio.sleep(0.1) + + idx = await global_events.get_last_idx() + + assigned_runner: AssignedRunner = worker1.assigned_runners[RUNNER_1_ID] + assert assigned_runner.runner is not None + assigned_runner.runner.runner_process.kill() + + # Wait for the process to actually be detected as dead or cleaned up + for _ in range(100): # Wait up to 1 second + await asyncio.sleep(0.01) + # The worker may clean up the runner (set to None) when it detects it's dead + if assigned_runner.runner and not assigned_runner.runner.healthy: + break + else: + raise AssertionError("Runner should have been detected as unhealthy or cleaned up after kill()") + + await asyncio.sleep(5.0) + + events = await global_events.get_events_since(idx) + print(f'{events=}') + # assert len(events) == 2 + assert isinstance(events[0].event, RunnerStatusUpdated) + assert isinstance(events[0].event.runner_status, FailedRunnerStatus) + + assert isinstance(events[1].event, RunnerStatusUpdated) + assert isinstance(events[1].event.runner_status, ReadyRunnerStatus) + assert events[1].event.runner_id == RUNNER_2_ID + + assert isinstance(events[2].event, RunnerStatusUpdated) + assert isinstance(events[2].event.runner_status, ReadyRunnerStatus) + assert events[2].event.runner_id == RUNNER_1_ID + + print(worker1.state) + print(worker2.state) + + for event in [events[3].event, events[4].event]: + assert isinstance(event, RunnerStatusUpdated) + assert isinstance(event.runner_status, LoadedRunnerStatus) + + task = chat_completion_task(INSTANCE_1_ID) + await global_events.append_events( + [ + TaskCreated( + task_id=task.task_id, + task=task + ) + ], + origin=MASTER_NODE_ID + ) + + seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events) + + assert seen_task_started + assert seen_task_finished + assert 'tokyo' in response_string.lower() + + await asyncio.sleep(0.1) + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID + ) + + await asyncio.sleep(1.0) \ No newline at end of file diff --git a/worker/tests/test_worker_integration_utils.py b/worker/tests/test_worker_integration_utils.py new file mode 100644 index 00000000..5e0b78d8 --- /dev/null +++ b/worker/tests/test_worker_integration_utils.py @@ -0,0 +1,44 @@ + + +import asyncio +from typing import Tuple + +from shared.db.sqlite.connector import AsyncSQLiteEventStorage +from shared.types.events import ChunkGenerated, TaskStateUpdated +from shared.types.events.chunks import TokenChunk +from shared.types.tasks import TaskStatus + + +async def read_streaming_response(global_events: AsyncSQLiteEventStorage) -> Tuple[bool, bool, str]: + # Read off all events - these should be our GenerationChunk events + seen_task_started, seen_task_finished = 0, 0 + response_string = '' + finish_reason: str | None = None + + idx = 0 + while not finish_reason: + events = await global_events.get_events_since(idx) + if len(events) == 0: + await asyncio.sleep(0.01) + continue + idx = events[-1].idx_in_log + + for wrapped_event in events: + event = wrapped_event.event + if isinstance(event, TaskStateUpdated): + if event.task_status == TaskStatus.RUNNING: + seen_task_started += 1 + if event.task_status == TaskStatus.COMPLETE: + seen_task_finished += 1 + + if isinstance(event, ChunkGenerated): + assert isinstance(event.chunk, TokenChunk) + response_string += event.chunk.text + if event.chunk.finish_reason: + finish_reason = event.chunk.finish_reason + + await asyncio.sleep(0.2) + + print(f'event log: {await global_events.get_events_since(0)}') + + return seen_task_started == 1, seen_task_finished == 1, response_string \ No newline at end of file diff --git a/worker/tests/test_worker_plan.py b/worker/tests/test_worker_plan.py index 8f00b84b..120e3895 100644 --- a/worker/tests/test_worker_plan.py +++ b/worker/tests/test_worker_plan.py @@ -481,7 +481,23 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]: ) }, runners={RUNNER_1_ID: LoadedRunnerStatus()}, - tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)}, + tasks={ + TASK_1_ID: ChatCompletionTask( + task_id=TASK_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=ChatCompletionTaskParams( + model=str(MODEL_A_ID), + messages=[ + ChatCompletionMessage( + role="user", + content="Hello, world!" + ) + ] + ), + instance_id=INSTANCE_1_ID + ) + }, ), expected_op=ExecuteTaskOp(runner_id=RUNNER_1_ID, task=ChatCompletionTask( task_id=TASK_1_ID,