mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 11:11:45 -05:00
Compare commits
2 Commits
rust-explo
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ee257a13e | ||
|
|
24511ab7cb |
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-4bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-6bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
@@ -233,6 +233,7 @@ class ConfigData(BaseModel):
|
||||
["MiniMaxM2ForCausalLM"],
|
||||
["LlamaForCausalLM"],
|
||||
["GptOssForCausalLM"],
|
||||
["Step3p5ForCausalLM"],
|
||||
]
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -31,6 +31,8 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.step3p5 import Model as Step3p5Model
|
||||
from mlx_lm.models.step3p5 import Step3p5MLP
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
@@ -380,6 +382,14 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Step3p5Model):
|
||||
tensor_parallel_sharding_strategy = Step3p5ShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
@@ -774,3 +784,57 @@ class ShardedGptOssMoE(CustomMlxLayer):
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class Step3p5ShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Step3p5Model, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N # pyright: ignore[reportUnknownMemberType]
|
||||
layer.self_attn.num_kv_heads //= self.N # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
if isinstance(layer.mlp, Step3p5MLP):
|
||||
# Dense MLP layer
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
else:
|
||||
# MoE layer: shared expert + routed experts
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedStep3p5MoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
class ShardedStep3p5MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
4
uv.lock
generated
4
uv.lock
generated
@@ -1072,8 +1072,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.5"
|
||||
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
|
||||
version = "0.30.6"
|
||||
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#ab050d1fac2ef1d7bea6b8d870f1e5717d7f59f5" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
|
||||
Reference in New Issue
Block a user