|
|
|
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|
|
|
|
from collections.abc import Callable
|
|
|
|
|
from functools import partial
|
|
|
|
|
from inspect import signature
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Protocol, cast
|
|
|
|
|
from typing import TYPE_CHECKING, Any, cast
|
|
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
|
import mlx.nn as nn
|
|
|
|
@@ -67,27 +67,16 @@ def eval_with_timeout(
|
|
|
|
|
completed.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _LayerCallable(Protocol):
|
|
|
|
|
"""Structural type that any compatible layer must satisfy.
|
|
|
|
|
|
|
|
|
|
We require a single positional input of type ``mx.array`` and an
|
|
|
|
|
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
|
|
|
|
|
protocol matches the vast majority of `mlx.nn.Module` subclasses.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomMlxLayer(nn.Module):
|
|
|
|
|
"""Base class for replacing an MLX layer with a custom implementation."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, original_layer: _LayerCallable):
|
|
|
|
|
def __init__(self, original_layer: nn.Module):
|
|
|
|
|
super().__init__()
|
|
|
|
|
object.__setattr__(self, "_original_layer", original_layer)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def original_layer(self) -> _LayerCallable:
|
|
|
|
|
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
|
|
|
|
def original_layer(self) -> nn.Module:
|
|
|
|
|
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
|
|
|
|
|
|
|
|
|
|
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
|
|
|
|
if not TYPE_CHECKING:
|
|
|
|
@@ -100,52 +89,53 @@ class CustomMlxLayer(nn.Module):
|
|
|
|
|
return getattr(original_layer, name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PipelineFirstLayer(CustomMlxLayer):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
original_layer: _LayerCallable,
|
|
|
|
|
r: int,
|
|
|
|
|
group: mx.distributed.Group,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(original_layer)
|
|
|
|
|
self.r: int = r
|
|
|
|
|
self.group = group
|
|
|
|
|
def patch_pipeline_first_layer(
|
|
|
|
|
pipeline_layer: nn.Module, group: mx.distributed.Group
|
|
|
|
|
) -> nn.Module:
|
|
|
|
|
cls = type(pipeline_layer)
|
|
|
|
|
orig_call = cast(Callable[..., mx.array], cls.__call__)
|
|
|
|
|
|
|
|
|
|
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
|
|
|
|
if self.r != 0:
|
|
|
|
|
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
|
|
|
|
return self.original_layer(x, *args, **kwargs)
|
|
|
|
|
rank = group.rank()
|
|
|
|
|
|
|
|
|
|
class PatchedFirstLayer(cls):
|
|
|
|
|
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
|
|
|
|
if rank != 0:
|
|
|
|
|
x = mx.distributed.recv_like(x, (rank - 1), group=group)
|
|
|
|
|
return orig_call(self, x, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
pipeline_layer.__class__ = PatchedFirstLayer
|
|
|
|
|
|
|
|
|
|
return pipeline_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PipelineLastLayer(CustomMlxLayer):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
original_layer: _LayerCallable,
|
|
|
|
|
r: int,
|
|
|
|
|
s: int,
|
|
|
|
|
group: mx.distributed.Group,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(original_layer)
|
|
|
|
|
self.r: int = r
|
|
|
|
|
self.s: int = s
|
|
|
|
|
self.group = group
|
|
|
|
|
self.original_layer_signature = signature(self.original_layer.__call__)
|
|
|
|
|
def patch_pipeline_last_layer(
|
|
|
|
|
pipeline_layer: nn.Module, group: mx.distributed.Group
|
|
|
|
|
) -> nn.Module:
|
|
|
|
|
cls = type(pipeline_layer)
|
|
|
|
|
orig_call = cast(Callable[..., mx.array], cls.__call__)
|
|
|
|
|
orig_call_sig = signature(orig_call)
|
|
|
|
|
|
|
|
|
|
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
|
|
|
|
cache = self.original_layer_signature.bind_partial(
|
|
|
|
|
x, *args, **kwargs
|
|
|
|
|
).arguments.get("cache", None)
|
|
|
|
|
rank = group.rank()
|
|
|
|
|
size = group.size()
|
|
|
|
|
|
|
|
|
|
output: mx.array = self.original_layer(x, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
if self.r != self.s - 1:
|
|
|
|
|
output = mx.distributed.send(
|
|
|
|
|
output, (self.r + 1) % self.s, group=self.group
|
|
|
|
|
class PatchedLastLayer(cls):
|
|
|
|
|
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
|
|
|
|
cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
|
|
|
|
|
"cache", None
|
|
|
|
|
)
|
|
|
|
|
if cache is not None:
|
|
|
|
|
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
output: mx.array = orig_call(self, x, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
if rank != size - 1:
|
|
|
|
|
output = mx.distributed.send(output, (rank + 1) % size, group=group)
|
|
|
|
|
if cache is not None:
|
|
|
|
|
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
pipeline_layer.__class__ = PatchedLastLayer
|
|
|
|
|
|
|
|
|
|
return pipeline_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _inner_model(model: nn.Module) -> nn.Module:
|
|
|
|
@@ -160,13 +150,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
|
|
|
|
|
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
|
|
|
|
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
|
|
|
|
|
# Handle both model.layers and model.h cases
|
|
|
|
|
layers: list[_LayerCallable]
|
|
|
|
|
layers: list[nn.Module]
|
|
|
|
|
if hasattr(inner_model_instance, "layers"):
|
|
|
|
|
layers = cast(list[_LayerCallable], inner_model_instance.layers)
|
|
|
|
|
layers = cast(list[nn.Module], inner_model_instance.layers)
|
|
|
|
|
elif hasattr(inner_model_instance, "h"):
|
|
|
|
|
layers = cast(list[_LayerCallable], inner_model_instance.h)
|
|
|
|
|
layers = cast(list[nn.Module], inner_model_instance.h)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
|
|
|
|
|
|
|
|
@@ -191,15 +181,12 @@ def pipeline_auto_parallel(
|
|
|
|
|
layers = _get_layers(inner_model_instance)
|
|
|
|
|
|
|
|
|
|
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
|
|
|
|
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
|
|
|
|
|
|
|
|
|
layers = layers[start_layer:end_layer]
|
|
|
|
|
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
|
|
|
|
layers[-1] = PipelineLastLayer(
|
|
|
|
|
layers[0] = patch_pipeline_first_layer(layers[0], group)
|
|
|
|
|
layers[-1] = patch_pipeline_last_layer(
|
|
|
|
|
layers[-1],
|
|
|
|
|
device_rank,
|
|
|
|
|
world_size,
|
|
|
|
|
group=group,
|
|
|
|
|
group,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(inner_model_instance, GptOssMoeModel):
|
|
|
|
@@ -446,7 +433,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
|
|
|
|
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
|
|
|
|
|
inner_model_instance = _inner_model(model)
|
|
|
|
|
if hasattr(inner_model_instance, "layers"):
|
|
|
|
|
inner_model_instance.layers = layers
|
|
|
|
@@ -521,17 +508,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
|
|
|
|
def __init__(self, layer: _LayerCallable):
|
|
|
|
|
def __init__(self, layer: nn.Module):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
self.sharding_group: mx.distributed.Group | None = None
|
|
|
|
|
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
x = sum_gradients(self.sharding_group)(x)
|
|
|
|
|
y = self.original_layer.__call__(x)
|
|
|
|
|
y = self.original_layer.__call__(x) # type: ignore
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
|
|
|
|
return y
|
|
|
|
|
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
|
|
|
|
return y # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
@@ -565,7 +552,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
self.all_to_sharded_linear_in_place(
|
|
|
|
|
layer.block_sparse_moe.switch_mlp.up_proj
|
|
|
|
|
)
|
|
|
|
|
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
|
|
|
|
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
|
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
@@ -599,7 +586,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
|
|
|
|
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
|
|
|
|
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
|
|
|
|
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
|
|
|
|
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
|
layer.mlp.sharding_group = self.group
|
|
|
|
|
|
|
|
|
|
# Shard the MLP
|
|
|
|
@@ -612,17 +599,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardedQwenMoE(CustomMlxLayer):
|
|
|
|
|
def __init__(self, layer: _LayerCallable):
|
|
|
|
|
def __init__(self, layer: nn.Module):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
self.sharding_group: mx.distributed.Group | None = None
|
|
|
|
|
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
x = sum_gradients(self.sharding_group)(x)
|
|
|
|
|
y = self.original_layer.__call__(x)
|
|
|
|
|
y = self.original_layer.__call__(x) # type: ignore
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
|
|
|
|
return y
|
|
|
|
|
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
|
|
|
|
return y # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
|
|
|
@@ -674,7 +661,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
x = sum_gradients(self.sharding_group)(x)
|
|
|
|
|
y = self.original_layer(x)
|
|
|
|
|
y = self.original_layer(x) # type: ignore
|
|
|
|
|
if self.sharding_group is not None:
|
|
|
|
|
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
|
|
|
|
return y
|
|
|
|
|
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
|
|
|
|
return y # type: ignore
|
|
|
|
|