From 07b8405d3e360cca972e5ac3c8c194982bf2cbef Mon Sep 17 00:00:00 2001 From: dmcc73 Date: Mon, 2 Feb 2026 18:35:20 +0000 Subject: [PATCH] Add context parallelism support to DeepSeek sharding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Store pre-shard head count and distributed group on each attention layer during sharding, enabling automatic TP→CP switching at runtime when context length exceeds a threshold. Co-Authored-By: Claude Opus 4.5 --- src/exo/worker/engines/mlx/auto_parallel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 64537616..3457b100 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -520,6 +520,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy): layer.self_attn.kv_b_proj ) layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj) + # Store pre-shard head count and group for context parallelism + layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads + layer.self_attn._cp_group = self.group layer.self_attn.num_heads //= self.N # Shard the MLP @@ -542,6 +545,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy): mx.eval(layer) + # Store group for context parallelism + if hasattr(model, "model"): + model.model._cp_group = self.group + return model