mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-14 08:04:15 -05:00
Compare commits
16 Commits
main
...
leo/add-gl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
58e751a930 | ||
|
|
6718da7af3 | ||
|
|
9d9237f68f | ||
|
|
8de4e10736 | ||
|
|
0de3e486df | ||
|
|
ce0eef999e | ||
|
|
20fb6a9acc | ||
|
|
4a1234106b | ||
|
|
2929249147 | ||
|
|
837ffc6b97 | ||
|
|
2366ed0299 | ||
|
|
c95c088952 | ||
|
|
2af1c81cde | ||
|
|
6922dd4ead | ||
|
|
8c2fb7f130 | ||
|
|
0488cb2967 |
46
.mlx_typings/mlx_lm/models/glm_moe_dsa.pyi
Normal file
46
.mlx_typings/mlx_lm/models/glm_moe_dsa.pyi
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .deepseek_v32 import Model as DSV32Model
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
index_head_dim: int
|
||||
index_n_heads: int
|
||||
index_topk: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_parameters: Dict[str, Any]
|
||||
attention_bias: bool
|
||||
rope_scaling: Dict[str, Any] | None
|
||||
rope_theta: float | None
|
||||
|
||||
class Model(DSV32Model):
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
@@ -1,151 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
vocab_size: int
|
||||
num_attention_heads: int
|
||||
num_attention_groups: int
|
||||
head_dim: int
|
||||
intermediate_size: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
max_position_embeddings: int
|
||||
sliding_window: int
|
||||
layer_types: Optional[List[str]]
|
||||
yarn_only_types: Optional[List[str]]
|
||||
partial_rotary_factors: Optional[List[float]]
|
||||
attention_other_setting: Optional[Dict[str, Any]]
|
||||
use_head_wise_attn_gate: bool
|
||||
moe_num_experts: int
|
||||
moe_top_k: int
|
||||
moe_intermediate_size: int
|
||||
share_expert_dim: int
|
||||
moe_layers_enum: Optional[str]
|
||||
moe_router_scaling_factor: float
|
||||
norm_expert_weight: bool
|
||||
swiglu_limits: Optional[List[float]]
|
||||
swiglu_limits_shared: Optional[List[float]]
|
||||
tie_word_embeddings: bool
|
||||
|
||||
class Step3p5MLP(nn.Module):
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
limit: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Step3p5MoEGate(nn.Module):
|
||||
top_k: int
|
||||
n_routed_experts: int
|
||||
routed_scaling_factor: float
|
||||
norm_topk_prob: bool
|
||||
gate: nn.Linear
|
||||
router_bias: mx.array
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class Step3p5MoE(nn.Module):
|
||||
gate: Step3p5MoEGate
|
||||
switch_mlp: SwitchGLU
|
||||
share_expert: Step3p5MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Step3p5Attention(nn.Module):
|
||||
is_sliding: bool
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
q_norm: nn.Module
|
||||
k_norm: nn.Module
|
||||
use_head_wise_attn_gate: bool
|
||||
g_proj: nn.Linear
|
||||
rope: nn.Module
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Step3p5DecoderLayer(nn.Module):
|
||||
self_attn: Step3p5Attention
|
||||
is_sliding: bool
|
||||
is_moe_layer: bool
|
||||
mlp: Step3p5MLP | Step3p5MoE
|
||||
input_layernorm: nn.Module
|
||||
post_attention_layernorm: nn.Module
|
||||
|
||||
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Step3p5Model(nn.Module):
|
||||
args: ModelArgs
|
||||
vocab_size: int
|
||||
num_layers: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[Step3p5DecoderLayer]
|
||||
norm: nn.Module
|
||||
_swa_idx: Optional[int]
|
||||
_full_idx: Optional[int]
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[List[Any]] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
args: ModelArgs
|
||||
model_type: str
|
||||
model: Step3p5Model
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Any]] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
|
||||
@property
|
||||
def layers(self) -> list[Step3p5DecoderLayer]: ...
|
||||
def make_cache(self) -> list[Any]: ...
|
||||
@property
|
||||
def cast_predicate(self) -> Any: ...
|
||||
@property
|
||||
def quant_predicate(self) -> Any: ...
|
||||
@@ -806,7 +806,6 @@
|
||||
isFavorite={favorites.has(group.id)}
|
||||
{selectedModelId}
|
||||
{canModelFit}
|
||||
{getModelFitStatus}
|
||||
onToggleExpand={() => toggleGroupExpanded(group.id)}
|
||||
onSelectModel={handleSelect}
|
||||
{onToggleFavorite}
|
||||
|
||||
22
download_glm5_shard.sh
Executable file
22
download_glm5_shard.sh
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
# Usage: ./download_glm5_shard.sh <start> <end> [local_dir]
|
||||
#
|
||||
# Split across 4 Macs:
|
||||
# Mac 1: ./download_glm5_shard.sh 1 71
|
||||
# Mac 2: ./download_glm5_shard.sh 72 141
|
||||
# Mac 3: ./download_glm5_shard.sh 142 212
|
||||
# Mac 4: ./download_glm5_shard.sh 213 282
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
START=${1:?Usage: $0 <start> <end> [local_dir]}
|
||||
END=${2:?Usage: $0 <start> <end> [local_dir]}
|
||||
LOCAL_DIR="${3:-GLM-5}"
|
||||
|
||||
INCLUDES=()
|
||||
for i in $(seq "$START" "$END"); do
|
||||
INCLUDES+=(--include "$(printf 'model-%05d-of-00282.safetensors' "$i")")
|
||||
done
|
||||
|
||||
echo "Downloading safetensors $START-$END to $LOCAL_DIR"
|
||||
hf download zai-org/GLM-5 "${INCLUDES[@]}" --local-dir "$LOCAL_DIR"
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.6; sys_platform == 'darwin'",
|
||||
"mlx==0.30.6",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
@@ -64,6 +64,8 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
#mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", marker = "sys_platform == 'darwin'" }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-8bit"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 790517400864
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-MXFP4-Q8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "MXFP4-Q8"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405478939008
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1487822475264
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-4bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "4bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-6bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "6bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "8bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
@@ -7,11 +7,17 @@ from exo.utils.dashboard_path import find_dashboard, find_resources
|
||||
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
|
||||
|
||||
|
||||
def _resolve_env_path(env_value: str) -> Path:
|
||||
"""Resolve an environment variable path: absolute paths are used as-is, relative paths are resolved from home."""
|
||||
p = Path(env_value)
|
||||
return p if p.is_absolute() else Path.home() / p
|
||||
|
||||
|
||||
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
|
||||
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
|
||||
|
||||
if _EXO_HOME_ENV is not None:
|
||||
return Path.home() / _EXO_HOME_ENV
|
||||
return _resolve_env_path(_EXO_HOME_ENV)
|
||||
|
||||
if sys.platform != "linux":
|
||||
return Path.home() / ".exo"
|
||||
@@ -31,15 +37,19 @@ _EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR", None)
|
||||
EXO_MODELS_DIR = (
|
||||
EXO_DATA_HOME / "models"
|
||||
if _EXO_MODELS_DIR_ENV is None
|
||||
else Path.home() / _EXO_MODELS_DIR_ENV
|
||||
else _resolve_env_path(_EXO_MODELS_DIR_ENV)
|
||||
)
|
||||
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
|
||||
RESOURCES_DIR = (
|
||||
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
|
||||
find_resources()
|
||||
if _RESOURCES_DIR_ENV is None
|
||||
else _resolve_env_path(_RESOURCES_DIR_ENV)
|
||||
)
|
||||
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
|
||||
DASHBOARD_DIR = (
|
||||
find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV
|
||||
find_dashboard()
|
||||
if _DASHBOARD_DIR_ENV is None
|
||||
else _resolve_env_path(_DASHBOARD_DIR_ENV)
|
||||
)
|
||||
|
||||
# Log files (data/logs or cache)
|
||||
|
||||
@@ -182,6 +182,7 @@ class ConfigData(BaseModel):
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["GlmMoeDsaForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
@@ -189,7 +190,6 @@ class ConfigData(BaseModel):
|
||||
["MiniMaxM2ForCausalLM"],
|
||||
["LlamaForCausalLM"],
|
||||
["GptOssForCausalLM"],
|
||||
["Step3p5ForCausalLM"],
|
||||
]
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -35,9 +35,6 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.step3p5 import Model as Step35Model
|
||||
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
|
||||
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
@@ -163,11 +160,14 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
|
||||
# doesn't have .keys directly; access via first sub-cache.
|
||||
dep_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
if self.is_prefill:
|
||||
mx.eval(output)
|
||||
if cache is not None:
|
||||
mx.eval(cache.keys) # type: ignore
|
||||
mx.eval(dep_cache.keys) # type: ignore
|
||||
|
||||
if not self.is_prefill:
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
@@ -267,19 +267,6 @@ def pipeline_auto_parallel(
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, Step35InnerModel):
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
sliding_layers = [
|
||||
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
|
||||
]
|
||||
full_layers = [
|
||||
i
|
||||
for i, layer in enumerate(layers)
|
||||
if not getattr(layer, "is_sliding", True)
|
||||
]
|
||||
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
|
||||
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
assert isinstance(layers, list), (
|
||||
@@ -307,7 +294,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
last = cache[-1] # type: ignore
|
||||
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
|
||||
|
||||
return logits
|
||||
|
||||
@@ -333,7 +322,9 @@ def patch_tensor_model[T](model: T) -> T:
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
last = cache[-1] # pyright: ignore[reportAny]
|
||||
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
|
||||
return logits
|
||||
|
||||
@@ -443,14 +434,6 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Step35Model):
|
||||
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -547,11 +530,13 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
|
||||
# Shard attention heads
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
@@ -561,10 +546,11 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
@@ -581,12 +567,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
else:
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
@@ -1005,46 +996,3 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
class Step35ShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Step35Model, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_heads //= self.N
|
||||
layer.self_attn.num_kv_heads //= self.N
|
||||
|
||||
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
|
||||
layer.self_attn.g_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.g_proj
|
||||
)
|
||||
|
||||
if isinstance(layer.mlp, Step35MLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
else:
|
||||
layer.mlp.sharding_group = self.group
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
@@ -311,10 +311,12 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm-4.7-flash" in model_id_lower:
|
||||
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
|
||||
# For GLM-5 and GLM-4.7
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif "glm" in model_id_lower:
|
||||
# For GLM-4.5 and older
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
|
||||
@@ -295,8 +295,8 @@ def main(
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
elif "glm-4" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm4_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
@@ -863,7 +863,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
def patch_glm4_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
|
||||
12
uv.lock
generated
12
uv.lock
generated
@@ -416,9 +416,9 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.5" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
|
||||
{ name = "mlx", specifier = "==0.30.6" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
|
||||
{ name = "msgspec", specifier = ">=0.19.0" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
@@ -1098,8 +1098,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
version = "0.30.7"
|
||||
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#bcf630614ffb5624bcb19870a7bcb0d847e6e98f" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -1109,10 +1109,6 @@ dependencies = [
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
|
||||
Reference in New Issue
Block a user