Worker tests on staging 1

Test plan
This commit is contained in:
rltakashige
2025-11-21 15:22:40 +00:00
committed by GitHub
parent b45cbdeecd
commit de50811313
8 changed files with 980 additions and 0 deletions

View File

@@ -5,6 +5,10 @@ from typing import Generator
import pytest
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
@@ -19,3 +23,21 @@ def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
def reset_event_loop():
"""Reset the event loop for each test to ensure clean state."""
# This ensures each test gets a fresh event loop state
def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
),
device_rank=device_rank,
world_size=world_size,
start_layer=0,
end_layer=32,
n_layers=32,
)

View File

@@ -0,0 +1,45 @@
from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
def test_apply_node_download_progress():
state = State()
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
)
new_state = apply_node_download_progress(
NodeDownloadProgress(download_progress=event), state
)
assert new_state == State(downloads={NodeId("node-1"): [event]})
def test_apply_two_node_download_progress():
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_B_ID, device_rank=0, world_size=2)
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
)
state = State(downloads={NodeId("node-1"): [event1]})
new_state = apply_node_download_progress(
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})

View File

View File

@@ -0,0 +1,71 @@
from dataclasses import dataclass
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@dataclass(frozen=True)
class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
class OtherTask(BaseTask):
pass
# TODO: Is this actually better than using Mock/Fake dataclasses?
# e.g. commit d01cd292344df15759070966826a6c027945792b
def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
),
device_rank=device_rank,
world_size=world_size,
start_layer=0,
end_layer=32,
n_layers=32,
)
def get_shard_assignments(
model_id: ModelId,
node_to_runner: dict[NodeId, RunnerId],
runner_to_shard: dict[RunnerId, ShardMetadata],
) -> ShardAssignments:
return ShardAssignments(
model_id=model_id,
node_to_runner=node_to_runner,
runner_to_shard=runner_to_shard,
)
def get_mlx_ring_instance(
instance_id: InstanceId,
model_id: ModelId,
node_to_runner: dict[NodeId, RunnerId],
runner_to_shard: dict[RunnerId, ShardMetadata],
) -> Instance:
return MlxRingInstance(
instance_id=instance_id,
shard_assignments=get_shard_assignments(
model_id, node_to_runner, runner_to_shard
),
hosts=[],
)

View File

@@ -0,0 +1,207 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerWaitingForModel,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
)
from exo.worker.tests.unittests.test_plan.conftest import (
FakeRunnerSupervisor,
get_mlx_ring_instance,
get_pipeline_shard_metadata,
)
def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
"""
When a runner is waiting for a model and its shard is not in the
local download_status map, plan() should emit DownloadModel.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ShardMetadata, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=download_status,
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, plan_mod.DownloadModel)
assert result.instance_id == INSTANCE_1_ID
assert result.shard_metadata == shard
def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
"""
When all shards for an instance are DownloadCompleted (globally) and
all runners are in waiting/loading/loaded states, plan() should emit
LoadModel once.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerWaitingForModel(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_request_download_when_shard_already_downloaded():
"""
If the local shard already has a DownloadCompleted entry, plan()
should not re-emit DownloadModel while global state is still catching up.
"""
shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
# Local status claims the shard is downloaded already
local_download_status = {
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [],
NODE_B: [],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
"""
LoadModel should not be emitted while some shards are still missing from
the global_download_status.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerWaitingForModel(),
}
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [], # NODE_B has no downloads completed yet
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None

View File

@@ -0,0 +1,194 @@
from typing import Any
import exo.worker.plan as plan_mod
from exo.shared.types.tasks import Shutdown
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerReady,
RunnerStatus,
)
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
)
from .conftest import (
FakeRunnerSupervisor,
get_mlx_ring_instance,
get_pipeline_shard_metadata,
)
def test_plan_kills_runner_when_instance_missing():
"""
If a local runner's instance is no longer present in state,
plan() should return a Shutdown for that runner.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
runners = {RUNNER_1_ID: runner}
instances: dict[InstanceId, Instance] = {}
all_runners = {RUNNER_1_ID: RunnerReady()}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, Shutdown)
assert result.instance_id == INSTANCE_1_ID
assert result.runner_id == RUNNER_1_ID
def test_plan_kills_runner_when_sibling_failed():
"""
If a sibling runner in the same instance has failed, the local runner
should be shut down.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerReady(),
RUNNER_2_ID: RunnerFailed(error_message="boom"),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, Shutdown)
assert result.instance_id == INSTANCE_1_ID
assert result.runner_id == RUNNER_1_ID
def test_plan_creates_runner_when_missing_for_node():
"""
If shard_assignments specify a runner for this node but we don't have
a local supervisor yet, plan() should emit a CreateRunner.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
runners: dict[Any, Any] = {} # nothing local yet
instances = {INSTANCE_1_ID: instance}
all_runners: dict[Any, Any] = {}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners,
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
# We patched plan_mod.CreateRunner → CreateRunner
assert isinstance(result, plan_mod.CreateRunner)
assert result.instance_id == INSTANCE_1_ID
assert isinstance(result.bound_instance, BoundInstance)
assert result.bound_instance.instance is instance
assert result.bound_instance.bound_runner_id == RUNNER_1_ID
def test_plan_does_not_create_runner_when_supervisor_already_present():
"""
If we already have a local supervisor for the runner assigned to this node,
plan() should not emit a CreateRunner again.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerReady()}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None
def test_plan_does_not_create_runner_for_unassigned_node():
"""
If this node does not appear in shard_assignments.node_to_runner,
plan() should not try to create a runner on this node.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_2_ID: shard},
)
runners: dict[RunnerId, FakeRunnerSupervisor] = {} # no local runners
instances = {INSTANCE_1_ID: instance}
all_runners: dict[RunnerId, RunnerStatus] = {}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None

