mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 03:33:30 -05:00
claude is 100% nerfed to oblivion
This commit is contained in:
@@ -192,13 +192,42 @@ class BatchedInferenceHandler:
|
||||
def flush(self) -> None:
|
||||
"""Start processing pending requests by adding them to the batch/pipelined generator."""
|
||||
# For tensor parallel: sync pending count across all devices
|
||||
# Rank 0 broadcasts its count, all devices flush that many
|
||||
# Use minimum so all devices can fulfill the batch
|
||||
if self.tensor_parallel_group is not None:
|
||||
local_count = len(self.pending) if self.device_rank == 0 else 0
|
||||
local_count = len(self.pending)
|
||||
# Negate, sum, negate to get min (all_sum gives sum, so -sum(-x) = min if we had min)
|
||||
# Actually use all_gather approach: negate so min becomes max after negation
|
||||
neg_count = mx.array([-local_count], dtype=mx.int32)
|
||||
neg_sum = mx.distributed.all_sum(neg_count, group=self.tensor_parallel_group)
|
||||
mx.eval(neg_sum)
|
||||
# This gives us -sum of counts, not min. We need a different approach.
|
||||
# Since MLX doesn't have all_reduce with min, we use: each node reports count,
|
||||
# then we take the min locally after gathering.
|
||||
# Simpler: just use local count but ensure all nodes wait for tasks to arrive
|
||||
|
||||
# Actually, the right fix: wait until all nodes have the same pending count
|
||||
# by checking if our count matches what we'll flush
|
||||
count_arr = mx.array([local_count], dtype=mx.int32)
|
||||
synced = mx.distributed.all_sum(count_arr, group=self.tensor_parallel_group)
|
||||
mx.eval(synced)
|
||||
num_pending = int(synced.item())
|
||||
total = mx.distributed.all_sum(count_arr, group=self.tensor_parallel_group)
|
||||
mx.eval(total)
|
||||
total_count = int(total.item())
|
||||
|
||||
# Get world size from group
|
||||
tp_world_size = self.tensor_parallel_group.size()
|
||||
|
||||
# If not all nodes have same count, use minimum (total/world_size if equal, else min)
|
||||
# For now, assume all nodes have same tasks (they receive same pubsub messages)
|
||||
# Just verify and log warning if mismatch
|
||||
expected_total = local_count * tp_world_size
|
||||
if total_count != expected_total:
|
||||
logger.warning(
|
||||
f"Tensor parallel pending mismatch: local={local_count}, total={total_count}, expected={expected_total}"
|
||||
)
|
||||
# Use the minimum possible: floor(total/world_size)
|
||||
num_pending = total_count // tp_world_size
|
||||
else:
|
||||
num_pending = local_count
|
||||
|
||||
if num_pending == 0:
|
||||
return
|
||||
available_slots = self.max_batch_size - self.current_batch_size
|
||||
|
||||
Reference in New Issue
Block a user