Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
5019ac489e foo 2026-01-20 16:52:28 +00:00

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Any, cast
import mlx.core as mx
import mlx.nn as nn
@@ -13,11 +13,17 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
@@ -60,28 +66,16 @@ def eval_with_timeout(
finally:
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable):
def __init__(self, original_layer: nn.Module):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
def original_layer(self) -> nn.Module:
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -94,53 +88,49 @@ class CustomMlxLayer(nn.Module):
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
s: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
def patch_pipeline_first_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
rank = group.rank()
class PatchedFirstLayer(nn.Module):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(x, *args, **kwargs)
output: mx.array = self.original_layer(x, *args, **kwargs)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
def patch_pipeline_last_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
orig_call_sig = signature(orig_call)
return output
rank = group.rank()
size = group.size()
class PatchedLastLayer(nn.Module):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = orig_call_sig.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
output: mx.array = orig_call(x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(
output, (rank + 1) % size, group=group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
def _inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
@@ -154,13 +144,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
# Handle both model.layers and model.h cases
layers: list[_LayerCallable]
layers: list[nn.Module]
if hasattr(inner_model_instance, "layers"):
layers = cast(list[_LayerCallable], inner_model_instance.layers)
layers = cast(list[nn.Module], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"):
layers = cast(list[_LayerCallable], inner_model_instance.h)
layers = cast(list[nn.Module], inner_model_instance.h)
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -185,15 +175,12 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = patch_pipeline_last_layer(
layers[-1],
device_rank,
world_size,
group=group,
group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
@@ -335,7 +322,33 @@ def tensor_auto_parallel(
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, MiniMaxModel):
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -343,6 +356,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -377,7 +399,35 @@ class TensorParallelShardingStrategy(ABC):
) -> nn.Module: ...
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.n_heads //= self.N
if layer.self_attn.n_kv_heads is not None:
layer.self_attn.n_kv_heads //= self.N
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
return model
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
@@ -403,6 +453,105 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
else:
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
layer.self_attn.kv_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
return model
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.block_sparse_moe.switch_mlp.down_proj
)
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -431,7 +580,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -444,14 +593,69 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn.num_key_value_groups = (
layer.self_attn.num_attention_heads
// layer.self_attn.num_key_value_heads
)
layer.self_attn.sinks = layer.self_attn.sinks[
layer.self_attn.num_attention_heads
* self.group.rank() : layer.self_attn.num_attention_heads
* (self.group.rank() + 1)
]
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
class ShardedGptOssMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore