update glm 5 to use upstream mlx lm

This commit is contained in:
Ryuichi Leo Takashige
2026-02-13 12:50:08 +00:00
parent ce0eef999e
commit 0de3e486df

View File

@@ -24,8 +24,6 @@ from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
from mlx_lm.models.glm_moe_dsa import Glm4MoeLiteMoE as GlmMoeDsaMoE
from mlx_lm.models.glm_moe_dsa import Model as GlmMoeDsaModel
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
@@ -408,14 +406,6 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GlmMoeDsaModel):
tensor_parallel_sharding_strategy = GlmMoeDsaShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Glm4MoeModel):
tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy(
group,
@@ -667,62 +657,6 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
return model
class GlmMoeDsaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GlmMoeDsaModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(),
timeout_seconds / len(model.layers),
on_timeout,
)
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
if isinstance(layer.mlp, Glm4MoeLiteMLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
else:
moe = cast(GlmMoeDsaMoE, layer.mlp)
if moe.shared_experts is not None:
self.all_to_sharded_linear_in_place(
moe.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
moe.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
moe.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(moe.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(moe.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(moe.switch_mlp.up_proj)
layer.mlp = ShardedMoE(moe) # type: ignore
layer.mlp.sharding_group = self.group
mx.eval(layer)
return model
class WrappedMiniMaxAttention(CustomMlxLayer):
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
super().__init__(layer)