fix single block receive shape

This commit is contained in:
ciaranbor
2025-12-05 12:57:05 +00:00
parent 9f5f763993
commit dd2d25951d

View File

@@ -77,7 +77,7 @@ class DistributedTransformer:
# This happens when we have joint blocks and either:
# - We also have single blocks (transition within this stage), or
# - We're the last stage to have joint blocks and next stage has single
self.is_concat_stage = self.has_joint_blocks and (
self.owns_concat_stage = self.has_joint_blocks and (
self.has_single_blocks or self.end_layer == self.total_joint
)
@@ -138,7 +138,7 @@ class DistributedTransformer:
)
# === PHASE 3: Joint→Single Transition ===
if self.is_concat_stage:
if self.owns_concat_stage:
logger.info("concatenating")
# Concatenate encoder and hidden states
concatenated = mx.concatenate(
@@ -161,7 +161,10 @@ class DistributedTransformer:
# === PHASE 4: Single Blocks with Communication ===
if self.has_single_blocks:
# Receive from previous stage if we didn't do concatenation
if not self.is_concat_stage and not self.is_first_stage:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = mx.concatenate(
[encoder_hidden_states, hidden_states], axis=1
)
logger.info(f"receiving single block inputs: {hidden_states.shape}")
hidden_states = mx.distributed.recv_like(
hidden_states, self.rank - 1, group=self.group