diff --git a/src/exo/worker/engines/mflux/pipefusion/distributed_transformer.py b/src/exo/worker/engines/mflux/pipefusion/distributed_transformer.py index 7440b7885..28edb98d5 100644 --- a/src/exo/worker/engines/mflux/pipefusion/distributed_transformer.py +++ b/src/exo/worker/engines/mflux/pipefusion/distributed_transformer.py @@ -77,7 +77,7 @@ class DistributedTransformer: # This happens when we have joint blocks and either: # - We also have single blocks (transition within this stage), or # - We're the last stage to have joint blocks and next stage has single - self.is_concat_stage = self.has_joint_blocks and ( + self.owns_concat_stage = self.has_joint_blocks and ( self.has_single_blocks or self.end_layer == self.total_joint ) @@ -138,7 +138,7 @@ class DistributedTransformer: ) # === PHASE 3: Joint→Single Transition === - if self.is_concat_stage: + if self.owns_concat_stage: logger.info("concatenating") # Concatenate encoder and hidden states concatenated = mx.concatenate( @@ -161,7 +161,10 @@ class DistributedTransformer: # === PHASE 4: Single Blocks with Communication === if self.has_single_blocks: # Receive from previous stage if we didn't do concatenation - if not self.is_concat_stage and not self.is_first_stage: + if not self.owns_concat_stage and not self.is_first_stage: + hidden_states = mx.concatenate( + [encoder_hidden_states, hidden_states], axis=1 + ) logger.info(f"receiving single block inputs: {hidden_states.shape}") hidden_states = mx.distributed.recv_like( hidden_states, self.rank - 1, group=self.group