mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 19:22:39 -05:00
Compare commits
2 Commits
leo/add-lo
...
david/mla-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6018a9c97c | ||
|
|
07b8405d3e |
@@ -71,6 +71,7 @@ exo_pyo3_bindings = { workspace = true }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.9,<0.9.0"]
|
||||
|
||||
@@ -520,6 +520,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
# Store pre-shard head count and group for context parallelism
|
||||
layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads
|
||||
layer.self_attn._cp_group = self.group
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
@@ -542,6 +545,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
mx.eval(layer)
|
||||
|
||||
# Store group for context parallelism
|
||||
if hasattr(model, "model"):
|
||||
model.model._cp_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user