mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Inference Integration Test
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
25
read_events.py
Normal file
25
read_events.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
|
||||
|
||||
from worker.main import get_node_id
|
||||
from shared.types.common import NodeId
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager, EventLogConfig
|
||||
|
||||
async def main():
|
||||
node_id: NodeId = get_node_id()
|
||||
logger: Logger = Logger('worker_log')
|
||||
|
||||
event_log_manager: EventLogManager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
|
||||
events = await event_log_manager.global_events.get_events_since(0)
|
||||
|
||||
for wrapped_event in events:
|
||||
event = wrapped_event.event
|
||||
event_type = type(event).__name__.replace('_', ' ').title()
|
||||
attributes = ', '.join(f"{key}={value!r}" for key, value in vars(event).items())
|
||||
print(f"{event_type}: {attributes}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -106,7 +106,7 @@ def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> Sta
|
||||
return state.model_copy(update={"runners": new_runners})
|
||||
|
||||
@event_apply.register(RunnerDeleted)
|
||||
def apply_runner_deleted(event: RunnerStatusUpdated, state: State) -> State:
|
||||
def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {rid: rs for rid, rs in state.runners.items() if rid != event.runner_id}
|
||||
return state.model_copy(update={"runners": new_runners})
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ class RunnerOpType(str, Enum):
|
||||
UNASSIGN_RUNNER = "unassign_runner"
|
||||
RUNNER_UP = "runner_up"
|
||||
RUNNER_DOWN = "runner_down"
|
||||
RUNNER_FAILED = "runner_failed"
|
||||
DOWNLOAD = "download"
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
|
||||
@@ -42,6 +43,10 @@ class RunnerDownOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_DOWN]]):
|
||||
op_type: Literal[RunnerOpType.RUNNER_DOWN] = Field(default=RunnerOpType.RUNNER_DOWN, frozen=True)
|
||||
runner_id: RunnerId
|
||||
|
||||
class RunnerFailedOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_FAILED]]):
|
||||
op_type: Literal[RunnerOpType.RUNNER_FAILED] = Field(default=RunnerOpType.RUNNER_FAILED, frozen=True)
|
||||
runner_id: RunnerId
|
||||
|
||||
class DownloadOp(BaseRunnerOp[Literal[RunnerOpType.DOWNLOAD]]):
|
||||
op_type: Literal[RunnerOpType.DOWNLOAD] = Field(default=RunnerOpType.DOWNLOAD, frozen=True)
|
||||
instance_id: InstanceId
|
||||
@@ -62,6 +67,7 @@ RunnerOp = Annotated[
|
||||
UnassignRunnerOp,
|
||||
RunnerUpOp,
|
||||
RunnerDownOp,
|
||||
RunnerFailedOp,
|
||||
DownloadOp,
|
||||
ExecuteTaskOp,
|
||||
],
|
||||
|
||||
@@ -38,6 +38,7 @@ from shared.types.worker.ops import (
|
||||
DownloadOp,
|
||||
ExecuteTaskOp,
|
||||
RunnerDownOp,
|
||||
RunnerFailedOp,
|
||||
RunnerOp,
|
||||
RunnerOpType,
|
||||
RunnerUpOp,
|
||||
@@ -162,6 +163,18 @@ class Worker:
|
||||
|
||||
assigned_runner.status = ReadyRunnerStatus()
|
||||
yield assigned_runner.status_update_event()
|
||||
return
|
||||
|
||||
async def _execute_runner_failed_op(
|
||||
self, op: RunnerFailedOp
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
'''
|
||||
We detected that this runner has failed. So we'll put it into 'failed' state now, triggering the rest of the instance to spin down.
|
||||
'''
|
||||
assigned_runner = self.assigned_runners[op.runner_id]
|
||||
|
||||
assigned_runner.status = FailedRunnerStatus()
|
||||
yield self.assigned_runners[op.runner_id].status_update_event()
|
||||
|
||||
async def _execute_download_op(
|
||||
self, op: DownloadOp
|
||||
@@ -309,6 +322,8 @@ class Worker:
|
||||
event_generator = self._execute_runner_up_op(op)
|
||||
case RunnerOpType.RUNNER_DOWN:
|
||||
event_generator = self._execute_runner_down_op(op)
|
||||
case RunnerOpType.RUNNER_FAILED:
|
||||
event_generator = self._execute_runner_failed_op(op)
|
||||
case RunnerOpType.DOWNLOAD:
|
||||
event_generator = self._execute_download_op(op)
|
||||
case RunnerOpType.CHAT_COMPLETION:
|
||||
@@ -331,6 +346,12 @@ class Worker:
|
||||
if runner_id not in runner_ids:
|
||||
return UnassignRunnerOp(runner_id=runner_id)
|
||||
|
||||
for runner_id, assigned_runner in self.assigned_runners.items():
|
||||
if assigned_runner.runner is not None and \
|
||||
not assigned_runner.runner.healthy and \
|
||||
not isinstance(assigned_runner.status, FailedRunnerStatus):
|
||||
return RunnerFailedOp(runner_id=runner_id)
|
||||
|
||||
# Then spin down active runners
|
||||
for _instance_id, instance in state.instances.items():
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
@@ -346,7 +367,9 @@ class Worker:
|
||||
# If we are part of an instance that has a dead node - and we aren't the dead node - we should spin down
|
||||
# TODO: We need to limit number of retries if we keep failing.
|
||||
for _instance_id, instance in state.instances.items():
|
||||
if self.node_id in instance.shard_assignments.node_to_runner:
|
||||
if self.node_id in instance.shard_assignments.node_to_runner and \
|
||||
instance.shard_assignments.node_to_runner[self.node_id] in self.assigned_runners and \
|
||||
not isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, ReadyRunnerStatus): # make sure that our runner has not already been spun down into ready state
|
||||
other_node_in_instance_has_failed = False
|
||||
for runner_id in instance.shard_assignments.runner_to_shard:
|
||||
if runner_id in state.runners and \
|
||||
@@ -362,13 +385,17 @@ class Worker:
|
||||
for _instance_id, instance in state.instances.items():
|
||||
if self.node_id in instance.shard_assignments.node_to_runner and \
|
||||
instance.shard_assignments.node_to_runner[self.node_id] in state.runners and \
|
||||
isinstance(state.runners[instance.shard_assignments.node_to_runner[self.node_id]], FailedRunnerStatus):
|
||||
|
||||
isinstance(self.assigned_runners[instance.shard_assignments.node_to_runner[self.node_id]].status, FailedRunnerStatus):
|
||||
|
||||
num_spundown_nodes = 0
|
||||
for runner_id in instance.shard_assignments.runner_to_shard:
|
||||
if isinstance(state.runners[runner_id], ReadyRunnerStatus) and \
|
||||
runner_id not in self.assigned_runners:
|
||||
num_spundown_nodes += 1
|
||||
# Suggested:
|
||||
# if runner_id in state.runners and isinstance(state.runners[runner_id], ReadyRunnerStatus):
|
||||
# if runner_id != instance.shard_assignments.node_to_runner[self.node_id]:
|
||||
# num_spundown_nodes += 1
|
||||
|
||||
if num_spundown_nodes == next(iter(instance.shard_assignments.runner_to_shard.values())).world_size - 1:
|
||||
# All the other nodes are spun down - so now we can spin down too.
|
||||
@@ -421,7 +448,7 @@ class Worker:
|
||||
# Need to assert all other runners are ready before we can spin up.
|
||||
ready_to_spin = True
|
||||
for runner_id in instance.shard_assignments.node_to_runner.values():
|
||||
if state.runners[runner_id].runner_status != RunnerStatusType.Ready:
|
||||
if runner_id in state.runners and state.runners[runner_id].runner_status != RunnerStatusType.Ready:
|
||||
ready_to_spin = False
|
||||
|
||||
if ready_to_spin:
|
||||
@@ -438,7 +465,7 @@ class Worker:
|
||||
continue # The only previous state to get to Running is from Loaded
|
||||
|
||||
for _, task in state.tasks.items():
|
||||
if task.instance_id == instance_id:
|
||||
if task.instance_id == instance_id and task.task_status == TaskStatus.PENDING:
|
||||
if (runner.shard_metadata.device_rank >= 1 or runner.shard_metadata.world_size == 1):
|
||||
return ExecuteTaskOp(runner_id=runner_id, task=task)
|
||||
else:
|
||||
@@ -465,11 +492,9 @@ class Worker:
|
||||
assert self.global_events is not None
|
||||
|
||||
while True:
|
||||
_rank = list(self.assigned_runners.values())[0].shard_metadata.device_rank if self.assigned_runners else None
|
||||
# 1. get latest events
|
||||
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
|
||||
if len(events) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# 2. for each event, apply it to the state and run sagas
|
||||
for event_from_log in events:
|
||||
|
||||
@@ -91,8 +91,8 @@ async def _mlx_generate(
|
||||
runner_print(item.text)
|
||||
yield item
|
||||
|
||||
# TODO: There is a big bug on this line!
|
||||
assert future.done()
|
||||
# Wait for the executor thread to complete
|
||||
await future
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@@ -88,12 +88,23 @@ class RunnerSupervisor:
|
||||
|
||||
async def astop(self) -> None:
|
||||
async def terminate() -> None:
|
||||
self.runner_process.terminate()
|
||||
_ = await self.runner_process.wait()
|
||||
# Check if process is already dead before trying to terminate
|
||||
if self.runner_process.returncode is None:
|
||||
self.runner_process.terminate()
|
||||
|
||||
# Wait for the process to exit (or confirm it's already exited)
|
||||
try:
|
||||
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
# If terminate didn't work, force kill
|
||||
if self.runner_process.returncode is None:
|
||||
self.runner_process.kill()
|
||||
_ = await self.runner_process.wait()
|
||||
|
||||
if not self.healthy:
|
||||
print("Runner process is not healthy, killing...")
|
||||
await terminate()
|
||||
print('terminated')
|
||||
|
||||
if self.runner_process.stdout is not None:
|
||||
while True:
|
||||
@@ -107,15 +118,20 @@ class RunnerSupervisor:
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# Give the process a moment to exit gracefully
|
||||
await supervisor_write_message(
|
||||
proc=self.runner_process, message=ExitMessage()
|
||||
)
|
||||
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
print("Runner process did not terminate, killing...")
|
||||
await terminate()
|
||||
# Only try to send ExitMessage if process is still alive
|
||||
if self.runner_process.returncode is None:
|
||||
try:
|
||||
# Give the process a moment to exit gracefully
|
||||
await supervisor_write_message(
|
||||
proc=self.runner_process, message=ExitMessage()
|
||||
)
|
||||
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
print("Runner process did not terminate, killing...")
|
||||
await terminate()
|
||||
except Exception:
|
||||
# If we can't write to the process (e.g., broken pipe), it's probably already dead
|
||||
pass
|
||||
|
||||
self.running = False
|
||||
|
||||
@@ -124,7 +140,7 @@ class RunnerSupervisor:
|
||||
self.running = False
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not self.running:
|
||||
if self.running:
|
||||
print(
|
||||
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process."
|
||||
)
|
||||
|
||||
@@ -99,14 +99,16 @@ def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def chat_completion_task(completion_create_params: ChatCompletionTaskParams) -> ChatCompletionTask:
|
||||
return ChatCompletionTask(
|
||||
task_id=TaskId(),
|
||||
instance_id=InstanceId(),
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=completion_create_params
|
||||
)
|
||||
def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
|
||||
def _chat_completion_task(instance_id: InstanceId) -> ChatCompletionTask:
|
||||
return ChatCompletionTask(
|
||||
task_id=TaskId(),
|
||||
instance_id=instance_id,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=completion_create_params
|
||||
)
|
||||
return _chat_completion_task
|
||||
|
||||
@pytest.fixture
|
||||
def node_id() -> NodeId:
|
||||
@@ -129,7 +131,7 @@ def logger() -> Logger:
|
||||
|
||||
@pytest.fixture
|
||||
def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts_one: list[Host]):
|
||||
def _instance(node_id: NodeId, runner_id: RunnerId) -> Instance:
|
||||
def _instance(instance_id: InstanceId, node_id: NodeId, runner_id: RunnerId) -> Instance:
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
@@ -156,10 +158,10 @@ async def worker(node_id: NodeId, logger: Logger):
|
||||
return Worker(node_id, logger, worker_events=event_log_manager.global_events, global_events=event_log_manager.global_events)
|
||||
|
||||
@pytest.fixture
|
||||
async def worker_with_assigned_runner(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance]):
|
||||
async def worker_with_assigned_runner(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance]):
|
||||
"""Fixture that provides a worker with an already assigned runner."""
|
||||
|
||||
instance_obj: Instance = instance(worker.node_id, RunnerId())
|
||||
instance_obj: Instance = instance(InstanceId(), worker.node_id, RunnerId())
|
||||
|
||||
# Extract runner_id from shard assignments
|
||||
runner_id = next(iter(instance_obj.shard_assignments.runner_to_shard))
|
||||
|
||||
@@ -9,6 +9,7 @@ from shared.types.worker.commands_runner import (
|
||||
RunnerMessageTypeAdapter,
|
||||
SetupMessage,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
@@ -37,9 +38,10 @@ def test_supervisor_setup_message_serdes(
|
||||
|
||||
|
||||
def test_supervisor_task_message_serdes(
|
||||
chat_completion_task: Task,
|
||||
chat_completion_task: Callable[[InstanceId], Task],
|
||||
):
|
||||
task = chat_completion_task(InstanceId())
|
||||
task_message = ChatTaskMessage(
|
||||
task_data=chat_completion_task.task_params,
|
||||
task_data=task.task_params,
|
||||
)
|
||||
assert_equal_serdes(task_message, RunnerMessageTypeAdapter)
|
||||
|
||||
@@ -11,6 +11,7 @@ from shared.types.tasks import (
|
||||
Task,
|
||||
TaskType,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.runner.runner_supervisor import RunnerSupervisor
|
||||
@@ -26,11 +27,12 @@ def user_message():
|
||||
async def test_supervisor_single_node_response(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Task,
|
||||
chat_completion_task: Callable[[InstanceId], Task],
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
instance_id = InstanceId()
|
||||
|
||||
print(f'{model_shard_meta=}')
|
||||
|
||||
@@ -43,7 +45,7 @@ async def test_supervisor_single_node_response(
|
||||
full_response = ""
|
||||
stop_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in supervisor.stream_response(task=chat_completion_task):
|
||||
async for chunk in supervisor.stream_response(task=chat_completion_task(instance_id)):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response += chunk.text
|
||||
if chunk.finish_reason:
|
||||
@@ -63,10 +65,11 @@ async def test_supervisor_single_node_response(
|
||||
async def test_supervisor_two_node_response(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Task,
|
||||
chat_completion_task: Callable[[InstanceId], Task],
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
instance_id = InstanceId()
|
||||
supervisor_0 = await RunnerSupervisor.create(
|
||||
model_shard_meta=pipeline_shard_meta(2, 0),
|
||||
hosts=hosts(2, offset=15),
|
||||
@@ -85,13 +88,13 @@ async def test_supervisor_two_node_response(
|
||||
|
||||
async def collect_response_0():
|
||||
nonlocal full_response_0
|
||||
async for chunk in supervisor_0.stream_response(task=chat_completion_task):
|
||||
async for chunk in supervisor_0.stream_response(task=chat_completion_task(instance_id)):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response_0 += chunk.text
|
||||
|
||||
async def collect_response_1():
|
||||
nonlocal full_response_1
|
||||
async for chunk in supervisor_1.stream_response(task=chat_completion_task):
|
||||
async for chunk in supervisor_1.stream_response(task=chat_completion_task(instance_id)):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response_1 += chunk.text
|
||||
|
||||
@@ -118,22 +121,25 @@ async def test_supervisor_two_node_response(
|
||||
async def test_supervisor_early_stopping(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Task,
|
||||
chat_completion_task: Callable[[InstanceId], Task],
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
instance_id = InstanceId()
|
||||
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
)
|
||||
|
||||
task = chat_completion_task(instance_id)
|
||||
|
||||
max_tokens = 50
|
||||
assert chat_completion_task.task_type == TaskType.CHAT_COMPLETION
|
||||
print(f'chat_completion_task.task_params: {chat_completion_task.task_params}')
|
||||
assert isinstance(chat_completion_task.task_params, ChatCompletionTaskParams)
|
||||
task_params: ChatCompletionTaskParams = chat_completion_task.task_params
|
||||
assert task.task_type == TaskType.CHAT_COMPLETION
|
||||
print(f'chat_completion_task.task_params: {task.task_params}')
|
||||
assert isinstance(task.task_params, ChatCompletionTaskParams)
|
||||
task_params: ChatCompletionTaskParams = task.task_params
|
||||
|
||||
try:
|
||||
task_params.max_tokens = max_tokens
|
||||
@@ -146,7 +152,7 @@ async def test_supervisor_early_stopping(
|
||||
count = 0
|
||||
stop_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in supervisor.stream_response(task=chat_completion_task):
|
||||
async for chunk in supervisor.stream_response(task=task):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response += chunk.text
|
||||
count += 1
|
||||
@@ -169,7 +175,6 @@ async def test_supervisor_early_stopping(
|
||||
async def test_supervisor_handles_terminated_runner(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Task,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test that the supervisor handles a terminated runner"""
|
||||
@@ -194,7 +199,6 @@ async def test_supervisor_handles_terminated_runner(
|
||||
async def test_supervisor_handles_killed_runner(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Task,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test that the supervisor handles a killed runner"""
|
||||
|
||||
@@ -16,7 +16,7 @@ from shared.types.events import (
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.tasks import Task, TaskStatus
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.instances import Instance
|
||||
from shared.types.worker.instances import Instance, InstanceId
|
||||
from shared.types.worker.ops import (
|
||||
AssignRunnerOp,
|
||||
DownloadOp,
|
||||
@@ -40,9 +40,9 @@ def user_message():
|
||||
return "What, according to Douglas Adams, is the meaning of life, the universe and everything?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assign_op(worker: Worker, instance: Callable[[NodeId, RunnerId], Instance], tmp_path: Path):
|
||||
async def test_assign_op(worker: Worker, instance: Callable[[InstanceId, NodeId, RunnerId], Instance], tmp_path: Path):
|
||||
runner_id = RunnerId()
|
||||
instance_obj: Instance = instance(worker.node_id, runner_id)
|
||||
instance_obj: Instance = instance(InstanceId(), worker.node_id, runner_id)
|
||||
|
||||
assign_op = AssignRunnerOp(
|
||||
runner_id=runner_id,
|
||||
@@ -84,7 +84,7 @@ async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
|
||||
assert isinstance(events[0], RunnerDeleted)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_completion_task: Task, tmp_path: Path):
|
||||
async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId, Instance], chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path):
|
||||
worker, runner_id, _ = worker_with_assigned_runner
|
||||
|
||||
runner_up_op = RunnerUpOp(runner_id=runner_id)
|
||||
@@ -104,7 +104,7 @@ async def test_runner_up_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
|
||||
|
||||
full_response = ''
|
||||
|
||||
async for chunk in supervisor.stream_response(task=chat_completion_task):
|
||||
async for chunk in supervisor.stream_response(task=chat_completion_task(InstanceId())):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response += chunk.text
|
||||
|
||||
@@ -153,12 +153,12 @@ async def test_download_op(worker_with_assigned_runner: tuple[Worker, RunnerId,
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_op(
|
||||
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
|
||||
chat_completion_task: Task, tmp_path: Path):
|
||||
chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path):
|
||||
worker, runner_id, _ = worker_with_running_runner
|
||||
|
||||
execute_task_op = ExecuteTaskOp(
|
||||
runner_id=runner_id,
|
||||
task=chat_completion_task
|
||||
task=chat_completion_task(InstanceId())
|
||||
)
|
||||
|
||||
events: list[Event] = []
|
||||
@@ -196,15 +196,16 @@ async def test_execute_task_op(
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_fails(
|
||||
worker_with_running_runner: tuple[Worker, RunnerId, Instance],
|
||||
chat_completion_task: Task, tmp_path: Path):
|
||||
chat_completion_task: Callable[[InstanceId], Task], tmp_path: Path):
|
||||
worker, runner_id, _ = worker_with_running_runner
|
||||
|
||||
messages = chat_completion_task.task_params.messages
|
||||
task = chat_completion_task(InstanceId())
|
||||
messages = task.task_params.messages
|
||||
messages[0].content = 'Artificial prompt: EXO RUNNER MUST FAIL'
|
||||
|
||||
execute_task_op = ExecuteTaskOp(
|
||||
runner_id=runner_id,
|
||||
task=chat_completion_task
|
||||
task=task
|
||||
)
|
||||
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -1,27 +1,39 @@
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
from typing import Awaitable, Callable, Final
|
||||
|
||||
import pytest
|
||||
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import (
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskCreated,
|
||||
)
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import Instance, InstanceStatus
|
||||
from shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceStatus,
|
||||
ShardAssignments,
|
||||
)
|
||||
from shared.types.worker.mlx import Host
|
||||
from shared.types.worker.runners import (
|
||||
FailedRunnerStatus,
|
||||
LoadedRunnerStatus,
|
||||
ReadyRunnerStatus,
|
||||
# RunningRunnerStatus,
|
||||
)
|
||||
from worker.main import Worker
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.main import AssignedRunner, Worker
|
||||
from worker.tests.test_worker_integration_utils import read_streaming_response
|
||||
|
||||
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
@@ -42,14 +54,14 @@ def user_message():
|
||||
|
||||
async def test_runner_assigned(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[NodeId, RunnerId], Instance]
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
|
||||
):
|
||||
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
print(worker)
|
||||
|
||||
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.INACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
@@ -79,12 +91,12 @@ async def test_runner_assigned(
|
||||
|
||||
async def test_runner_assigned_active(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Task
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Callable[[InstanceId], Task]
|
||||
):
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
@@ -118,7 +130,7 @@ async def test_runner_assigned_active(
|
||||
|
||||
full_response = ''
|
||||
|
||||
async for chunk in supervisor.stream_response(task=chat_completion_task):
|
||||
async for chunk in supervisor.stream_response(task=chat_completion_task(INSTANCE_1_ID)):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response += chunk.text
|
||||
|
||||
@@ -128,11 +140,11 @@ async def test_runner_assigned_active(
|
||||
|
||||
async def test_runner_assigned_wrong_node(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[NodeId, RunnerId], Instance]
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
|
||||
):
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value = instance(NODE_B, RUNNER_1_ID)
|
||||
instance_value = instance(INSTANCE_1_ID, NODE_B, RUNNER_1_ID)
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
@@ -157,11 +169,11 @@ async def test_runner_assigned_wrong_node(
|
||||
|
||||
async def test_runner_unassigns(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[NodeId, RunnerId], Instance]
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
|
||||
):
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(NODE_A, RUNNER_1_ID)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
@@ -206,4 +218,254 @@ async def test_runner_unassigns(
|
||||
events = await global_events.get_events_since(0)
|
||||
assert isinstance(events[-1].event, RunnerDeleted)
|
||||
# After deletion, runner should be removed from state.runners
|
||||
assert len(worker.state.runners) == 0
|
||||
assert len(worker.state.runners) == 0
|
||||
|
||||
async def test_runner_inference(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Callable[[InstanceId], Task]
|
||||
):
|
||||
_worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
task: Task = chat_completion_task(INSTANCE_1_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value,
|
||||
),
|
||||
TaskCreated(
|
||||
task_id=task.task_id,
|
||||
task=task
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
assert 'tokyo' in response_string.lower()
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance_value.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
async def test_2_runner_inference(
|
||||
logger: Logger,
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
hosts: Callable[[int], list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId], Task]
|
||||
):
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
|
||||
global_events = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(NODE_A, logger=logger, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(worker1.run())
|
||||
|
||||
worker2 = Worker(NODE_B, logger=logger, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(worker2.run())
|
||||
|
||||
## Instance
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_id,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: pipeline_shard_meta(2, 0),
|
||||
RUNNER_2_ID: pipeline_shard_meta(2, 1)
|
||||
},
|
||||
node_to_runner={
|
||||
NODE_A: RUNNER_1_ID,
|
||||
NODE_B: RUNNER_2_ID
|
||||
}
|
||||
)
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2)
|
||||
)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance
|
||||
),
|
||||
TaskCreated(
|
||||
task_id=task.task_id,
|
||||
task=task
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
assert 'tokyo' in response_string.lower()
|
||||
|
||||
|
||||
idx = await global_events.get_last_idx()
|
||||
await asyncio.sleep(1.0)
|
||||
events = await global_events.get_events_since(idx)
|
||||
assert len(events) == 0
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
|
||||
|
||||
async def test_runner_respawn(
|
||||
logger: Logger,
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
hosts: Callable[[int], list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId], Task]
|
||||
):
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
|
||||
global_events = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(NODE_A, logger=logger, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(worker1.run())
|
||||
|
||||
worker2 = Worker(NODE_B, logger=logger, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(worker2.run())
|
||||
|
||||
## Instance
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_id,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: pipeline_shard_meta(2, 0),
|
||||
RUNNER_2_ID: pipeline_shard_meta(2, 1)
|
||||
},
|
||||
node_to_runner={
|
||||
NODE_A: RUNNER_1_ID,
|
||||
NODE_B: RUNNER_2_ID
|
||||
}
|
||||
)
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2)
|
||||
)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance
|
||||
),
|
||||
TaskCreated(
|
||||
task_id=task.task_id,
|
||||
task=task
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
assert 'tokyo' in response_string.lower()
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
idx = await global_events.get_last_idx()
|
||||
|
||||
assigned_runner: AssignedRunner = worker1.assigned_runners[RUNNER_1_ID]
|
||||
assert assigned_runner.runner is not None
|
||||
assigned_runner.runner.runner_process.kill()
|
||||
|
||||
# Wait for the process to actually be detected as dead or cleaned up
|
||||
for _ in range(100): # Wait up to 1 second
|
||||
await asyncio.sleep(0.01)
|
||||
# The worker may clean up the runner (set to None) when it detects it's dead
|
||||
if assigned_runner.runner and not assigned_runner.runner.healthy:
|
||||
break
|
||||
else:
|
||||
raise AssertionError("Runner should have been detected as unhealthy or cleaned up after kill()")
|
||||
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
events = await global_events.get_events_since(idx)
|
||||
print(f'{events=}')
|
||||
# assert len(events) == 2
|
||||
assert isinstance(events[0].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[0].event.runner_status, FailedRunnerStatus)
|
||||
|
||||
assert isinstance(events[1].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[1].event.runner_status, ReadyRunnerStatus)
|
||||
assert events[1].event.runner_id == RUNNER_2_ID
|
||||
|
||||
assert isinstance(events[2].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[2].event.runner_status, ReadyRunnerStatus)
|
||||
assert events[2].event.runner_id == RUNNER_1_ID
|
||||
|
||||
print(worker1.state)
|
||||
print(worker2.state)
|
||||
|
||||
for event in [events[3].event, events[4].event]:
|
||||
assert isinstance(event, RunnerStatusUpdated)
|
||||
assert isinstance(event.runner_status, LoadedRunnerStatus)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
TaskCreated(
|
||||
task_id=task.task_id,
|
||||
task=task
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
assert 'tokyo' in response_string.lower()
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
44
worker/tests/test_worker_integration_utils.py
Normal file
44
worker/tests/test_worker_integration_utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
from typing import Tuple
|
||||
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.types.events import ChunkGenerated, TaskStateUpdated
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.tasks import TaskStatus
|
||||
|
||||
|
||||
async def read_streaming_response(global_events: AsyncSQLiteEventStorage) -> Tuple[bool, bool, str]:
|
||||
# Read off all events - these should be our GenerationChunk events
|
||||
seen_task_started, seen_task_finished = 0, 0
|
||||
response_string = ''
|
||||
finish_reason: str | None = None
|
||||
|
||||
idx = 0
|
||||
while not finish_reason:
|
||||
events = await global_events.get_events_since(idx)
|
||||
if len(events) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
idx = events[-1].idx_in_log
|
||||
|
||||
for wrapped_event in events:
|
||||
event = wrapped_event.event
|
||||
if isinstance(event, TaskStateUpdated):
|
||||
if event.task_status == TaskStatus.RUNNING:
|
||||
seen_task_started += 1
|
||||
if event.task_status == TaskStatus.COMPLETE:
|
||||
seen_task_finished += 1
|
||||
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
response_string += event.chunk.text
|
||||
if event.chunk.finish_reason:
|
||||
finish_reason = event.chunk.finish_reason
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
print(f'event log: {await global_events.get_events_since(0)}')
|
||||
|
||||
return seen_task_started == 1, seen_task_finished == 1, response_string
|
||||
@@ -481,7 +481,23 @@ def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
||||
)
|
||||
},
|
||||
runners={RUNNER_1_ID: LoadedRunnerStatus()},
|
||||
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
|
||||
tasks={
|
||||
TASK_1_ID: ChatCompletionTask(
|
||||
task_id=TASK_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="user",
|
||||
content="Hello, world!"
|
||||
)
|
||||
]
|
||||
),
|
||||
instance_id=INSTANCE_1_ID
|
||||
)
|
||||
},
|
||||
),
|
||||
expected_op=ExecuteTaskOp(runner_id=RUNNER_1_ID, task=ChatCompletionTask(
|
||||
task_id=TASK_1_ID,
|
||||
|
||||
Reference in New Issue
Block a user