Compare commits

...

2 Commits

Author SHA1 Message Date
rltakashige
c8cbe31a09 Merge branch 'main' into leo/add-step35-flash 2026-02-12 18:32:56 +00:00
Ryuichi Leo Takashige
7e08a6302b Add support for Step 3.5 flash! 2026-02-12 17:53:50 +00:00
6 changed files with 255 additions and 0 deletions

View File

@@ -0,0 +1,151 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
vocab_size: int
num_attention_heads: int
num_attention_groups: int
head_dim: int
intermediate_size: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
max_position_embeddings: int
sliding_window: int
layer_types: Optional[List[str]]
yarn_only_types: Optional[List[str]]
partial_rotary_factors: Optional[List[float]]
attention_other_setting: Optional[Dict[str, Any]]
use_head_wise_attn_gate: bool
moe_num_experts: int
moe_top_k: int
moe_intermediate_size: int
share_expert_dim: int
moe_layers_enum: Optional[str]
moe_router_scaling_factor: float
norm_expert_weight: bool
swiglu_limits: Optional[List[float]]
swiglu_limits_shared: Optional[List[float]]
tie_word_embeddings: bool
class Step3p5MLP(nn.Module):
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
limit: Optional[float]
def __init__(
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5MoEGate(nn.Module):
top_k: int
n_routed_experts: int
routed_scaling_factor: float
norm_topk_prob: bool
gate: nn.Linear
router_bias: mx.array
def __init__(self, args: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class Step3p5MoE(nn.Module):
gate: Step3p5MoEGate
switch_mlp: SwitchGLU
share_expert: Step3p5MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5Attention(nn.Module):
is_sliding: bool
num_heads: int
num_kv_heads: int
head_dim: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
q_norm: nn.Module
k_norm: nn.Module
use_head_wise_attn_gate: bool
g_proj: nn.Linear
rope: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5DecoderLayer(nn.Module):
self_attn: Step3p5Attention
is_sliding: bool
is_moe_layer: bool
mlp: Step3p5MLP | Step3p5MoE
input_layernorm: nn.Module
post_attention_layernorm: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5Model(nn.Module):
args: ModelArgs
vocab_size: int
num_layers: int
embed_tokens: nn.Embedding
layers: list[Step3p5DecoderLayer]
norm: nn.Module
_swa_idx: Optional[int]
_full_idx: Optional[int]
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
class Model(nn.Module):
args: ModelArgs
model_type: str
model: Step3p5Model
lm_head: nn.Linear
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
@property
def layers(self) -> list[Step3p5DecoderLayer]: ...
def make_cache(self) -> list[Any]: ...
@property
def cast_predicate(self) -> Any: ...
@property
def quant_predicate(self) -> Any: ...

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-4bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 114572190076

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-6bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 159039627774

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-8Bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 209082699847

View File

@@ -189,6 +189,7 @@ class ConfigData(BaseModel):
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
["Step3p5ForCausalLM"],
]
@model_validator(mode="before")

View File

@@ -35,6 +35,9 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step35Model
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
@@ -264,6 +267,19 @@ def pipeline_auto_parallel(
)
)
if isinstance(inner_model_instance, Step35InnerModel):
inner_model_instance.num_layers = len(layers)
sliding_layers = [
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
]
full_layers = [
i
for i, layer in enumerate(layers)
if not getattr(layer, "is_sliding", True)
]
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -427,6 +443,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Step35Model):
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -981,3 +1005,46 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class Step35ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Step35Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
layer.self_attn.num_kv_heads //= self.N
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
layer.self_attn.g_proj = self.all_to_sharded_linear(
layer.self_attn.g_proj
)
if isinstance(layer.mlp, Step35MLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
else:
layer.mlp.sharding_group = self.group
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
mx.eval(layer)
return model