mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Worker tests on staging 1
Test plan
This commit is contained in:
@@ -5,6 +5,10 @@ from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
@@ -19,3 +23,21 @@ def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
def reset_event_loop():
|
||||
"""Reset the event loop for each test to ensure clean state."""
|
||||
# This ensures each test gets a fresh event loop state
|
||||
|
||||
|
||||
def get_pipeline_shard_metadata(
|
||||
model_id: ModelId, device_rank: int, world_size: int = 1
|
||||
) -> ShardMetadata:
|
||||
return PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=model_id,
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
end_layer=32,
|
||||
n_layers=32,
|
||||
)
|
||||
|
||||
45
src/exo/shared/tests/test_apply/test_apply_node_download.py
Normal file
45
src/exo/shared/tests/test_apply/test_apply_node_download.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from exo.shared.apply import apply_node_download_progress
|
||||
from exo.shared.tests.conftest import get_pipeline_shard_metadata
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import NodeDownloadProgress
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
|
||||
|
||||
|
||||
def test_apply_node_download_progress():
|
||||
state = State()
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
NodeDownloadProgress(download_progress=event), state
|
||||
)
|
||||
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event]})
|
||||
|
||||
|
||||
def test_apply_two_node_download_progress():
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_B_ID, device_rank=0, world_size=2)
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
NodeDownloadProgress(download_progress=event2), state
|
||||
)
|
||||
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
|
||||
71
src/exo/worker/tests/unittests/test_plan/conftest.py
Normal file
71
src/exo/worker/tests/unittests/test_plan/conftest.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import BaseTask
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Is this actually better than using Mock/Fake dataclasses?
|
||||
# e.g. commit d01cd292344df15759070966826a6c027945792b
|
||||
def get_pipeline_shard_metadata(
|
||||
model_id: ModelId, device_rank: int, world_size: int = 1
|
||||
) -> ShardMetadata:
|
||||
return PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=model_id,
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
end_layer=32,
|
||||
n_layers=32,
|
||||
)
|
||||
|
||||
|
||||
def get_shard_assignments(
|
||||
model_id: ModelId,
|
||||
node_to_runner: dict[NodeId, RunnerId],
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata],
|
||||
) -> ShardAssignments:
|
||||
return ShardAssignments(
|
||||
model_id=model_id,
|
||||
node_to_runner=node_to_runner,
|
||||
runner_to_shard=runner_to_shard,
|
||||
)
|
||||
|
||||
|
||||
def get_mlx_ring_instance(
|
||||
instance_id: InstanceId,
|
||||
model_id: ModelId,
|
||||
node_to_runner: dict[NodeId, RunnerId],
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata],
|
||||
) -> Instance:
|
||||
return MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=get_shard_assignments(
|
||||
model_id, node_to_runner, runner_to_shard
|
||||
),
|
||||
hosts=[],
|
||||
)
|
||||
@@ -0,0 +1,207 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_plan.conftest import (
|
||||
FakeRunnerSupervisor,
|
||||
get_mlx_ring_instance,
|
||||
get_pipeline_shard_metadata,
|
||||
)
|
||||
|
||||
|
||||
def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
"""
|
||||
When a runner is waiting for a model and its shard is not in the
|
||||
local download_status map, plan() should emit DownloadModel.
|
||||
"""
|
||||
|
||||
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=download_status,
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, plan_mod.DownloadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.shard_metadata == shard
|
||||
|
||||
|
||||
def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
"""
|
||||
When all shards for an instance are DownloadCompleted (globally) and
|
||||
all runners are in waiting/loading/loaded states, plan() should emit
|
||||
LoadModel once.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
"""
|
||||
If the local shard already has a DownloadCompleted entry, plan()
|
||||
should not re-emit DownloadModel while global state is still catching up.
|
||||
"""
|
||||
shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
"""
|
||||
LoadModel should not be emitted while some shards are still missing from
|
||||
the global_download_status.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
@@ -0,0 +1,194 @@
|
||||
from typing import Any
|
||||
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.tasks import Shutdown
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerReady,
|
||||
RunnerStatus,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
)
|
||||
|
||||
from .conftest import (
|
||||
FakeRunnerSupervisor,
|
||||
get_mlx_ring_instance,
|
||||
get_pipeline_shard_metadata,
|
||||
)
|
||||
|
||||
|
||||
def test_plan_kills_runner_when_instance_missing():
|
||||
"""
|
||||
If a local runner's instance is no longer present in state,
|
||||
plan() should return a Shutdown for that runner.
|
||||
"""
|
||||
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances: dict[InstanceId, Instance] = {}
|
||||
all_runners = {RUNNER_1_ID: RunnerReady()}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, Shutdown)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.runner_id == RUNNER_1_ID
|
||||
|
||||
|
||||
def test_plan_kills_runner_when_sibling_failed():
|
||||
"""
|
||||
If a sibling runner in the same instance has failed, the local runner
|
||||
should be shut down.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerReady(),
|
||||
RUNNER_2_ID: RunnerFailed(error_message="boom"),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, Shutdown)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.runner_id == RUNNER_1_ID
|
||||
|
||||
|
||||
def test_plan_creates_runner_when_missing_for_node():
|
||||
"""
|
||||
If shard_assignments specify a runner for this node but we don't have
|
||||
a local supervisor yet, plan() should emit a CreateRunner.
|
||||
"""
|
||||
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
|
||||
runners: dict[Any, Any] = {} # nothing local yet
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners: dict[Any, Any] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners,
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
# We patched plan_mod.CreateRunner → CreateRunner
|
||||
assert isinstance(result, plan_mod.CreateRunner)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert isinstance(result.bound_instance, BoundInstance)
|
||||
assert result.bound_instance.instance is instance
|
||||
assert result.bound_instance.bound_runner_id == RUNNER_1_ID
|
||||
|
||||
|
||||
def test_plan_does_not_create_runner_when_supervisor_already_present():
|
||||
"""
|
||||
If we already have a local supervisor for the runner assigned to this node,
|
||||
plan() should not emit a CreateRunner again.
|
||||
"""
|
||||
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerReady()}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_create_runner_for_unassigned_node():
|
||||
"""
|
||||
If this node does not appear in shard_assignments.node_to_runner,
|
||||
plan() should not try to create a runner on this node.
|
||||
"""
|
||||
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_2_ID: shard},
|
||||
)
|
||||
|
||||
runners: dict[RunnerId, FakeRunnerSupervisor] = {} # no local runners
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners: dict[RunnerId, RunnerStatus] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
262
src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py
Normal file
262
src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from typing import cast
|
||||
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
COMMAND_1_ID,
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
TASK_1_ID,
|
||||
)
|
||||
|
||||
from .conftest import (
|
||||
FakeRunnerSupervisor,
|
||||
OtherTask,
|
||||
get_mlx_ring_instance,
|
||||
get_pipeline_shard_metadata,
|
||||
)
|
||||
|
||||
|
||||
def test_plan_forwards_pending_chat_completion_when_runner_ready():
|
||||
"""
|
||||
When there is a pending ChatCompletion for the local instance and all
|
||||
runners are Ready/Running, plan() should forward that task.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerReady()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerReady(),
|
||||
RUNNER_2_ID: RunnerReady(),
|
||||
}
|
||||
|
||||
task = ChatCompletion(
|
||||
task_id=TASK_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_status=TaskStatus.Pending,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
|
||||
)
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={TASK_1_ID: task},
|
||||
)
|
||||
|
||||
assert result is task
|
||||
|
||||
|
||||
def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
"""
|
||||
Even with a pending ChatCompletion, plan() should not forward it unless
|
||||
all runners for the instance are Ready/Running.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerReady()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerReady(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
}
|
||||
|
||||
task = ChatCompletion(
|
||||
task_id=TASK_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_status=TaskStatus.Pending,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
|
||||
)
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={TASK_1_ID: task},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_forward_tasks_for_other_instances():
|
||||
"""
|
||||
plan() should ignore pending ChatCompletion tasks whose instance_id does
|
||||
not match the local instance.
|
||||
"""
|
||||
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
|
||||
local_instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=local_instance, bound_runner_id=RUNNER_1_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerReady()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: local_instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerReady()}
|
||||
|
||||
other_instance_id = InstanceId("instance-2")
|
||||
foreign_task = ChatCompletion(
|
||||
task_id=TaskId("other-task"),
|
||||
instance_id=other_instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
|
||||
)
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={foreign_task.task_id: foreign_task},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_ignores_non_pending_or_non_chat_tasks():
|
||||
"""
|
||||
_pending_tasks should not forward tasks that are either not ChatCompletion
|
||||
or not in Pending/Running states.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerReady()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerReady(),
|
||||
RUNNER_2_ID: RunnerReady(),
|
||||
}
|
||||
|
||||
completed_task = ChatCompletion(
|
||||
task_id=TASK_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_status=TaskStatus.Complete,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
|
||||
)
|
||||
|
||||
other_task_id = TaskId("other-task")
|
||||
|
||||
other_task = cast(
|
||||
Task,
|
||||
cast(
|
||||
object,
|
||||
OtherTask(
|
||||
task_id=other_task_id,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_status=TaskStatus.Pending,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={TASK_1_ID: completed_task, other_task_id: other_task},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_returns_none_when_nothing_to_do():
|
||||
"""
|
||||
If there are healthy runners, no downloads needed, and no pending tasks,
|
||||
plan() should return None (steady state).
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerRunning()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerRunning(),
|
||||
RUNNER_2_ID: RunnerRunning(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
179
src/exo/worker/tests/unittests/test_plan/test_warmup.py
Normal file
179
src/exo/worker/tests/unittests/test_plan/test_warmup.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.tasks import StartWarmup
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerLoaded,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
)
|
||||
|
||||
from .conftest import (
|
||||
FakeRunnerSupervisor,
|
||||
get_mlx_ring_instance,
|
||||
get_pipeline_shard_metadata,
|
||||
)
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
"""
|
||||
For non-zero device_rank shards, StartWarmup should be emitted when all
|
||||
shards in the instance are Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_rank_zero_after_others_warming():
|
||||
"""
|
||||
For device_rank == 0, StartWarmup should only be emitted once all the
|
||||
other runners in the instance are already warming up.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerWarmingUp(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warming():
|
||||
"""
|
||||
Non-zero rank should not start warmup while any shard is not Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
"""
|
||||
Rank-zero shard should not start warmup until all non-zero ranks are
|
||||
already WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
Reference in New Issue
Block a user