mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-13 08:29:21 -05:00
Compare commits
1 Commits
main
...
evan/autop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfebf1f2fe |
@@ -10,18 +10,23 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.cache import (
|
||||
_BaseCache, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
)
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
@@ -69,6 +74,7 @@ class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
||||
# mx.eval(x)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
@@ -91,8 +97,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
@@ -100,7 +104,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# This change happened upstream - check out mlx github somewhere??
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
@@ -132,24 +135,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
return layers
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(model, DeepseekV3Model) and hasattr(
|
||||
inner_model_instance, "num_layers"
|
||||
):
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
def pipeline_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
@@ -165,8 +150,7 @@ def pipeline_auto_parallel(
|
||||
"""
|
||||
inner_model_instance: nn.Module = _inner_model(model)
|
||||
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
|
||||
layers = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
@@ -180,6 +164,17 @@ def pipeline_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
assert isinstance(layers, list), (
|
||||
@@ -204,18 +199,44 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
SEGMENTS: int = 1
|
||||
|
||||
def _all_to_sharded(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - all to sharded")
|
||||
return weight.ndim - 1, SEGMENTS
|
||||
return max(weight.ndim - 2, 0), SEGMENTS
|
||||
|
||||
all_to_sharded_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding="all-to-sharded",
|
||||
group=group,
|
||||
)
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding="sharded-to-all",
|
||||
sharding=_all_to_sharded, # type: ignore
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(model, LlamaModel):
|
||||
N = group.size()
|
||||
|
||||
def _sharded_to_all(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - sharded to all")
|
||||
weight /= N
|
||||
return None
|
||||
return -1, SEGMENTS
|
||||
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding=_sharded_to_all, # type: ignore
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return model
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -223,7 +244,8 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, DeepseekV3Model):
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -231,7 +253,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -239,6 +261,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -284,6 +315,32 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)) and hasattr(
|
||||
inner_model_instance, "num_layers"
|
||||
):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
logger.info(
|
||||
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.num_hidden_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
@@ -304,7 +361,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, DeepseekV3MLP):
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
@@ -352,7 +409,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
@@ -380,3 +437,52 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
layer.self_attn.num_key_value_groups = (
|
||||
layer.self_attn.num_attention_heads
|
||||
// layer.self_attn.num_key_value_heads
|
||||
)
|
||||
|
||||
layer.self_attn.sinks = layer.self_attn.sinks[
|
||||
layer.self_attn.num_attention_heads
|
||||
* self.group.rank() : layer.self_attn.num_attention_heads
|
||||
* (self.group.rank() + 1)
|
||||
]
|
||||
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user