fix: defer batch prefill for true continuous batching

Move sync_and_insert_pending() out of the per-task loop so all
concurrently-arrived requests share a single batched prefill pass.
Previously each request was prefilled individually as it arrived,
serializing what should be a parallel operation.

Also break early from the generation TimeBudget loop when new tasks
are waiting, so they get inserted sooner rather than blocking for
the full 0.5s budget.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Alex Cheema
2026-02-15 17:52:04 -08:00
parent 41d9d2a61f
commit 35973b8698
3 changed files with 37 additions and 7 deletions

View File

@@ -454,8 +454,13 @@ def _run_generation_steps(
gpt_oss_trackers: dict[CommandId, GptOssTracker],
thinking_first_token: dict[CommandId, bool],
tokenizer: TokenizerWrapper | None,
) -> None:
"""Run one TimeBudget cycle of batch generation."""
task_receiver: MpReceiver[Task] | None = None,
) -> list[Task]:
"""Run one TimeBudget cycle of batch generation.
Returns any tasks collected via early exit so the caller can process them.
"""
early_tasks: list[Task] = []
if batch_engine.has_pending_inserts:
batch_engine.sync_and_insert_pending()
for _ in TimeBudget(budget=0.5, group=group): # pyright: ignore[reportArgumentType]
@@ -499,7 +504,14 @@ def _run_generation_steps(
thinking_first_token,
tokenizer,
)
# After processing a step, check if new tasks arrived so they can be
# batched sooner rather than waiting for the full time budget.
if task_receiver is not None and not early_tasks:
early_tasks = task_receiver.collect()
if early_tasks:
break
batch_engine.sync_completions()
return early_tasks
def _drain_batch_engine(
@@ -605,7 +617,7 @@ def main(
if batch_engine is not None and (
batch_engine.has_active_requests or batch_engine.has_pending_inserts
):
_run_generation_steps(
early_tasks = _run_generation_steps(
batch_engine,
event_sender,
device_rank,
@@ -617,6 +629,7 @@ def main(
gpt_oss_trackers,
thinking_first_token,
tokenizer,
task_receiver,
)
# Update runner status based on remaining work
if batch_engine.has_active_requests:
@@ -632,8 +645,8 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
# Non-blocking poll for new tasks
new_tasks = task_receiver.collect()
# Use early-collected tasks from generation loop, plus any additional
new_tasks = early_tasks + task_receiver.collect()
if not new_tasks and batch_engine.has_active_requests:
continue
else:
@@ -788,7 +801,6 @@ def main(
batch_engine.queue_request(
command_id, task.task_id, task_params
)
batch_engine.sync_and_insert_pending()
in_flight_tasks[command_id] = task.task_id
# GPT-OSS: use dedicated tracker (handles thinking + tool calls)
@@ -821,6 +833,7 @@ def main(
thinking_first_token[command_id] = False
current_status = RunnerRunning(
active_requests=batch_engine.active_count
+ batch_engine.pending_insert_count
)
event_sender.send(
RunnerStatusUpdated(
@@ -830,7 +843,7 @@ def main(
)
emit_task_completion = False
logger.info(
f"runner running with {batch_engine.active_count} active requests"
f"runner running with {batch_engine.active_count} active + {batch_engine.pending_insert_count} pending requests"
)
case ImageGeneration(
task_params=task_params, command_id=command_id
@@ -1020,6 +1033,11 @@ def main(
gc.collect()
break
else:
# Batch-insert all queued requests at once for efficient prefill.
# By deferring sync_and_insert_pending() to here (instead of per-task),
# all concurrently-arrived tasks share a single prefill pass.
if batch_engine is not None and batch_engine.has_pending_inserts:
batch_engine.sync_and_insert_pending()
continue
break

View File

@@ -102,6 +102,10 @@ class ScriptedBatchEngine:
def has_pending_inserts(self) -> bool:
return bool(self._pending)
@property
def pending_insert_count(self) -> int:
return len(self._pending)
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
done: list[int] = []
@@ -174,6 +178,10 @@ class FakeBatchEngineWithTokens:
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
done: list[int] = []

View File

@@ -200,6 +200,10 @@ class FakeBatchEngine:
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def active_count(self) -> int:
return len(self._active_requests)