diff --git a/.mlx_typings/mlx_lm/models/deepseek_v3.pyi b/.mlx_typings/mlx_lm/models/deepseek_v3.pyi index 8e640ac9..142b4a33 100644 --- a/.mlx_typings/mlx_lm/models/deepseek_v3.pyi +++ b/.mlx_typings/mlx_lm/models/deepseek_v3.pyi @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn +from mlx_lm.models.mla import MultiLinear from .base import BaseModelArgs from .switch_layers import SwitchGLU @@ -60,7 +61,10 @@ class DeepseekV3Attention(nn.Module): q_b_proj: nn.Linear kv_a_proj_with_mqa: nn.Linear kv_a_layernorm: nn.RMSNorm - kv_b_proj: nn.Linear + # kv_b_proj: nn.Linear + embed_q: MultiLinear + unembed_out: MultiLinear + o_proj: nn.Linear rope: Any diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 6fc8d8b9..ae85945e 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -511,12 +511,24 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy): layer.self_attn.q_b_proj = self.all_to_sharded_linear( layer.self_attn.q_b_proj ) - layer.self_attn.kv_b_proj = self.all_to_sharded_linear( - layer.self_attn.kv_b_proj - ) + + # layer.self_attn.kv_b_proj = self.all_to_sharded_linear( + # layer.self_attn.kv_b_proj + # ) layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj) layer.self_attn.num_heads //= self.N + # Logic from upstream mlx + num_heads = layer.self_attn.num_heads + sh = self.group.rank() * num_heads + eh = sh + num_heads + + def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array: + return w[sh:eh] + + layer.self_attn.embed_q.apply(shard_heads) + layer.self_attn.unembed_out.apply(shard_heads) + # Shard the MLP if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)): layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj) diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 20c9b9f9..1c15cd2c 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -249,9 +249,9 @@ def mlx_generate( ) -> Generator[GenerationResponse]: # Ensure that generation stats only contains peak memory for this generation mx.reset_peak_memory() + # TODO: Randomise task seed and set in taskparams, instead of hard coding as 42. seed = task.seed or 42 - if seed is not None: - mx.random.seed(seed) + mx.random.seed(seed) # Encode prompt once at the top and fix unmatched think tags all_prompt_tokens = encode_prompt(tokenizer, prompt)