mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Co-authored-by: Gelu Vrabie <gelu.vrabie.univ@gmail.com> Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
914 lines
38 KiB
Python
914 lines
38 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from shared.types.api import ChatCompletionMessage
|
|
from shared.types.state import State
|
|
from shared.types.tasks import (
|
|
ChatCompletionTask,
|
|
ChatCompletionTaskParams,
|
|
TaskStatus,
|
|
TaskType,
|
|
)
|
|
from shared.types.worker.common import NodeStatus
|
|
from shared.types.worker.downloads import DownloadPending
|
|
from shared.types.worker.instances import Instance, InstanceStatus
|
|
from shared.types.worker.ops import (
|
|
AssignRunnerOp,
|
|
DownloadOp,
|
|
ExecuteTaskOp,
|
|
RunnerDownOp,
|
|
RunnerUpOp,
|
|
UnassignRunnerOp,
|
|
)
|
|
from shared.types.worker.runners import (
|
|
AssignedRunnerStatus,
|
|
DownloadingRunnerStatus,
|
|
FailedRunnerStatus,
|
|
LoadedRunnerStatus,
|
|
ReadyRunnerStatus,
|
|
RunningRunnerStatus,
|
|
ShardAssignments,
|
|
)
|
|
from shared.types.worker.shards import PipelineShardMetadata
|
|
from worker.download.download_utils import build_model_path
|
|
from worker.download.shard_downloader import NoopShardDownloader
|
|
from worker.main import AssignedRunner, Worker
|
|
|
|
from .test_worker_plan_utils import (
|
|
COMMAND_1_ID,
|
|
INSTANCE_1_ID,
|
|
MODEL_A_ID,
|
|
NODE_A,
|
|
NODE_B,
|
|
RUNNER_1_ID,
|
|
RUNNER_2_ID,
|
|
TASK_1_ID,
|
|
InProcessRunner,
|
|
PlanTestCase,
|
|
make_downloading_status,
|
|
make_model_meta,
|
|
make_shard_metadata,
|
|
)
|
|
|
|
"""
|
|
The idea with these tests is to define declaratively the input and expected output of the worker.plan function.
|
|
|
|
We initialize a Worker with InProcessRunners. We then construct a State which gets passed to Worker.plan.
|
|
We then check what operation is returned by Worker.plan.
|
|
"""
|
|
|
|
def _get_test_cases(tmp_path: Path) -> list[PlanTestCase]:
|
|
# The `model_path` for `RUNNER_1_ID` must exist for the `DownloadOp` test case to pass validation.
|
|
(tmp_path / f"model_for_runner_{RUNNER_1_ID}").mkdir(exist_ok=True, parents=True)
|
|
model_a_meta = make_model_meta(MODEL_A_ID)
|
|
return [
|
|
PlanTestCase(
|
|
description="no runners -> no-op",
|
|
in_process_runners=[],
|
|
state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}),
|
|
expected_op=None,
|
|
),
|
|
|
|
# I don't think this should ever happen, as if it's currently downloading then the worker loop will be blocked
|
|
# Potentially useful for future compatibility when worker becomes non-blocking
|
|
PlanTestCase(
|
|
description="runner state assigned, runner is assigned and downloading -> no-op",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=make_downloading_status(NODE_A),
|
|
downloaded=False,
|
|
)
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.INACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
|
|
),
|
|
expected_op=None,
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="runner state downloading, runner is downloading -> no-op",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=make_downloading_status(NODE_A),
|
|
downloaded=False,
|
|
)
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.INACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: make_downloading_status(NODE_A)},
|
|
),
|
|
expected_op=None,
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="ready runner, model present -> no-op",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=ReadyRunnerStatus(),
|
|
downloaded=True,
|
|
)
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.INACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: ReadyRunnerStatus()},
|
|
),
|
|
expected_op=None,
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="runner assigned and not in state -> AssignRunnerOp",
|
|
in_process_runners=[],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE, # Either active or inactive should yield the same.
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: AssignedRunnerStatus()},
|
|
),
|
|
expected_op=AssignRunnerOp(
|
|
instance_id=INSTANCE_1_ID,
|
|
runner_id=RUNNER_1_ID,
|
|
shard_metadata=PipelineShardMetadata(
|
|
device_rank=0,
|
|
world_size=1,
|
|
model_meta=model_a_meta,
|
|
start_layer=0,
|
|
end_layer=1,
|
|
n_layers=1,
|
|
),
|
|
hosts=[]
|
|
),
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="runner assigned but no longer in state -> UnassignRunnerOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=AssignedRunnerStatus(),
|
|
downloaded=False,
|
|
)
|
|
],
|
|
state=State(node_status={NODE_A: NodeStatus.Idle}, instances={}, runners={}),
|
|
expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID),
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="runner state assigned, runner is assigned, not downloaded -> expect DownloadOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=AssignedRunnerStatus(),
|
|
downloaded=False,
|
|
)
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: AssignedRunnerStatus()},
|
|
),
|
|
expected_op=DownloadOp(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_metadata=PipelineShardMetadata(
|
|
device_rank=0,
|
|
world_size=1,
|
|
model_meta=model_a_meta,
|
|
start_layer=0,
|
|
end_layer=1,
|
|
n_layers=1,
|
|
),
|
|
hosts=[],
|
|
),
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="ready runner (and state up) -> expect RunnerUpOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=ReadyRunnerStatus(),
|
|
downloaded=True,
|
|
)
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: ReadyRunnerStatus()},
|
|
tasks={},
|
|
),
|
|
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="1 ready, 1 downloading (and state up) -> no-op",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=ReadyRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
InProcessRunner(
|
|
runner_id=RUNNER_2_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=DownloadingRunnerStatus(
|
|
download_progress=DownloadPending(node_id=NODE_A)
|
|
),
|
|
downloaded=False,
|
|
device_rank=1,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
|
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: DownloadingRunnerStatus(download_progress=DownloadPending(node_id=NODE_A))},
|
|
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
|
|
),
|
|
expected_op=None
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="2 ready runners (and state up) -> expect RunnerUpOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=ReadyRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
InProcessRunner(
|
|
runner_id=RUNNER_2_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=ReadyRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=1,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
|
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: ReadyRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()},
|
|
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
|
|
),
|
|
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID)
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="loaded runner (and state down) -> expect RunnerDownOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
)
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.INACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: LoadedRunnerStatus()},
|
|
tasks={},
|
|
),
|
|
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="failed runner (and state down) -> expect RunnerDownOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=FailedRunnerStatus(),
|
|
downloaded=True,
|
|
)
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.INACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: FailedRunnerStatus()},
|
|
tasks={},
|
|
),
|
|
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="loaded runner, model present, task pending -> expect ExecuteTaskOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
)
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: LoadedRunnerStatus()},
|
|
tasks={
|
|
TASK_1_ID: ChatCompletionTask(
|
|
task_id=TASK_1_ID,
|
|
command_id=COMMAND_1_ID,
|
|
task_type=TaskType.CHAT_COMPLETION,
|
|
task_status=TaskStatus.PENDING,
|
|
task_params=ChatCompletionTaskParams(
|
|
model=str(MODEL_A_ID),
|
|
messages=[
|
|
ChatCompletionMessage(
|
|
role="user",
|
|
content="Hello, world!"
|
|
)
|
|
]
|
|
),
|
|
instance_id=INSTANCE_1_ID
|
|
)
|
|
},
|
|
),
|
|
expected_op=ExecuteTaskOp(runner_id=RUNNER_1_ID, task=ChatCompletionTask(
|
|
task_id=TASK_1_ID,
|
|
command_id=COMMAND_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
task_type=TaskType.CHAT_COMPLETION,
|
|
task_status=TaskStatus.PENDING,
|
|
task_params=ChatCompletionTaskParams(
|
|
model=str(MODEL_A_ID),
|
|
messages=[ChatCompletionMessage(role="user", content="Hello, world!")]
|
|
),
|
|
)),
|
|
),
|
|
|
|
PlanTestCase(
|
|
# We should only run rank 0 once all other ranks are running.
|
|
description="two loaded runners & task, i'm rank 0 -> no-op",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
InProcessRunner(
|
|
runner_id=RUNNER_2_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=1,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
|
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
|
|
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
|
|
),
|
|
expected_op=None
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="two loaded runners & task, i'm rank 1 -> expect ExecuteTaskOp on rank 1",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=1,
|
|
),
|
|
InProcessRunner(
|
|
runner_id=RUNNER_2_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=1, world_size=2),
|
|
RUNNER_2_ID: make_shard_metadata(device_rank=0, world_size=2)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
|
|
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
|
|
),
|
|
expected_op=ExecuteTaskOp(
|
|
runner_id=RUNNER_1_ID,
|
|
task=ChatCompletionTask(
|
|
task_id=TASK_1_ID,
|
|
command_id=COMMAND_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
task_type=TaskType.CHAT_COMPLETION,
|
|
task_params=ChatCompletionTaskParams(
|
|
model=str(MODEL_A_ID),
|
|
messages=[ChatCompletionMessage(role="user", content="Hello, world!")],
|
|
),
|
|
task_status=TaskStatus.PENDING,
|
|
),
|
|
),
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="rank 1 loaded, rank 0 ready, i'm rank 0 -> expect ExecuteTaskOp on rank 0",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
InProcessRunner(
|
|
runner_id=RUNNER_2_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=RunningRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=1,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Running},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
|
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: RunningRunnerStatus()},
|
|
tasks={TASK_1_ID: ChatCompletionTask(task_id=TASK_1_ID, command_id=COMMAND_1_ID, task_type=TaskType.CHAT_COMPLETION, task_status=TaskStatus.PENDING, task_params=ChatCompletionTaskParams(model=str(MODEL_A_ID), messages=[ChatCompletionMessage(role="user", content="Hello, world!")]), instance_id=INSTANCE_1_ID)},
|
|
),
|
|
expected_op=ExecuteTaskOp(
|
|
runner_id=RUNNER_1_ID,
|
|
task=ChatCompletionTask(
|
|
task_id=TASK_1_ID,
|
|
command_id=COMMAND_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
task_type=TaskType.CHAT_COMPLETION,
|
|
task_params=ChatCompletionTaskParams(
|
|
model=str(MODEL_A_ID),
|
|
messages=[ChatCompletionMessage(role="user", content="Hello, world!")],
|
|
),
|
|
task_status=TaskStatus.PENDING,
|
|
),
|
|
),
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="other runner failed -> RunnerDownOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
InProcessRunner(
|
|
runner_id=RUNNER_2_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=FailedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=1,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
|
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: LoadedRunnerStatus(), RUNNER_2_ID: FailedRunnerStatus()},
|
|
),
|
|
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="this runner failed (1 node) -> RunnerDownOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=FailedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=1),
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: FailedRunnerStatus()},
|
|
),
|
|
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
|
|
),
|
|
|
|
|
|
PlanTestCase(
|
|
description="this runner failed (2 nodes) -> no-op",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=FailedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
InProcessRunner(
|
|
runner_id=RUNNER_2_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=LoadedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=1,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
|
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: LoadedRunnerStatus()},
|
|
),
|
|
expected_op=None
|
|
),
|
|
|
|
PlanTestCase(
|
|
description="this node failed, other node spun down -> RunnerDownOp",
|
|
in_process_runners=[
|
|
InProcessRunner(
|
|
runner_id=RUNNER_1_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=FailedRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=0,
|
|
),
|
|
InProcessRunner(
|
|
runner_id=RUNNER_2_ID,
|
|
instance_id=INSTANCE_1_ID,
|
|
model_id=MODEL_A_ID,
|
|
status=ReadyRunnerStatus(),
|
|
downloaded=True,
|
|
device_rank=1,
|
|
),
|
|
],
|
|
state=State(
|
|
node_status={NODE_A: NodeStatus.Idle, NODE_B: NodeStatus.Idle},
|
|
instances={
|
|
INSTANCE_1_ID: Instance(
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
instance_id=INSTANCE_1_ID,
|
|
shard_assignments=ShardAssignments(
|
|
model_id=MODEL_A_ID,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: make_shard_metadata(device_rank=0, world_size=2),
|
|
RUNNER_2_ID: make_shard_metadata(device_rank=1, world_size=2)
|
|
},
|
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}
|
|
),
|
|
hosts=[]
|
|
)
|
|
},
|
|
runners={RUNNER_1_ID: FailedRunnerStatus(), RUNNER_2_ID: ReadyRunnerStatus()},
|
|
),
|
|
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID)
|
|
),
|
|
|
|
]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Parametrised test
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
# Pre-compute readable identifiers for each case to avoid lambda typing issues.
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
# We use a factory to delay test case generation until tmp_path is available.
|
|
[pytest.param(c, id=c.id()) for c in _get_test_cases(Path(tempfile.TemporaryDirectory().name))],
|
|
)
|
|
def test_worker_plan(case: PlanTestCase, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Exercise Worker.plan across declarative scenarios."""
|
|
|
|
print(f"----- case: {case.description}")
|
|
|
|
# Regenerate test cases with the actual tmp_path fixture
|
|
test_cases = {c.description: c for c in _get_test_cases(tmp_path)}
|
|
case = test_cases[case.description]
|
|
|
|
node_id = NODE_A
|
|
|
|
logger = logging.getLogger("test_worker_plan")
|
|
shard_downloader = NoopShardDownloader()
|
|
worker = Worker(node_id=node_id, shard_downloader=shard_downloader, worker_events=None, global_events=None, logger=logger)
|
|
|
|
path_downloaded_map: dict[str, bool] = {}
|
|
|
|
runner_config: InProcessRunner
|
|
for runner_config in case.in_process_runners:
|
|
|
|
model_path = tmp_path / f"model_for_runner_{runner_config.runner_id}"
|
|
model_path.mkdir(exist_ok=True, parents=True)
|
|
|
|
if len(case.state.instances) == 1:
|
|
instance_id = next(iter(case.state.instances))
|
|
|
|
shard_assignments = case.state.instances[instance_id].shard_assignments
|
|
shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id]
|
|
|
|
# Only add this runner if it belongs to our node
|
|
runner_node = None
|
|
for node, runner in shard_assignments.node_to_runner.items():
|
|
if runner == runner_config.runner_id:
|
|
runner_node = node
|
|
break
|
|
|
|
if runner_node != node_id:
|
|
# This runner belongs to a different node, skip it
|
|
continue
|
|
|
|
elif len(case.state.instances) == 0:
|
|
shard_metadata = PipelineShardMetadata(
|
|
device_rank=runner_config.device_rank,
|
|
world_size=1,
|
|
model_meta=make_model_meta(runner_config.model_id),
|
|
start_layer=0,
|
|
end_layer=1,
|
|
n_layers=1,
|
|
)
|
|
else:
|
|
raise Exception('test_worker_plan not currently designed to have more than 1 instance.')
|
|
|
|
|
|
assigned_runner = AssignedRunner(
|
|
runner_id=runner_config.runner_id,
|
|
instance_id=runner_config.instance_id,
|
|
shard_metadata=shard_metadata,
|
|
hosts=[],
|
|
status=runner_config.status,
|
|
runner=None,
|
|
is_downloaded=runner_config.downloaded
|
|
)
|
|
worker.assigned_runners[runner_config.runner_id] = assigned_runner
|
|
path_downloaded_map[str(build_model_path(shard_metadata.model_meta.model_id))] = runner_config.downloaded
|
|
|
|
op = worker.plan(case.state)
|
|
assert op == case.expected_op
|