mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 14:55:13 -05:00
fix single block receive shape
This commit is contained in:
@@ -77,7 +77,7 @@ class DistributedTransformer:
|
||||
# This happens when we have joint blocks and either:
|
||||
# - We also have single blocks (transition within this stage), or
|
||||
# - We're the last stage to have joint blocks and next stage has single
|
||||
self.is_concat_stage = self.has_joint_blocks and (
|
||||
self.owns_concat_stage = self.has_joint_blocks and (
|
||||
self.has_single_blocks or self.end_layer == self.total_joint
|
||||
)
|
||||
|
||||
@@ -138,7 +138,7 @@ class DistributedTransformer:
|
||||
)
|
||||
|
||||
# === PHASE 3: Joint→Single Transition ===
|
||||
if self.is_concat_stage:
|
||||
if self.owns_concat_stage:
|
||||
logger.info("concatenating")
|
||||
# Concatenate encoder and hidden states
|
||||
concatenated = mx.concatenate(
|
||||
@@ -161,7 +161,10 @@ class DistributedTransformer:
|
||||
# === PHASE 4: Single Blocks with Communication ===
|
||||
if self.has_single_blocks:
|
||||
# Receive from previous stage if we didn't do concatenation
|
||||
if not self.is_concat_stage and not self.is_first_stage:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = mx.concatenate(
|
||||
[encoder_hidden_states, hidden_states], axis=1
|
||||
)
|
||||
logger.info(f"receiving single block inputs: {hidden_states.shape}")
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
hidden_states, self.rank - 1, group=self.group
|
||||
|
||||
Reference in New Issue
Block a user