mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-08 13:13:09 -05:00
update deepseek sharding and fix probable issue with seed
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.mla import MultiLinear
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
@@ -60,7 +61,10 @@ class DeepseekV3Attention(nn.Module):
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
# kv_b_proj: nn.Linear
|
||||
embed_q: MultiLinear
|
||||
unembed_out: MultiLinear
|
||||
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
|
||||
@@ -511,12 +511,24 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
|
||||
# layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
# layer.self_attn.kv_b_proj
|
||||
# )
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
|
||||
@@ -249,9 +249,9 @@ def mlx_generate(
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
|
||||
seed = task.seed or 42
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Encode prompt once at the top and fix unmatched think tags
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
|
||||
Reference in New Issue
Block a user