Add context parallelism support to DeepSeek sharding

Store pre-shard head count and distributed group on each attention
layer during sharding, enabling automatic TP→CP switching at runtime
when context length exceeds a threshold.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
dmcc73
2026-02-02 18:35:20 +00:00
parent 5d3b407602
commit 07b8405d3e

View File

@@ -520,6 +520,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.kv_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
# Store pre-shard head count and group for context parallelism
layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads
layer.self_attn._cp_group = self.group
layer.self_attn.num_heads //= self.N
# Shard the MLP
@@ -542,6 +545,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
mx.eval(layer)
# Store group for context parallelism
if hasattr(model, "model"):
model.model._cp_group = self.group
return model