mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 19:22:39 -05:00
Add context parallelism support to DeepSeek sharding
Store pre-shard head count and distributed group on each attention layer during sharding, enabling automatic TP→CP switching at runtime when context length exceeds a threshold. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -520,6 +520,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
# Store pre-shard head count and group for context parallelism
|
||||
layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads
|
||||
layer.self_attn._cp_group = self.group
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
@@ -542,6 +545,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
mx.eval(layer)
|
||||
|
||||
# Store group for context parallelism
|
||||
if hasattr(model, "model"):
|
||||
model.model._cp_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user