View File

@@ -0,0 +1,262 @@
from typing import cast
import exo.worker.plan as plan_mod
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
RunnerReady,
RunnerRunning,
RunnerWaitingForModel,
)
from exo.worker.tests.constants import (
COMMAND_1_ID,
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
)
from .conftest import (
FakeRunnerSupervisor,
OtherTask,
get_mlx_ring_instance,
get_pipeline_shard_metadata,
)
def test_plan_forwards_pending_chat_completion_when_runner_ready():
"""
When there is a pending ChatCompletion for the local instance and all
runners are Ready/Running, plan() should forward that task.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerReady()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerReady(),
RUNNER_2_ID: RunnerReady(),
}
task = ChatCompletion(
task_id=TASK_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
)
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={TASK_1_ID: task},
)
assert result is task
def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
"""
Even with a pending ChatCompletion, plan() should not forward it unless
all runners for the instance are Ready/Running.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerReady()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerReady(),
RUNNER_2_ID: RunnerWaitingForModel(),
}
task = ChatCompletion(
task_id=TASK_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
)
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={TASK_1_ID: task},
)
assert result is None
def test_plan_does_not_forward_tasks_for_other_instances():
"""
plan() should ignore pending ChatCompletion tasks whose instance_id does
not match the local instance.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
local_instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=local_instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerReady()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: local_instance}
all_runners = {RUNNER_1_ID: RunnerReady()}
other_instance_id = InstanceId("instance-2")
foreign_task = ChatCompletion(
task_id=TaskId("other-task"),
instance_id=other_instance_id,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
)
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={foreign_task.task_id: foreign_task},
)
assert result is None
def test_plan_ignores_non_pending_or_non_chat_tasks():
"""
_pending_tasks should not forward tasks that are either not ChatCompletion
or not in Pending/Running states.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerReady()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerReady(),
RUNNER_2_ID: RunnerReady(),
}
completed_task = ChatCompletion(
task_id=TASK_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Complete,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
)
other_task_id = TaskId("other-task")
other_task = cast(
Task,
cast(
object,
OtherTask(
task_id=other_task_id,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
),
),
)
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={TASK_1_ID: completed_task, other_task_id: other_task},
)
assert result is None
def test_plan_returns_none_when_nothing_to_do():
"""
If there are healthy runners, no downloads needed, and no pending tasks,
plan() should return None (steady state).
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerRunning()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerRunning(),
RUNNER_2_ID: RunnerRunning(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None

View File

@@ -0,0 +1,179 @@
import exo.worker.plan as plan_mod
from exo.shared.types.tasks import StartWarmup
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerLoaded,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
)
from .conftest import (
FakeRunnerSupervisor,
get_mlx_ring_instance,
get_pipeline_shard_metadata,
)
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
"""
For non-zero device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_starts_warmup_for_rank_zero_after_others_warming():
"""
For device_rank == 0, StartWarmup should only be emitted once all the
other runners in the instance are already warming up.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warming():
"""
Non-zero rank should not start warmup while any shard is not Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None
def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None