diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 64537616..3457b100 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -520,6 +520,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy): layer.self_attn.kv_b_proj ) layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj) + # Store pre-shard head count and group for context parallelism + layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads + layer.self_attn._cp_group = self.group layer.self_attn.num_heads //= self.N # Shard the MLP @@ -542,6 +545,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy): mx.eval(layer) + # Store group for context parallelism + if hasattr(model, "model"): + model.model._cp_group = self.group + return model