|
|
|
@@ -10,18 +10,23 @@ from mlx.nn.layers.distributed import (
|
|
|
|
|
shard_linear,
|
|
|
|
|
sum_gradients,
|
|
|
|
|
)
|
|
|
|
|
from mlx_lm.models.cache import (
|
|
|
|
|
_BaseCache, # pyright: ignore[reportPrivateUsage]
|
|
|
|
|
)
|
|
|
|
|
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
|
|
|
|
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
|
|
|
|
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
|
|
|
|
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
|
|
|
|
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
|
|
|
|
from mlx_lm.models.glm4_moe import MoE
|
|
|
|
|
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
|
|
|
|
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
|
|
|
|
from mlx_lm.models.llama import Model as LlamaModel
|
|
|
|
|
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
|
|
|
|
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
|
|
|
|
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
|
|
|
|
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
|
|
|
|
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
|
|
|
|
|
|
|
|
|
from exo.shared.types.worker.shards import (
|
|
|
|
|
PipelineShardMetadata,
|
|
|
|
|
)
|
|
|
|
|
from exo.shared.logging import logger
|
|
|
|
|
from exo.shared.types.worker.shards import PipelineShardMetadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _LayerCallable(Protocol):
|
|
|
|
@@ -91,8 +96,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
|
|
|
|
x, *args, **kwargs
|
|
|
|
|
).arguments.get("cache", None)
|
|
|
|
|
|
|
|
|
|
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
|
|
|
|
|
|
|
|
|
|
output: mx.array = self.original_layer(x, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
if self.r != self.s - 1:
|
|
|
|
@@ -100,7 +103,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
|
|
|
|
output, (self.r + 1) % self.s, group=self.group
|
|
|
|
|
)
|
|
|
|
|
if cache is not None:
|
|
|
|
|
# This change happened upstream - check out mlx github somewhere??
|
|
|
|
|
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
|
|
|
|
|
|
|
|
|
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
|
|
|
@@ -132,24 +134,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
|
|
|
|
return layers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
|
|
|
|
inner_model_instance = _inner_model(model)
|
|
|
|
|
if hasattr(inner_model_instance, "layers"):
|
|
|
|
|
inner_model_instance.layers = layers
|
|
|
|
|
|
|
|
|
|
# Update DeepSeek V3 specific parameters when layers are shrunk
|
|
|
|
|
if isinstance(model, DeepseekV3Model) and hasattr(
|
|
|
|
|
inner_model_instance, "num_layers"
|
|
|
|
|
):
|
|
|
|
|
inner_model_instance.start_idx = 0
|
|
|
|
|
inner_model_instance.end_idx = len(layers)
|
|
|
|
|
inner_model_instance.num_layers = len(layers)
|
|
|
|
|
elif hasattr(inner_model_instance, "h"):
|
|
|
|
|
inner_model_instance.h = layers
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pipeline_auto_parallel(
|
|
|
|
|
model: nn.Module,
|
|
|
|
|
group: mx.distributed.Group,
|
|
|
|
@@ -165,8 +149,7 @@ def pipeline_auto_parallel(
|
|
|
|
|
"""
|
|
|
|
|
inner_model_instance: nn.Module = _inner_model(model)
|
|
|
|
|
|
|
|
|
|
# Handle both model.layers and model.h cases
|
|
|
|
|
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
|
|
|
|
|
layers = _get_layers(inner_model_instance)
|
|
|
|
|
|
|
|
|
|
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
|
|
|
|
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
|
|
|
@@ -180,6 +163,17 @@ def pipeline_auto_parallel(
|
|
|
|
|
group=group,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(inner_model_instance, GptOssMoeModel):
|
|
|
|
|
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
|
|
|
|
start_layer:end_layer
|
|
|
|
|
]
|
|
|
|
|
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
|
|
|
|
"sliding_attention"
|
|
|
|
|
)
|
|
|
|
|
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
|
|
|
|
"full_attention"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_set_layers(model, layers)
|
|
|
|
|
|
|
|
|
|
assert isinstance(layers, list), (
|
|
|
|
@@ -204,18 +198,44 @@ def tensor_auto_parallel(
|
|
|
|
|
group=group,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
segments: int = 1
|
|
|
|
|
|
|
|
|
|
def _all_to_sharded(path: str, weight: mx.array):
|
|
|
|
|
if path.endswith("bias"):
|
|
|
|
|
logger.info(f"Sharding bias for {path} - all to sharded")
|
|
|
|
|
return weight.ndim - 1, segments
|
|
|
|
|
return max(weight.ndim - 2, 0), segments
|
|
|
|
|
|
|
|
|
|
all_to_sharded_linear_in_place = partial(
|
|
|
|
|
shard_inplace,
|
|
|
|
|
sharding="all-to-sharded",
|
|
|
|
|
group=group,
|
|
|
|
|
)
|
|
|
|
|
sharded_to_all_linear_in_place = partial(
|
|
|
|
|
shard_inplace,
|
|
|
|
|
sharding="sharded-to-all",
|
|
|
|
|
sharding=_all_to_sharded, # type: ignore
|
|
|
|
|
group=group,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(model, LlamaModel):
|
|
|
|
|
n = group.size()
|
|
|
|
|
|
|
|
|
|
def _sharded_to_all(path: str, weight: mx.array):
|
|
|
|
|
if path.endswith("bias"):
|
|
|
|
|
logger.info(f"Sharding bias for {path} - sharded to all")
|
|
|
|
|
weight /= n
|
|
|
|
|
return None
|
|
|
|
|
return -1, segments
|
|
|
|
|
|
|
|
|
|
sharded_to_all_linear_in_place = partial(
|
|
|
|
|
shard_inplace,
|
|
|
|
|
sharding=_sharded_to_all, # type: ignore
|
|
|
|
|
group=group,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if hasattr(model, "shard"):
|
|
|
|
|
try:
|
|
|
|
|
model.shard(group) # type: ignore
|
|
|
|
|
return model
|
|
|
|
|
except (AttributeError, TypeError, NameError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if isinstance(model, (LlamaModel, Ministral3Model)):
|
|
|
|
|
logger.warning("shouldn't be hit - upstream sharding exists")
|
|
|
|
|
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
|
|
|
|
group,
|
|
|
|
|
all_to_sharded_linear,
|
|
|
|
@@ -223,7 +243,8 @@ def tensor_auto_parallel(
|
|
|
|
|
all_to_sharded_linear_in_place,
|
|
|
|
|
sharded_to_all_linear_in_place,
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(model, DeepseekV3Model):
|
|
|
|
|
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
|
|
|
|
logger.warning("shouldn't be hit - upstream sharding exists")
|
|
|
|
|
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
|
|
|
|
group,
|
|
|
|
|
all_to_sharded_linear,
|
|
|
|
@@ -231,7 +252,7 @@ def tensor_auto_parallel(
|
|
|
|
|
all_to_sharded_linear_in_place,
|
|
|
|
|
sharded_to_all_linear_in_place,
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(model, Qwen3MoeModel):
|
|
|
|
|
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
|
|
|
|
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
|
|
|
|
group,
|
|
|
|
|
all_to_sharded_linear,
|
|
|
|
@@ -239,6 +260,15 @@ def tensor_auto_parallel(
|
|
|
|
|
all_to_sharded_linear_in_place,
|
|
|
|
|
sharded_to_all_linear_in_place,
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(model, GptOssModel):
|
|
|
|
|
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
|
|
|
|
group,
|
|
|
|
|
all_to_sharded_linear,
|
|
|
|
|
sharded_to_all_linear,
|
|
|
|
|
all_to_sharded_linear_in_place,
|
|
|
|
|
sharded_to_all_linear_in_place,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported model type: {type(model)}")
|
|
|
|
|
|
|
|
|
@@ -284,6 +314,32 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
|
|
|
|
inner_model_instance = _inner_model(model)
|
|
|
|
|
if hasattr(inner_model_instance, "layers"):
|
|
|
|
|
inner_model_instance.layers = layers
|
|
|
|
|
|
|
|
|
|
# Update DeepSeek V3 specific parameters when layers are shrunk
|
|
|
|
|
if isinstance(
|
|
|
|
|
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
|
|
|
|
) and hasattr(inner_model_instance, "num_layers"):
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
|
|
|
|
)
|
|
|
|
|
inner_model_instance.start_idx = 0
|
|
|
|
|
inner_model_instance.end_idx = len(layers)
|
|
|
|
|
inner_model_instance.num_layers = len(layers)
|
|
|
|
|
elif isinstance(model, Qwen3MoeModel):
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
|
|
|
|
)
|
|
|
|
|
inner_model_instance.num_hidden_layers = len(layers)
|
|
|
|
|
elif hasattr(inner_model_instance, "h"):
|
|
|
|
|
inner_model_instance.h = layers
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
def shard_model(self, model: nn.Module) -> nn.Module:
|
|
|
|
|
model = cast(DeepseekV3Model, model)
|
|
|
|
@@ -304,7 +360,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
layer.self_attn.num_heads //= self.N
|
|
|
|
|
|
|
|
|
|
# Shard the MLP
|
|
|
|
|
if isinstance(layer.mlp, DeepseekV3MLP):
|
|
|
|
|
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
|
|
|
|
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
|
|
|
|
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
|
|
|
|
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
|
|
|
@@ -352,7 +408,9 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
|
|
|
|
|
# Shard the MoE. Shard in place since the MoE should be responsible
|
|
|
|
|
# for aggregating the results.
|
|
|
|
|
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
|
|
|
|
if isinstance(
|
|
|
|
|
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
|
|
|
|
|
):
|
|
|
|
|
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
|
|
|
|
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
|
|
|
|
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
|
|
|
@@ -380,3 +438,50 @@ class ShardedQwenMoE(CustomMlxLayer):
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
def shard_model(self, model: nn.Module) -> nn.Module:
|
|
|
|
|
model = cast(GptOssMoeModel, model)
|
|
|
|
|
|
|
|
|
|
for layer in model.layers:
|
|
|
|
|
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
|
|
|
|
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
|
|
|
|
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
|
|
|
|
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
|
|
|
|
|
|
|
|
|
layer.self_attn.num_attention_heads //= self.N
|
|
|
|
|
layer.self_attn.num_key_value_heads //= self.N
|
|
|
|
|
layer.self_attn.num_key_value_groups = (
|
|
|
|
|
layer.self_attn.num_attention_heads
|
|
|
|
|
// layer.self_attn.num_key_value_heads
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
layer.self_attn.sinks = layer.self_attn.sinks[
|
|
|
|
|
layer.self_attn.num_attention_heads
|
|
|
|
|
* self.group.rank() : layer.self_attn.num_attention_heads
|
|
|
|
|
* (self.group.rank() + 1)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
|
|
|
|
|
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
|
|
|
|
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
|
|
|
|
|
|
|
|
|
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
|
|
|
|
layer.mlp.sharding_group = self.group
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardedGptOssMoE(CustomMlxLayer):
|
|
|
|
|
def __init__(self, layer: nn.Module):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
self.sharding_group: mx.distributed.Group | None = None
|
|
|
|
|
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
x = sum_gradients(self.sharding_group)(x)
|
|
|
|
|
y = self.original_layer(x)
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
|
|
|
|
return y
|
|
|
|
|