From 64bc62eecc840024565aece7fa30fbb84cfd2ee4 Mon Sep 17 00:00:00 2001 From: ciaranbor Date: Mon, 8 Dec 2025 11:12:57 +0000 Subject: [PATCH] Implement sync/async switching logic --- .../worker/engines/mflux/distributed_flux.py | 9 -- .../pipefusion/distributed_transformer.py | 149 ++++++++++++++---- 2 files changed, 115 insertions(+), 43 deletions(-) diff --git a/src/exo/worker/engines/mflux/distributed_flux.py b/src/exo/worker/engines/mflux/distributed_flux.py index bfed8007..9620d23d 100644 --- a/src/exo/worker/engines/mflux/distributed_flux.py +++ b/src/exo/worker/engines/mflux/distributed_flux.py @@ -251,18 +251,10 @@ class DistributedFlux1: prompt_embeds: mx.array, pooled_prompt_embeds: mx.array, ) -> mx.array: - """ - Single diffusion step with distributed communication. - - Currently uses block wrappers for communication. Will be refactored - to handle send/recv at this level for async pipeline. - """ model = self._model - # Scale model input if needed by the scheduler latents = config.scheduler.scale_model_input(latents, t) - # Predict noise (communication happens in block wrappers) noise = model.transformer( t=t, config=config, @@ -271,7 +263,6 @@ class DistributedFlux1: pooled_prompt_embeds=pooled_prompt_embeds, ) - # Apply scheduler step (denoising) latents = config.scheduler.step( model_output=noise, timestep=t, diff --git a/src/exo/worker/engines/mflux/pipefusion/distributed_transformer.py b/src/exo/worker/engines/mflux/pipefusion/distributed_transformer.py index 30ca4e3b..d632a741 100644 --- a/src/exo/worker/engines/mflux/pipefusion/distributed_transformer.py +++ b/src/exo/worker/engines/mflux/pipefusion/distributed_transformer.py @@ -5,8 +5,6 @@ from mflux.config.runtime_config import RuntimeConfig from mflux.models.flux.model.flux_transformer.transformer import Transformer from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.engines.mlx.utils_mlx import mx_barrier -from exo.worker.runner.bootstrap import logger class DistributedTransformer: @@ -28,6 +26,7 @@ class DistributedTransformer: transformer: Transformer, group: mx.distributed.Group, shard_metadata: PipelineShardMetadata, + num_sync_steps: int = 1, ): self.transformer = transformer self.group = group @@ -36,6 +35,8 @@ class DistributedTransformer: self.start_layer = shard_metadata.start_layer self.end_layer = shard_metadata.end_layer + self.num_sync_steps = num_sync_steps + # Get block counts from the original transformer (before slicing) # Note: These are the ORIGINAL counts, not the sliced counts self.total_joint = 19 # Flux has 19 joint blocks @@ -99,23 +100,11 @@ class DistributedTransformer: def _sync_pipeline( self, - t: int, - config: RuntimeConfig, hidden_states: mx.array, - prompt_embeds: mx.array, - pooled_prompt_embeds: mx.array, - kontext_image_ids: mx.array | None = None, + encoder_hidden_states: mx.array, + text_embeddings: mx.array, + image_rotary_embeddings: mx.array, ): - # === PHASE 1: Create Embeddings (all stages compute, for consistency) === - hidden_states = self.transformer.x_embedder(hidden_states) - encoder_hidden_states = self.transformer.context_embedder(prompt_embeds) - text_embeddings = Transformer.compute_text_embeddings( - t, pooled_prompt_embeds, self.transformer.time_text_embed, config - ) - image_rotary_embeddings = Transformer.compute_rotary_embeddings( - prompt_embeds, self.transformer.pos_embed, config, kontext_image_ids - ) - # === PHASE 2: Joint Blocks with Communication === if self.has_joint_blocks: # Receive from previous stage (if not first stage) @@ -181,17 +170,79 @@ class DistributedTransformer: hidden_states, self.rank + 1, group=self.group ) - # === PHASE 5: All-gather Final Output === - # All stages participate to receive the final output - hidden_states = mx.distributed.all_gather(hidden_states, group=self.group)[ - -hidden_states.shape[0] : - ] + return hidden_states - # === PHASE 6: Final Projection (last stage only) === - # Extract image portion (remove text embeddings prefix) - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - hidden_states = self.transformer.norm_out(hidden_states, text_embeddings) - hidden_states = self.transformer.proj_out(hidden_states) + def _async_pipeline( + self, + hidden_states: mx.array, + encoder_hidden_states: mx.array, + text_embeddings: mx.array, + image_rotary_embeddings: mx.array, + ): + # === PHASE 2: Joint Blocks with Communication === + if self.has_joint_blocks: + # Receive from previous stage (if not first stage) + if not self.is_first_stage: + hidden_states = mx.distributed.recv_like( + hidden_states, self.rank - 1, group=self.group + ) + encoder_hidden_states = mx.distributed.recv_like( + encoder_hidden_states, self.rank - 1, group=self.group + ) + + # Run assigned joint blocks + for block in self.transformer.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + text_embeddings=text_embeddings, + rotary_embeddings=image_rotary_embeddings, + ) + + # === PHASE 3: Joint→Single Transition === + if self.owns_concat_stage: + # Concatenate encoder and hidden states + concatenated = mx.concatenate( + [encoder_hidden_states, hidden_states], axis=1 + ) + + if self.has_single_blocks: + # We continue with single blocks on this stage + hidden_states = concatenated + else: + # Send concatenated state to next stage (which has single blocks) + mx.distributed.send(concatenated, self.rank + 1, group=self.group) + # This stage is done with blocks, but will participate in all_gather + elif self.has_joint_blocks and not self.is_last_stage: + # Send joint block outputs to next stage (which has more joint blocks) + mx.distributed.send(hidden_states, self.rank + 1, group=self.group) + mx.distributed.send(encoder_hidden_states, self.rank + 1, group=self.group) + + # === PHASE 4: Single Blocks with Communication === + if self.has_single_blocks: + # Receive from previous stage if we didn't do concatenation + if not self.owns_concat_stage and not self.is_first_stage: + concatenated = mx.concatenate( + [encoder_hidden_states, hidden_states], axis=1 + ) + hidden_states = mx.distributed.recv_like( + concatenated, self.rank - 1, group=self.group + ) + mx.eval(hidden_states) + + # Run assigned single blocks + for block in self.transformer.single_transformer_blocks: + hidden_states = block( + hidden_states=hidden_states, + text_embeddings=text_embeddings, + rotary_embeddings=image_rotary_embeddings, + ) + + # Send to next stage if not last + if not self.is_last_stage: + hidden_states = mx.distributed.send( + hidden_states, self.rank + 1, group=self.group + ) return hidden_states @@ -206,14 +257,44 @@ class DistributedTransformer: controlnet_single_block_samples: list[mx.array] | None = None, kontext_image_ids: mx.array | None = None, ) -> mx.array: - return self._sync_pipeline( - t, - config, - hidden_states, - prompt_embeds, - pooled_prompt_embeds, - kontext_image_ids, + # === PHASE 1: Create Embeddings (all stages compute, for consistency) === + hidden_states = self.transformer.x_embedder(hidden_states) + encoder_hidden_states = self.transformer.context_embedder(prompt_embeds) + text_embeddings = Transformer.compute_text_embeddings( + t, pooled_prompt_embeds, self.transformer.time_text_embed, config ) + image_rotary_embeddings = Transformer.compute_rotary_embeddings( + prompt_embeds, self.transformer.pos_embed, config, kontext_image_ids + ) + + if t < self.num_sync_steps: + hidden_states = self._sync_pipeline( + hidden_states, + encoder_hidden_states, + text_embeddings, + image_rotary_embeddings, + ) + else: + hidden_states = self._async_pipeline( + hidden_states, + encoder_hidden_states, + text_embeddings, + image_rotary_embeddings, + ) + # + # === PHASE 5: All-gather Final Output === + # All stages participate to receive the final output + hidden_states = mx.distributed.all_gather(hidden_states, group=self.group)[ + -hidden_states.shape[0] : + ] + + # === PHASE 6: Final Projection (last stage only) === + # Extract image portion (remove text embeddings prefix) + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + hidden_states = self.transformer.norm_out(hidden_states, text_embeddings) + hidden_states = self.transformer.proj_out(hidden_states) + + return hidden_states # Delegate attribute access to the underlying transformer for compatibility def __getattr__(self, name: str) -> Any: