mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-14 16:15:43 -05:00
update glm 5 to use upstream mlx lm
This commit is contained in:
@@ -24,8 +24,6 @@ from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.glm_moe_dsa import Glm4MoeLiteMoE as GlmMoeDsaMoE
|
||||
from mlx_lm.models.glm_moe_dsa import Model as GlmMoeDsaModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
@@ -408,14 +406,6 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GlmMoeDsaModel):
|
||||
tensor_parallel_sharding_strategy = GlmMoeDsaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Glm4MoeModel):
|
||||
tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy(
|
||||
group,
|
||||
@@ -667,62 +657,6 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class GlmMoeDsaShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GlmMoeDsaModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
timeout_seconds / len(model.layers),
|
||||
on_timeout,
|
||||
)
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
if isinstance(layer.mlp, Glm4MoeLiteMLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
else:
|
||||
moe = cast(GlmMoeDsaMoE, layer.mlp)
|
||||
if moe.shared_experts is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
moe.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
moe.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
moe.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(moe.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(moe.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(moe.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedMoE(moe) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
|
||||
Reference in New Issue
Block a user