This commit is contained in:
Evan Quiney
2025-12-05 16:43:11 +00:00
committed by GitHub
parent 9e0a1c23ef
commit 7312a7e000

View File

@@ -174,11 +174,15 @@ def _ready_to_warmup(
shard = runner.bound_instance.bound_shard
device_rank = shard.device_rank
runner_id = runner.bound_instance.bound_runner_id
world_size = shard.world_size
is_runner_loaded = isinstance(runner.status, RunnerLoaded)
# Rank != 0
all_runners_loaded_or_warming_up = all(
assert device_rank < world_size
assert device_rank >= 0
# Rank != n-1
accepting_ranks_ready = device_rank != world_size - 1 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -186,17 +190,14 @@ def _ready_to_warmup(
for global_runner_id in shard_assignments.runner_to_shard
)
# Rank= 0
all_other_runners_warming_up = all(
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id
)
nonzero_rank_ready = device_rank != 0 and all_runners_loaded_or_warming_up
zero_rank_ready = device_rank == 0 and all_other_runners_warming_up
if is_runner_loaded and (nonzero_rank_ready or zero_rank_ready):
if is_runner_loaded and (accepting_ranks_ready or connecting_rank_ready):
return StartWarmup(instance_id=instance.instance_id)
return None