From b60a59bbf68ff6f4f417ff2369ebc99bc7800722 Mon Sep 17 00:00:00 2001 From: Ryuichi Leo Takashige Date: Wed, 28 Jan 2026 19:27:56 +0000 Subject: [PATCH] Add minimax and fix qwen sharding strategies --- .mlx_typings/mlx/core/__init__.pyi | 4 +- .mlx_typings/mlx_lm/models/cache.pyi | 6 +- src/exo/worker/engines/mlx/auto_parallel.py | 141 +++++++++++++++++--- src/exo/worker/engines/mlx/cache.py | 2 +- 4 files changed, 131 insertions(+), 22 deletions(-) diff --git a/.mlx_typings/mlx/core/__init__.pyi b/.mlx_typings/mlx/core/__init__.pyi index 48680a80..025b4ab2 100644 --- a/.mlx_typings/mlx/core/__init__.pyi +++ b/.mlx_typings/mlx/core/__init__.pyi @@ -1139,7 +1139,7 @@ class array: ) -> array: """See :func:`flatten`.""" - def reshape(self, *shape, stream: Stream | Device | None = ...) -> array: + def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array: """ Equivalent to :func:`reshape` but the shape can be passed either as a :obj:`tuple` or as separate arguments. @@ -1222,7 +1222,7 @@ class array: ) -> array: """See :func:`swapaxes`.""" - def transpose(self, *axes, stream: Stream | Device | None = ...) -> array: + def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array: """ Equivalent to :func:`transpose` but the axes can be passed either as a tuple or as separate arguments. diff --git a/.mlx_typings/mlx_lm/models/cache.pyi b/.mlx_typings/mlx_lm/models/cache.pyi index 37f96845..909aa18f 100644 --- a/.mlx_typings/mlx_lm/models/cache.pyi +++ b/.mlx_typings/mlx_lm/models/cache.pyi @@ -11,7 +11,10 @@ import mlx.core as mx class Cache(Protocol): keys: mx.array values: mx.array - def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ... + offset: int + def update_and_fetch( + self, keys: mx.array, values: mx.array + ) -> tuple[mx.array, mx.array]: ... @property def state(self) -> tuple[mx.array, mx.array]: ... @state.setter @@ -87,6 +90,7 @@ def create_attention_mask( class _BaseCache(Cache): keys: mx.array values: mx.array + offset: int @property def state(self) -> tuple[mx.array, mx.array]: ... @state.setter diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index ff2052fb..4babac67 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import ( shard_linear, sum_gradients, ) +from mlx_lm.models.base import ( + scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType] +) from mlx_lm.models.deepseek_v3 import DeepseekV3MLP from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model from mlx_lm.models.deepseek_v32 import DeepseekV32MLP @@ -25,16 +28,21 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel from mlx_lm.models.gpt_oss import Model as GptOssModel from mlx_lm.models.kimi_k25 import Model as KimiK25Model from mlx_lm.models.llama import Model as LlamaModel +from mlx_lm.models.minimax import MiniMaxAttention from mlx_lm.models.minimax import Model as MiniMaxModel from mlx_lm.models.ministral3 import Model as Ministral3Model from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock from mlx_lm.models.qwen3_next import Model as Qwen3NextModel -from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock +from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock +from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer from exo.shared.logging import logger from exo.shared.types.worker.shards import PipelineShardMetadata +if TYPE_CHECKING: + from mlx_lm.models.cache import Cache + TimeoutCallback = Callable[[], None] @@ -615,6 +623,84 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer): return y +class WrappedMiniMaxAttention(CustomMlxLayer): + def __init__(self, layer: _LayerCallable, group: mx.distributed.Group): + super().__init__(layer) + self.group = group + + def __call__( + self, + x: mx.array, + mask: mx.array | None = None, + cache: Cache | None = None, + ) -> mx.array: + batch_dim, seq_dim, _ = x.shape + + self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore + + queries: mx.array = self._original_layer.q_proj(x) + keys: mx.array = self._original_layer.k_proj(x) + values: mx.array = self._original_layer.v_proj(x) + + if getattr(self, "use_qk_norm", False): + q_dim = queries.shape[-1] + k_dim = keys.shape[-1] + n = self.group.size() + + qk = mx.concatenate( + [queries, keys], axis=-1 + ) # (batch_dim, seq_dim, q_dim + k_dim) + qk = mx.distributed.all_gather( + qk, group=self.group + ) # (n*batch_dim, seq_dim, q_dim + k_dim) + + qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3) + queries = qk[..., :q_dim].reshape( + batch_dim, seq_dim, -1 + ) # (batch_dim, seq_dim, n * q_dim) + keys = qk[..., q_dim:].reshape( + batch_dim, seq_dim, -1 + ) # (batch_dim, seq_dim, n * k_dim) + + queries = self._original_layer.q_norm(queries) + keys = self._original_layer.k_norm(keys) + + # Split back and take this rank's portion + queries = mx.split(queries, n, axis=-1)[self.group.rank()] + keys = mx.split(keys, n, axis=-1)[self.group.rank()] + + queries = queries.reshape( + batch_dim, seq_dim, self._original_layer.num_attention_heads, -1 + ).transpose(0, 2, 1, 3) + keys = keys.reshape( + batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1 + ).transpose(0, 2, 1, 3) + values = values.reshape( + batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1 + ).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self._original_layer.rope(queries, offset=cache.offset) + keys = self._original_layer.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self._original_layer.rope(queries) + keys = self._original_layer.rope(keys) + + output = scaled_dot_product_attention( + queries, + keys, + values, + cache=cache, + scale=self._original_layer.scale, # type: ignore + mask=mask, + ) + + output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1) + + return self._original_layer.o_proj(output) + + class MiniMaxShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, @@ -623,7 +709,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy): on_timeout: TimeoutCallback | None, ) -> nn.Module: model = cast(MiniMaxModel, model) - rank = self.group.rank() for layer in model.layers: eval_with_timeout( layer.parameters(), timeout_seconds / len(model.layers), on_timeout @@ -634,18 +719,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy): layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj) layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj) - # Shard qk_norm weights if present (must match sharded head count) - if getattr(layer.self_attn, "use_qk_norm", False): - layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore - self.N, axis=-1 - )[rank] - layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore - self.N, axis=-1 - )[rank] - layer.self_attn.num_attention_heads //= self.N layer.self_attn.num_key_value_heads //= self.N + layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType] + # Shard the MoE. Shard in place since the MoE should be responsible # for aggregating the results. self.all_to_sharded_linear_in_place( @@ -670,18 +748,45 @@ class QwenShardingStrategy(TensorParallelShardingStrategy): timeout_seconds: float, on_timeout: TimeoutCallback | None, ) -> nn.Module: - model = cast(Qwen3MoeModel, model) + model = cast(Qwen3MoeModel | Qwen3NextModel, model) for layer in model.layers: eval_with_timeout( layer.parameters(), timeout_seconds / len(model.layers), on_timeout ) # Shard the self attention - layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj) - layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj) - layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj) - layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj) - layer.self_attn.n_heads //= self.N - layer.self_attn.n_kv_heads //= self.N + if isinstance(layer, Qwen3DecoderLayer): + layer.self_attn.q_proj = self.all_to_sharded_linear( + layer.self_attn.q_proj + ) + layer.self_attn.k_proj = self.all_to_sharded_linear( + layer.self_attn.k_proj + ) + layer.self_attn.v_proj = self.all_to_sharded_linear( + layer.self_attn.v_proj + ) + layer.self_attn.o_proj = self.sharded_to_all_linear( + layer.self_attn.o_proj + ) + else: + assert isinstance(layer, Qwen3NextDecoderLayer) + if hasattr(layer, "linear_attn"): + # These layers are fast so we don't shard. This may change in future. + pass + else: + layer.self_attn.q_proj = self.all_to_sharded_linear( + layer.self_attn.q_proj + ) + layer.self_attn.k_proj = self.all_to_sharded_linear( + layer.self_attn.k_proj + ) + layer.self_attn.v_proj = self.all_to_sharded_linear( + layer.self_attn.v_proj + ) + layer.self_attn.o_proj = self.sharded_to_all_linear( + layer.self_attn.o_proj + ) + layer.self_attn.num_attention_heads //= self.N + layer.self_attn.num_key_value_heads //= self.N # Shard the MoE. Shard in place since the MoE should be responsible # for aggregating the results. diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index ee7b1581..cd01fda4 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -171,7 +171,7 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array: def _cache_length(cache: KVCacheType) -> int: """Get the number of tokens in a KV cache.""" # Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache) - return max(c.offset for c in cache) # type: ignore + return max(c.offset for c in cache) def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int: