Compare commits

...

2 Commits

Author SHA1 Message Date
Alex Cheema
5ee257a13e feat: add tensor parallelism support for Step 3.5 Flash
Add Step3p5ShardingStrategy to auto_parallel.py following the
DeepSeek pattern (shared expert + routed experts). Shard attention
q/k/v/o projections across devices and MoE expert weights in-place
with all-reduce synchronization.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:52:28 -08:00
Alex Cheema
24511ab7cb feat: add Step 3.5 Flash model cards and update mlx-lm
Update mlx-lm to v0.30.6 which includes Step 3.5 Flash support
(ml-explore/mlx-lm#836). Add model cards for the 4bit, 6bit, and 8bit
quantizations from mlx-community.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:33:57 -08:00
6 changed files with 91 additions and 2 deletions

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Step-3.5-Flash-4bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 114572190076

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Step-3.5-Flash-6bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 159039627774

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Step-3.5-Flash-8Bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 209082699847

View File

@@ -233,6 +233,7 @@ class ConfigData(BaseModel):
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
["Step3p5ForCausalLM"],
]
@model_validator(mode="before")

View File

@@ -31,6 +31,8 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step3p5Model
from mlx_lm.models.step3p5 import Step3p5MLP
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
@@ -380,6 +382,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Step3p5Model):
tensor_parallel_sharding_strategy = Step3p5ShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
@@ -774,3 +784,57 @@ class ShardedGptOssMoE(CustomMlxLayer):
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class Step3p5ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Step3p5Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N # pyright: ignore[reportUnknownMemberType]
layer.self_attn.num_kv_heads //= self.N # pyright: ignore[reportUnknownMemberType]
if isinstance(layer.mlp, Step3p5MLP):
# Dense MLP layer
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
else:
# MoE layer: shared expert + routed experts
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedStep3p5MoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class ShardedStep3p5MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

4
uv.lock generated
View File

@@ -1072,8 +1072,8 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.5"
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
version = "0.30.6"
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#ab050d1fac2ef1d7bea6b8d870f1e5717d7f59f5" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },