mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-16 17:11:57 -05:00
fix: defer batch prefill for true continuous batching
Move sync_and_insert_pending() out of the per-task loop so all concurrently-arrived requests share a single batched prefill pass. Previously each request was prefilled individually as it arrived, serializing what should be a parallel operation. Also break early from the generation TimeBudget loop when new tasks are waiting, so they get inserted sooner rather than blocking for the full 0.5s budget. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -454,8 +454,13 @@ def _run_generation_steps(
|
||||
gpt_oss_trackers: dict[CommandId, GptOssTracker],
|
||||
thinking_first_token: dict[CommandId, bool],
|
||||
tokenizer: TokenizerWrapper | None,
|
||||
) -> None:
|
||||
"""Run one TimeBudget cycle of batch generation."""
|
||||
task_receiver: MpReceiver[Task] | None = None,
|
||||
) -> list[Task]:
|
||||
"""Run one TimeBudget cycle of batch generation.
|
||||
|
||||
Returns any tasks collected via early exit so the caller can process them.
|
||||
"""
|
||||
early_tasks: list[Task] = []
|
||||
if batch_engine.has_pending_inserts:
|
||||
batch_engine.sync_and_insert_pending()
|
||||
for _ in TimeBudget(budget=0.5, group=group): # pyright: ignore[reportArgumentType]
|
||||
@@ -499,7 +504,14 @@ def _run_generation_steps(
|
||||
thinking_first_token,
|
||||
tokenizer,
|
||||
)
|
||||
# After processing a step, check if new tasks arrived so they can be
|
||||
# batched sooner rather than waiting for the full time budget.
|
||||
if task_receiver is not None and not early_tasks:
|
||||
early_tasks = task_receiver.collect()
|
||||
if early_tasks:
|
||||
break
|
||||
batch_engine.sync_completions()
|
||||
return early_tasks
|
||||
|
||||
|
||||
def _drain_batch_engine(
|
||||
@@ -605,7 +617,7 @@ def main(
|
||||
if batch_engine is not None and (
|
||||
batch_engine.has_active_requests or batch_engine.has_pending_inserts
|
||||
):
|
||||
_run_generation_steps(
|
||||
early_tasks = _run_generation_steps(
|
||||
batch_engine,
|
||||
event_sender,
|
||||
device_rank,
|
||||
@@ -617,6 +629,7 @@ def main(
|
||||
gpt_oss_trackers,
|
||||
thinking_first_token,
|
||||
tokenizer,
|
||||
task_receiver,
|
||||
)
|
||||
# Update runner status based on remaining work
|
||||
if batch_engine.has_active_requests:
|
||||
@@ -632,8 +645,8 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
# Non-blocking poll for new tasks
|
||||
new_tasks = task_receiver.collect()
|
||||
# Use early-collected tasks from generation loop, plus any additional
|
||||
new_tasks = early_tasks + task_receiver.collect()
|
||||
if not new_tasks and batch_engine.has_active_requests:
|
||||
continue
|
||||
else:
|
||||
@@ -788,7 +801,6 @@ def main(
|
||||
batch_engine.queue_request(
|
||||
command_id, task.task_id, task_params
|
||||
)
|
||||
batch_engine.sync_and_insert_pending()
|
||||
in_flight_tasks[command_id] = task.task_id
|
||||
|
||||
# GPT-OSS: use dedicated tracker (handles thinking + tool calls)
|
||||
@@ -821,6 +833,7 @@ def main(
|
||||
thinking_first_token[command_id] = False
|
||||
current_status = RunnerRunning(
|
||||
active_requests=batch_engine.active_count
|
||||
+ batch_engine.pending_insert_count
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
@@ -830,7 +843,7 @@ def main(
|
||||
)
|
||||
emit_task_completion = False
|
||||
logger.info(
|
||||
f"runner running with {batch_engine.active_count} active requests"
|
||||
f"runner running with {batch_engine.active_count} active + {batch_engine.pending_insert_count} pending requests"
|
||||
)
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
@@ -1020,6 +1033,11 @@ def main(
|
||||
gc.collect()
|
||||
break
|
||||
else:
|
||||
# Batch-insert all queued requests at once for efficient prefill.
|
||||
# By deferring sync_and_insert_pending() to here (instead of per-task),
|
||||
# all concurrently-arrived tasks share a single prefill pass.
|
||||
if batch_engine is not None and batch_engine.has_pending_inserts:
|
||||
batch_engine.sync_and_insert_pending()
|
||||
continue
|
||||
break
|
||||
|
||||
|
||||
@@ -102,6 +102,10 @@ class ScriptedBatchEngine:
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending)
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
done: list[int] = []
|
||||
@@ -174,6 +178,10 @@ class FakeBatchEngineWithTokens:
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
done: list[int] = []
|
||||
|
||||
@@ -200,6 +200,10 @@ class FakeBatchEngine:
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
Reference in New Issue
Block a user