Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
7317c87458 skip repeat tasks in runner 2026-02-04 19:19:29 +00:00
2 changed files with 7 additions and 15 deletions

View File

@@ -120,11 +120,12 @@ def main(
with task_receiver as tasks:
for task in tasks:
if task.task_id in seen:
logger.warning("repeat task - potential error")
continue
seen.add(task.task_id)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
@@ -136,7 +137,6 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
group = initialize_mlx(bound_instance)
logger.info("runner connected")
@@ -153,7 +153,6 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
def on_model_load_timeout() -> None:
event_sender.send(
@@ -196,7 +195,6 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
@@ -236,8 +234,6 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
@@ -369,7 +365,6 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
# Generate images using the image generation backend
@@ -435,7 +430,6 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
try:
image_index = 0
@@ -494,8 +488,6 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
current_status = RunnerShutdown()
case _:
raise ValueError(

View File

@@ -201,29 +201,29 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
@@ -231,10 +231,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),