update deepseek sharding and fix probable issue with seed

This commit is contained in:
Ryuichi Leo Takashige
2026-02-05 18:34:18 +00:00
parent 8c397b7341
commit 492029c86a
3 changed files with 22 additions and 6 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.mla import MultiLinear
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@@ -60,7 +61,10 @@ class DeepseekV3Attention(nn.Module):
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
# kv_b_proj: nn.Linear
embed_q: MultiLinear
unembed_out: MultiLinear
o_proj: nn.Linear
rope: Any

View File

@@ -511,12 +511,24 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
layer.self_attn.kv_b_proj
)
# layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
# layer.self_attn.kv_b_proj
# )
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Logic from upstream mlx
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)

View File

@@ -249,9 +249,9 @@ def mlx_generate(
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
seed = task.seed or 42
if seed is not None:
mx.random.seed(seed)
mx.random.seed(seed)
# Encode prompt once at the top and fix unmatched think tags
all_prompt_tokens = encode_prompt(tokenizer, prompt)