Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
2a0b9dc727 move runner ack after status update 2026-02-05 00:07:43 +00:00
3 changed files with 14 additions and 12 deletions

View File

@@ -292,14 +292,9 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard

View File

@@ -140,12 +140,12 @@ def main(
with task_receiver as tasks:
for task in tasks:
if task.task_id in seen:
logger.warning("repeat task - potential error")
logger.warning("repeat task - currently a logic bug, please report")
continue
seen.add(task.task_id)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
@@ -157,6 +157,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
group = initialize_mlx(bound_instance)
logger.info("runner connected")
@@ -173,6 +174,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
def on_model_load_timeout() -> None:
event_sender.send(
@@ -215,6 +217,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
@@ -254,6 +257,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
@@ -385,6 +389,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
@@ -447,6 +452,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
@@ -502,6 +508,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
current_status = RunnerShutdown()
case _:
raise ValueError(

View File

@@ -201,29 +201,29 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskAcknowledged(task_id=LOAD_TASK_ID),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
@@ -231,10 +231,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),