Add minimax and fix qwen sharding strategies

This commit is contained in:
Ryuichi Leo Takashige
2026-01-28 19:27:56 +00:00
parent 748a026071
commit b60a59bbf6
4 changed files with 131 additions and 22 deletions

View File

@@ -1139,7 +1139,7 @@ class array:
) -> array:
"""See :func:`flatten`."""
def reshape(self, *shape, stream: Stream | Device | None = ...) -> array:
def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:
"""
Equivalent to :func:`reshape` but the shape can be passed either as a
:obj:`tuple` or as separate arguments.
@@ -1222,7 +1222,7 @@ class array:
) -> array:
"""See :func:`swapaxes`."""
def transpose(self, *axes, stream: Stream | Device | None = ...) -> array:
def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:
"""
Equivalent to :func:`transpose` but the axes can be passed either as
a tuple or as separate arguments.

View File

@@ -11,7 +11,10 @@ import mlx.core as mx
class Cache(Protocol):
keys: mx.array
values: mx.array
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
offset: int
def update_and_fetch(
self, keys: mx.array, values: mx.array
) -> tuple[mx.array, mx.array]: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter
@@ -87,6 +90,7 @@ def create_attention_mask(
class _BaseCache(Cache):
keys: mx.array
values: mx.array
offset: int
@property
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter

View File

@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.base import (
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
@@ -25,16 +28,21 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import MiniMaxAttention
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
if TYPE_CHECKING:
from mlx_lm.models.cache import Cache
TimeoutCallback = Callable[[], None]
@@ -615,6 +623,84 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
return y
class WrappedMiniMaxAttention(CustomMlxLayer):
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
super().__init__(layer)
self.group = group
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
cache: Cache | None = None,
) -> mx.array:
batch_dim, seq_dim, _ = x.shape
self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore
queries: mx.array = self._original_layer.q_proj(x)
keys: mx.array = self._original_layer.k_proj(x)
values: mx.array = self._original_layer.v_proj(x)
if getattr(self, "use_qk_norm", False):
q_dim = queries.shape[-1]
k_dim = keys.shape[-1]
n = self.group.size()
qk = mx.concatenate(
[queries, keys], axis=-1
) # (batch_dim, seq_dim, q_dim + k_dim)
qk = mx.distributed.all_gather(
qk, group=self.group
) # (n*batch_dim, seq_dim, q_dim + k_dim)
qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)
queries = qk[..., :q_dim].reshape(
batch_dim, seq_dim, -1
) # (batch_dim, seq_dim, n * q_dim)
keys = qk[..., q_dim:].reshape(
batch_dim, seq_dim, -1
) # (batch_dim, seq_dim, n * k_dim)
queries = self._original_layer.q_norm(queries)
keys = self._original_layer.k_norm(keys)
# Split back and take this rank's portion
queries = mx.split(queries, n, axis=-1)[self.group.rank()]
keys = mx.split(keys, n, axis=-1)[self.group.rank()]
queries = queries.reshape(
batch_dim, seq_dim, self._original_layer.num_attention_heads, -1
).transpose(0, 2, 1, 3)
keys = keys.reshape(
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
).transpose(0, 2, 1, 3)
values = values.reshape(
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
).transpose(0, 2, 1, 3)
if cache is not None:
queries = self._original_layer.rope(queries, offset=cache.offset)
keys = self._original_layer.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self._original_layer.rope(queries)
keys = self._original_layer.rope(keys)
output = scaled_dot_product_attention(
queries,
keys,
values,
cache=cache,
scale=self._original_layer.scale, # type: ignore
mask=mask,
)
output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)
return self._original_layer.o_proj(output)
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -623,7 +709,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
rank = self.group.rank()
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
@@ -634,18 +719,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
# Shard qk_norm weights if present (must match sharded head count)
if getattr(layer.self_attn, "use_qk_norm", False):
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
@@ -670,18 +748,45 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Qwen3MoeModel, model)
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.n_heads //= self.N
layer.self_attn.n_kv_heads //= self.N
if isinstance(layer, Qwen3DecoderLayer):
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
layer.self_attn.k_proj = self.all_to_sharded_linear(
layer.self_attn.k_proj
)
layer.self_attn.v_proj = self.all_to_sharded_linear(
layer.self_attn.v_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
else:
assert isinstance(layer, Qwen3NextDecoderLayer)
if hasattr(layer, "linear_attn"):
# These layers are fast so we don't shard. This may change in future.
pass
else:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
layer.self_attn.k_proj = self.all_to_sharded_linear(
layer.self_attn.k_proj
)
layer.self_attn.v_proj = self.all_to_sharded_linear(
layer.self_attn.v_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.

View File

@@ -171,7 +171,7 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
return max(c.offset for c in cache)
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int: