Compare commits

...

15 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
6718da7af3 try optimisations 2026-02-13 19:54:23 +00:00
Ryuichi Leo Takashige
9d9237f68f delete unnecessary files 2026-02-13 19:54:08 +00:00
Ryuichi Leo Takashige
8de4e10736 fix depends for CacheList 2026-02-13 15:29:53 +00:00
Ryuichi Leo Takashige
0de3e486df update glm 5 to use upstream mlx lm 2026-02-13 12:50:08 +00:00
Ryuichi Leo Takashige
ce0eef999e return to mlx lm main 2026-02-13 12:31:07 +00:00
Ryuichi Leo Takashige
20fb6a9acc handle absolute paths 2026-02-13 11:09:46 +00:00
Ryuichi Leo Takashige
4a1234106b add type stub 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2929249147 fix glm eos id 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
837ffc6b97 dont patch glm5 tokenizer? 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2366ed0299 add glm5 model cards 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
c95c088952 convert glm5 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2af1c81cde convert glm5 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
6922dd4ead download faster 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
8c2fb7f130 Add tensor sharding 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
0488cb2967 update pyproject.toml 2026-02-12 23:46:13 +00:00
12 changed files with 229 additions and 26 deletions

View File

@@ -0,0 +1,46 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

22
download_glm5_shard.sh Executable file
View File

@@ -0,0 +1,22 @@
#!/bin/bash
# Usage: ./download_glm5_shard.sh <start> <end> [local_dir]
#
# Split across 4 Macs:
# Mac 1: ./download_glm5_shard.sh 1 71
# Mac 2: ./download_glm5_shard.sh 72 141
# Mac 3: ./download_glm5_shard.sh 142 212
# Mac 4: ./download_glm5_shard.sh 213 282
set -euo pipefail
START=${1:?Usage: $0 <start> <end> [local_dir]}
END=${2:?Usage: $0 <start> <end> [local_dir]}
LOCAL_DIR="${3:-GLM-5}"
INCLUDES=()
for i in $(seq "$START" "$END"); do
INCLUDES+=(--include "$(printf 'model-%05d-of-00282.safetensors' "$i")")
done
echo "Downloading safetensors $START-$END to $LOCAL_DIR"
hf download zai-org/GLM-5 "${INCLUDES[@]}" --local-dir "$LOCAL_DIR"

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx==0.30.6",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
@@ -64,6 +64,8 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
#mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", marker = "sys_platform == 'darwin'" }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-8bit"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -7,11 +7,17 @@ from exo.utils.dashboard_path import find_dashboard, find_resources
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
def _resolve_env_path(env_value: str) -> Path:
"""Resolve an environment variable path: absolute paths are used as-is, relative paths are resolved from home."""
p = Path(env_value)
return p if p.is_absolute() else Path.home() / p
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
if _EXO_HOME_ENV is not None:
return Path.home() / _EXO_HOME_ENV
return _resolve_env_path(_EXO_HOME_ENV)
if sys.platform != "linux":
return Path.home() / ".exo"
@@ -31,15 +37,19 @@ _EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR", None)
EXO_MODELS_DIR = (
EXO_DATA_HOME / "models"
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
else _resolve_env_path(_EXO_MODELS_DIR_ENV)
)
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
find_resources()
if _RESOURCES_DIR_ENV is None
else _resolve_env_path(_RESOURCES_DIR_ENV)
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV
find_dashboard()
if _DASHBOARD_DIR_ENV is None
else _resolve_env_path(_DASHBOARD_DIR_ENV)
)
# Log files (data/logs or cache)

View File

@@ -182,6 +182,7 @@ class ConfigData(BaseModel):
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["GlmMoeDsaForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],

View File

@@ -16,7 +16,10 @@ from mlx.nn.layers.distributed import (
from mlx_lm.models.base import (
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import (
DeepseekV3MLP,
group_expert_select, # pyright: ignore[reportAttributeAccessIssue,reportUnknownVariableType]
)
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
@@ -160,11 +163,14 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
dep_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
dep_cache.keys = mx.depends(dep_cache.keys, output) # type: ignore[reportUnknownMemberType]
if self.is_prefill:
mx.eval(output)
if cache is not None:
mx.eval(cache.keys) # type: ignore
mx.eval(dep_cache.keys) # type: ignore
if not self.is_prefill:
output = mx.distributed.all_gather(output, group=self.group)[
@@ -291,7 +297,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
last = cache[-1] # type: ignore
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
return logits
@@ -317,7 +325,9 @@ def patch_tensor_model[T](model: T) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
last = cache[-1] # pyright: ignore[reportAny]
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
@@ -515,6 +525,67 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
def _all_gather_last_axis(x: mx.array, group: mx.distributed.Group) -> mx.array:
"""All-gather along the last axis by transposing, gathering on axis 0, transposing back."""
batch_shape = x.shape[:-1]
d = x.shape[-1]
flat = x.reshape(-1, d).T # (D/N, batch_prod)
gathered = mx.distributed.all_gather(flat, group=group) # (D, batch_prod)
return gathered.T.reshape(*batch_shape, -1) # (..., D)
class AllToAllLinear(nn.Module):
"""Column-sharded linear that all_gathers output to reconstruct full result.
Wraps an AllToShardedLinear (created by shard_linear) and adds all_gather
on the output so callers see the full unsharded output.
"""
def __init__(self, sharded_linear: nn.Module, group: mx.distributed.Group):
super().__init__()
self.inner = sharded_linear
self.group = group
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
out = cast(mx.array, self.inner(x, *args, **kwargs))
return _all_gather_last_axis(out, self.group)
class ShardedMoEGate(nn.Module):
"""MoEGate with column-sharded weight and all_gather for routing scores."""
def __init__(self, gate: nn.Module, group: mx.distributed.Group):
super().__init__()
rank = group.rank()
n = group.size()
n_experts = int(gate.weight.shape[0]) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
shard_size = n_experts // n
start = rank * shard_size
end = start + shard_size
self.weight: mx.array = gate.weight[start:end] # pyright: ignore[reportUnknownMemberType]
self.e_score_correction_bias: mx.array = gate.e_score_correction_bias # pyright: ignore[reportUnknownMemberType]
self.top_k: int = gate.top_k # pyright: ignore[reportUnknownMemberType]
self.norm_topk_prob: bool = gate.norm_topk_prob # pyright: ignore[reportUnknownMemberType]
self.n_group: int = gate.n_group # pyright: ignore[reportUnknownMemberType]
self.topk_group: int = gate.topk_group # pyright: ignore[reportUnknownMemberType]
self.routed_scaling_factor: float = gate.routed_scaling_factor # pyright: ignore[reportUnknownMemberType]
self.group = group
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]:
scores_shard = x @ self.weight.T
scores = _all_gather_last_axis(scores_shard, self.group)
return group_expert_select( # pyright: ignore[reportUnknownVariableType]
scores,
self.e_score_correction_bias,
self.top_k,
self.n_group,
self.topk_group,
self.routed_scaling_factor,
self.norm_topk_prob,
)
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -527,7 +598,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is not None:
layer.self_attn.q_a_proj = AllToAllLinear( # pyright: ignore[reportAttributeAccessIssue]
self.all_to_sharded_linear(layer.self_attn.q_a_proj), self.group
)
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
@@ -560,9 +636,21 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
# Shard the gate weight for reduced compute
n_experts = int(layer.mlp.gate.weight.shape[0])
if n_experts % self.N == 0:
layer.mlp.gate = ShardedMoEGate(layer.mlp.gate, self.group) # pyright: ignore[reportAttributeAccessIssue]
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)

View File

@@ -311,10 +311,12 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-4.7-flash" in model_id_lower:
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
# For GLM-5 and GLM-4.7
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
return None

View File

@@ -295,8 +295,8 @@ def main(
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
elif "glm-4" in shard_metadata.model_card.model_id.lower():
patch_glm4_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
@@ -863,7 +863,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
def patch_glm4_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""

12
uv.lock generated
View File

@@ -416,9 +416,9 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", specifier = "==0.30.6" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
{ name = "msgspec", specifier = ">=0.19.0" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
@@ -1098,8 +1098,8 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
version = "0.30.7"
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#bcf630614ffb5624bcb19870a7bcb0d847e6e98f" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
@@ -1109,10 +1109,6 @@ dependencies = [
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"