Compare commits

...

2 Commits

Author SHA1 Message Date
dmcc73
6018a9c97c Point mlx-lm to davidmcc73 fork with context parallelism support 2026-02-02 19:16:43 +00:00
dmcc73
07b8405d3e Add context parallelism support to DeepSeek sharding
Store pre-shard head count and distributed group on each attention
layer during sharding, enabling automatic TP→CP switching at runtime
when context length exceeds a threshold.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 18:35:20 +00:00
2 changed files with 8 additions and 0 deletions

View File

@@ -71,6 +71,7 @@ exo_pyo3_bindings = { workspace = true }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
[build-system]
requires = ["uv_build>=0.8.9,<0.9.0"]

View File

@@ -520,6 +520,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.kv_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
# Store pre-shard head count and group for context parallelism
layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads
layer.self_attn._cp_group = self.group
layer.self_attn.num_heads //= self.N
# Shard the MLP
@@ -542,6 +545,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
mx.eval(layer)
# Store group for context parallelism
if hasattr(model, "model"):
model.model._cp_group = self.group
return model