diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 2ec9900e..69ea2984 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -24,8 +24,6 @@ from mlx_lm.models.glm4_moe import Model as Glm4MoeModel from mlx_lm.models.glm4_moe import MoE from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel -from mlx_lm.models.glm_moe_dsa import Glm4MoeLiteMoE as GlmMoeDsaMoE -from mlx_lm.models.glm_moe_dsa import Model as GlmMoeDsaModel from mlx_lm.models.gpt_oss import GptOssMoeModel from mlx_lm.models.gpt_oss import Model as GptOssModel from mlx_lm.models.kimi_k25 import Model as KimiK25Model @@ -408,14 +406,6 @@ def tensor_auto_parallel( all_to_sharded_linear_in_place, sharded_to_all_linear_in_place, ) - elif isinstance(model, GlmMoeDsaModel): - tensor_parallel_sharding_strategy = GlmMoeDsaShardingStrategy( - group, - all_to_sharded_linear, - sharded_to_all_linear, - all_to_sharded_linear_in_place, - sharded_to_all_linear_in_place, - ) elif isinstance(model, Glm4MoeModel): tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy( group, @@ -667,62 +657,6 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy): return model -class GlmMoeDsaShardingStrategy(TensorParallelShardingStrategy): - def shard_model( - self, - model: nn.Module, - timeout_seconds: float, - on_timeout: TimeoutCallback | None, - ) -> nn.Module: - model = cast(GlmMoeDsaModel, model) - for layer in model.layers: - eval_with_timeout( - layer.parameters(), - timeout_seconds / len(model.layers), - on_timeout, - ) - layer.self_attn.q_b_proj = self.all_to_sharded_linear( - layer.self_attn.q_b_proj - ) - layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj) - layer.self_attn.num_heads //= self.N - - num_heads = layer.self_attn.num_heads - sh = self.group.rank() * num_heads - eh = sh + num_heads - - def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array: - return w[sh:eh] - - layer.self_attn.embed_q.apply(shard_heads) - layer.self_attn.unembed_out.apply(shard_heads) - - if isinstance(layer.mlp, Glm4MoeLiteMLP): - layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj) - layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj) - layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj) - else: - moe = cast(GlmMoeDsaMoE, layer.mlp) - if moe.shared_experts is not None: - self.all_to_sharded_linear_in_place( - moe.shared_experts.gate_proj - ) - self.sharded_to_all_linear_in_place( - moe.shared_experts.down_proj - ) - self.all_to_sharded_linear_in_place( - moe.shared_experts.up_proj - ) - self.all_to_sharded_linear_in_place(moe.switch_mlp.gate_proj) - self.sharded_to_all_linear_in_place(moe.switch_mlp.down_proj) - self.all_to_sharded_linear_in_place(moe.switch_mlp.up_proj) - layer.mlp = ShardedMoE(moe) # type: ignore - layer.mlp.sharding_group = self.group - mx.eval(layer) - - return model - - class WrappedMiniMaxAttention(CustomMlxLayer): def __init__(self, layer: _LayerCallable, group: mx.distributed.Group): super().__init__(layer)