Implement sync/async switching logic

This commit is contained in:
ciaranbor
2025-12-08 11:12:57 +00:00
parent cb5ce5a130
commit 64bc62eecc
2 changed files with 115 additions and 43 deletions

View File

@@ -251,18 +251,10 @@ class DistributedFlux1:
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
) -> mx.array:
"""
Single diffusion step with distributed communication.
Currently uses block wrappers for communication. Will be refactored
to handle send/recv at this level for async pipeline.
"""
model = self._model
# Scale model input if needed by the scheduler
latents = config.scheduler.scale_model_input(latents, t)
# Predict noise (communication happens in block wrappers)
noise = model.transformer(
t=t,
config=config,
@@ -271,7 +263,6 @@ class DistributedFlux1:
pooled_prompt_embeds=pooled_prompt_embeds,
)
# Apply scheduler step (denoising)
latents = config.scheduler.step(
model_output=noise,
timestep=t,

View File

@@ -5,8 +5,6 @@ from mflux.config.runtime_config import RuntimeConfig
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.utils_mlx import mx_barrier
from exo.worker.runner.bootstrap import logger
class DistributedTransformer:
@@ -28,6 +26,7 @@ class DistributedTransformer:
transformer: Transformer,
group: mx.distributed.Group,
shard_metadata: PipelineShardMetadata,
num_sync_steps: int = 1,
):
self.transformer = transformer
self.group = group
@@ -36,6 +35,8 @@ class DistributedTransformer:
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_sync_steps = num_sync_steps
# Get block counts from the original transformer (before slicing)
# Note: These are the ORIGINAL counts, not the sliced counts
self.total_joint = 19 # Flux has 19 joint blocks
@@ -99,23 +100,11 @@ class DistributedTransformer:
def _sync_pipeline(
self,
t: int,
config: RuntimeConfig,
hidden_states: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
kontext_image_ids: mx.array | None = None,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
image_rotary_embeddings: mx.array,
):
# === PHASE 1: Create Embeddings (all stages compute, for consistency) ===
hidden_states = self.transformer.x_embedder(hidden_states)
encoder_hidden_states = self.transformer.context_embedder(prompt_embeds)
text_embeddings = Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self.transformer.time_text_embed, config
)
image_rotary_embeddings = Transformer.compute_rotary_embeddings(
prompt_embeds, self.transformer.pos_embed, config, kontext_image_ids
)
# === PHASE 2: Joint Blocks with Communication ===
if self.has_joint_blocks:
# Receive from previous stage (if not first stage)
@@ -181,17 +170,79 @@ class DistributedTransformer:
hidden_states, self.rank + 1, group=self.group
)
# === PHASE 5: All-gather Final Output ===
# All stages participate to receive the final output
hidden_states = mx.distributed.all_gather(hidden_states, group=self.group)[
-hidden_states.shape[0] :
]
return hidden_states
# === PHASE 6: Final Projection (last stage only) ===
# Extract image portion (remove text embeddings prefix)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.transformer.norm_out(hidden_states, text_embeddings)
hidden_states = self.transformer.proj_out(hidden_states)
def _async_pipeline(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
image_rotary_embeddings: mx.array,
):
# === PHASE 2: Joint Blocks with Communication ===
if self.has_joint_blocks:
# Receive from previous stage (if not first stage)
if not self.is_first_stage:
hidden_states = mx.distributed.recv_like(
hidden_states, self.rank - 1, group=self.group
)
encoder_hidden_states = mx.distributed.recv_like(
encoder_hidden_states, self.rank - 1, group=self.group
)
# Run assigned joint blocks
for block in self.transformer.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
# === PHASE 3: Joint→Single Transition ===
if self.owns_concat_stage:
# Concatenate encoder and hidden states
concatenated = mx.concatenate(
[encoder_hidden_states, hidden_states], axis=1
)
if self.has_single_blocks:
# We continue with single blocks on this stage
hidden_states = concatenated
else:
# Send concatenated state to next stage (which has single blocks)
mx.distributed.send(concatenated, self.rank + 1, group=self.group)
# This stage is done with blocks, but will participate in all_gather
elif self.has_joint_blocks and not self.is_last_stage:
# Send joint block outputs to next stage (which has more joint blocks)
mx.distributed.send(hidden_states, self.rank + 1, group=self.group)
mx.distributed.send(encoder_hidden_states, self.rank + 1, group=self.group)
# === PHASE 4: Single Blocks with Communication ===
if self.has_single_blocks:
# Receive from previous stage if we didn't do concatenation
if not self.owns_concat_stage and not self.is_first_stage:
concatenated = mx.concatenate(
[encoder_hidden_states, hidden_states], axis=1
)
hidden_states = mx.distributed.recv_like(
concatenated, self.rank - 1, group=self.group
)
mx.eval(hidden_states)
# Run assigned single blocks
for block in self.transformer.single_transformer_blocks:
hidden_states = block(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
# Send to next stage if not last
if not self.is_last_stage:
hidden_states = mx.distributed.send(
hidden_states, self.rank + 1, group=self.group
)
return hidden_states
@@ -206,14 +257,44 @@ class DistributedTransformer:
controlnet_single_block_samples: list[mx.array] | None = None,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
return self._sync_pipeline(
t,
config,
hidden_states,
prompt_embeds,
pooled_prompt_embeds,
kontext_image_ids,
# === PHASE 1: Create Embeddings (all stages compute, for consistency) ===
hidden_states = self.transformer.x_embedder(hidden_states)
encoder_hidden_states = self.transformer.context_embedder(prompt_embeds)
text_embeddings = Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self.transformer.time_text_embed, config
)
image_rotary_embeddings = Transformer.compute_rotary_embeddings(
prompt_embeds, self.transformer.pos_embed, config, kontext_image_ids
)
if t < self.num_sync_steps:
hidden_states = self._sync_pipeline(
hidden_states,
encoder_hidden_states,
text_embeddings,
image_rotary_embeddings,
)
else:
hidden_states = self._async_pipeline(
hidden_states,
encoder_hidden_states,
text_embeddings,
image_rotary_embeddings,
)
#
# === PHASE 5: All-gather Final Output ===
# All stages participate to receive the final output
hidden_states = mx.distributed.all_gather(hidden_states, group=self.group)[
-hidden_states.shape[0] :
]
# === PHASE 6: Final Projection (last stage only) ===
# Extract image portion (remove text embeddings prefix)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.transformer.norm_out(hidden_states, text_embeddings)
hidden_states = self.transformer.proj_out(hidden_states)
return hidden_states
# Delegate attribute access to the underlying transformer for compatibility
def __getattr__(self, name: str) -> Any: