claude is 100% nerfed to oblivion

This commit is contained in:
Ryuichi Leo Takashige
2026-02-02 23:07:40 +00:00
parent 163bb83195
commit 04cc92a97f

View File

@@ -192,13 +192,42 @@ class BatchedInferenceHandler:
def flush(self) -> None:
"""Start processing pending requests by adding them to the batch/pipelined generator."""
# For tensor parallel: sync pending count across all devices
# Rank 0 broadcasts its count, all devices flush that many
# Use minimum so all devices can fulfill the batch
if self.tensor_parallel_group is not None:
local_count = len(self.pending) if self.device_rank == 0 else 0
local_count = len(self.pending)
# Negate, sum, negate to get min (all_sum gives sum, so -sum(-x) = min if we had min)
# Actually use all_gather approach: negate so min becomes max after negation
neg_count = mx.array([-local_count], dtype=mx.int32)
neg_sum = mx.distributed.all_sum(neg_count, group=self.tensor_parallel_group)
mx.eval(neg_sum)
# This gives us -sum of counts, not min. We need a different approach.
# Since MLX doesn't have all_reduce with min, we use: each node reports count,
# then we take the min locally after gathering.
# Simpler: just use local count but ensure all nodes wait for tasks to arrive
# Actually, the right fix: wait until all nodes have the same pending count
# by checking if our count matches what we'll flush
count_arr = mx.array([local_count], dtype=mx.int32)
synced = mx.distributed.all_sum(count_arr, group=self.tensor_parallel_group)
mx.eval(synced)
num_pending = int(synced.item())
total = mx.distributed.all_sum(count_arr, group=self.tensor_parallel_group)
mx.eval(total)
total_count = int(total.item())
# Get world size from group
tp_world_size = self.tensor_parallel_group.size()
# If not all nodes have same count, use minimum (total/world_size if equal, else min)
# For now, assume all nodes have same tasks (they receive same pubsub messages)
# Just verify and log warning if mismatch
expected_total = local_count * tp_world_size
if total_count != expected_total:
logger.warning(
f"Tensor parallel pending mismatch: local={local_count}, total={total_count}, expected={expected_total}"
)
# Use the minimum possible: floor(total/world_size)
num_pending = total_count // tp_world_size
else:
num_pending = local_count
if num_pending == 0:
return
available_slots = self.max_batch_size - self.current_batch_size