mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-26 23:10:01 -05:00
Implement sync/async switching logic
This commit is contained in:
@@ -251,18 +251,10 @@ class DistributedFlux1:
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Single diffusion step with distributed communication.
|
||||
|
||||
Currently uses block wrappers for communication. Will be refactored
|
||||
to handle send/recv at this level for async pipeline.
|
||||
"""
|
||||
model = self._model
|
||||
|
||||
# Scale model input if needed by the scheduler
|
||||
latents = config.scheduler.scale_model_input(latents, t)
|
||||
|
||||
# Predict noise (communication happens in block wrappers)
|
||||
noise = model.transformer(
|
||||
t=t,
|
||||
config=config,
|
||||
@@ -271,7 +263,6 @@ class DistributedFlux1:
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# Apply scheduler step (denoising)
|
||||
latents = config.scheduler.step(
|
||||
model_output=noise,
|
||||
timestep=t,
|
||||
|
||||
@@ -5,8 +5,6 @@ from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx.utils_mlx import mx_barrier
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class DistributedTransformer:
|
||||
@@ -28,6 +26,7 @@ class DistributedTransformer:
|
||||
transformer: Transformer,
|
||||
group: mx.distributed.Group,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
num_sync_steps: int = 1,
|
||||
):
|
||||
self.transformer = transformer
|
||||
self.group = group
|
||||
@@ -36,6 +35,8 @@ class DistributedTransformer:
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.num_sync_steps = num_sync_steps
|
||||
|
||||
# Get block counts from the original transformer (before slicing)
|
||||
# Note: These are the ORIGINAL counts, not the sliced counts
|
||||
self.total_joint = 19 # Flux has 19 joint blocks
|
||||
@@ -99,23 +100,11 @@ class DistributedTransformer:
|
||||
|
||||
def _sync_pipeline(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
image_rotary_embeddings: mx.array,
|
||||
):
|
||||
# === PHASE 1: Create Embeddings (all stages compute, for consistency) ===
|
||||
hidden_states = self.transformer.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.transformer.context_embedder(prompt_embeds)
|
||||
text_embeddings = Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self.transformer.time_text_embed, config
|
||||
)
|
||||
image_rotary_embeddings = Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds, self.transformer.pos_embed, config, kontext_image_ids
|
||||
)
|
||||
|
||||
# === PHASE 2: Joint Blocks with Communication ===
|
||||
if self.has_joint_blocks:
|
||||
# Receive from previous stage (if not first stage)
|
||||
@@ -181,17 +170,79 @@ class DistributedTransformer:
|
||||
hidden_states, self.rank + 1, group=self.group
|
||||
)
|
||||
|
||||
# === PHASE 5: All-gather Final Output ===
|
||||
# All stages participate to receive the final output
|
||||
hidden_states = mx.distributed.all_gather(hidden_states, group=self.group)[
|
||||
-hidden_states.shape[0] :
|
||||
]
|
||||
return hidden_states
|
||||
|
||||
# === PHASE 6: Final Projection (last stage only) ===
|
||||
# Extract image portion (remove text embeddings prefix)
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
hidden_states = self.transformer.norm_out(hidden_states, text_embeddings)
|
||||
hidden_states = self.transformer.proj_out(hidden_states)
|
||||
def _async_pipeline(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
image_rotary_embeddings: mx.array,
|
||||
):
|
||||
# === PHASE 2: Joint Blocks with Communication ===
|
||||
if self.has_joint_blocks:
|
||||
# Receive from previous stage (if not first stage)
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
hidden_states, self.rank - 1, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv_like(
|
||||
encoder_hidden_states, self.rank - 1, group=self.group
|
||||
)
|
||||
|
||||
# Run assigned joint blocks
|
||||
for block in self.transformer.transformer_blocks:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
# === PHASE 3: Joint→Single Transition ===
|
||||
if self.owns_concat_stage:
|
||||
# Concatenate encoder and hidden states
|
||||
concatenated = mx.concatenate(
|
||||
[encoder_hidden_states, hidden_states], axis=1
|
||||
)
|
||||
|
||||
if self.has_single_blocks:
|
||||
# We continue with single blocks on this stage
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
# Send concatenated state to next stage (which has single blocks)
|
||||
mx.distributed.send(concatenated, self.rank + 1, group=self.group)
|
||||
# This stage is done with blocks, but will participate in all_gather
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
# Send joint block outputs to next stage (which has more joint blocks)
|
||||
mx.distributed.send(hidden_states, self.rank + 1, group=self.group)
|
||||
mx.distributed.send(encoder_hidden_states, self.rank + 1, group=self.group)
|
||||
|
||||
# === PHASE 4: Single Blocks with Communication ===
|
||||
if self.has_single_blocks:
|
||||
# Receive from previous stage if we didn't do concatenation
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
concatenated = mx.concatenate(
|
||||
[encoder_hidden_states, hidden_states], axis=1
|
||||
)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
concatenated, self.rank - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
# Run assigned single blocks
|
||||
for block in self.transformer.single_transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
# Send to next stage if not last
|
||||
if not self.is_last_stage:
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.rank + 1, group=self.group
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -206,14 +257,44 @@ class DistributedTransformer:
|
||||
controlnet_single_block_samples: list[mx.array] | None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
return self._sync_pipeline(
|
||||
t,
|
||||
config,
|
||||
hidden_states,
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
kontext_image_ids,
|
||||
# === PHASE 1: Create Embeddings (all stages compute, for consistency) ===
|
||||
hidden_states = self.transformer.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.transformer.context_embedder(prompt_embeds)
|
||||
text_embeddings = Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self.transformer.time_text_embed, config
|
||||
)
|
||||
image_rotary_embeddings = Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds, self.transformer.pos_embed, config, kontext_image_ids
|
||||
)
|
||||
|
||||
if t < self.num_sync_steps:
|
||||
hidden_states = self._sync_pipeline(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
text_embeddings,
|
||||
image_rotary_embeddings,
|
||||
)
|
||||
else:
|
||||
hidden_states = self._async_pipeline(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
text_embeddings,
|
||||
image_rotary_embeddings,
|
||||
)
|
||||
#
|
||||
# === PHASE 5: All-gather Final Output ===
|
||||
# All stages participate to receive the final output
|
||||
hidden_states = mx.distributed.all_gather(hidden_states, group=self.group)[
|
||||
-hidden_states.shape[0] :
|
||||
]
|
||||
|
||||
# === PHASE 6: Final Projection (last stage only) ===
|
||||
# Extract image portion (remove text embeddings prefix)
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
hidden_states = self.transformer.norm_out(hidden_states, text_embeddings)
|
||||
hidden_states = self.transformer.proj_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
# Delegate attribute access to the underlying transformer for compatibility
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
|
||||
Reference in New Issue
Block a user