diff --git a/src/exo/worker/runner/batched_handler.py b/src/exo/worker/runner/batched_handler.py index 71335a71..9d4010b1 100644 --- a/src/exo/worker/runner/batched_handler.py +++ b/src/exo/worker/runner/batched_handler.py @@ -192,13 +192,42 @@ class BatchedInferenceHandler: def flush(self) -> None: """Start processing pending requests by adding them to the batch/pipelined generator.""" # For tensor parallel: sync pending count across all devices - # Rank 0 broadcasts its count, all devices flush that many + # Use minimum so all devices can fulfill the batch if self.tensor_parallel_group is not None: - local_count = len(self.pending) if self.device_rank == 0 else 0 + local_count = len(self.pending) + # Negate, sum, negate to get min (all_sum gives sum, so -sum(-x) = min if we had min) + # Actually use all_gather approach: negate so min becomes max after negation + neg_count = mx.array([-local_count], dtype=mx.int32) + neg_sum = mx.distributed.all_sum(neg_count, group=self.tensor_parallel_group) + mx.eval(neg_sum) + # This gives us -sum of counts, not min. We need a different approach. + # Since MLX doesn't have all_reduce with min, we use: each node reports count, + # then we take the min locally after gathering. + # Simpler: just use local count but ensure all nodes wait for tasks to arrive + + # Actually, the right fix: wait until all nodes have the same pending count + # by checking if our count matches what we'll flush count_arr = mx.array([local_count], dtype=mx.int32) - synced = mx.distributed.all_sum(count_arr, group=self.tensor_parallel_group) - mx.eval(synced) - num_pending = int(synced.item()) + total = mx.distributed.all_sum(count_arr, group=self.tensor_parallel_group) + mx.eval(total) + total_count = int(total.item()) + + # Get world size from group + tp_world_size = self.tensor_parallel_group.size() + + # If not all nodes have same count, use minimum (total/world_size if equal, else min) + # For now, assume all nodes have same tasks (they receive same pubsub messages) + # Just verify and log warning if mismatch + expected_total = local_count * tp_world_size + if total_count != expected_total: + logger.warning( + f"Tensor parallel pending mismatch: local={local_count}, total={total_count}, expected={expected_total}" + ) + # Use the minimum possible: floor(total/world_size) + num_pending = total_count // tp_world_size + else: + num_pending = local_count + if num_pending == 0: return available_slots = self.max_batch_size - self.current_batch_size