Compare commits

...

28 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
ea5fc8125f Address comment 2026-02-18 16:06:30 +00:00
rltakashige
14da9c3f59 Merge branch 'main' into leo/add-glm-5 2026-02-18 13:31:04 +00:00
Ryuichi Leo Takashige
242067e8eb Update mlx to the nuclear solution 2026-02-18 13:17:39 +00:00
Ryuichi Leo Takashige
03267171ea move mx_barrier to right before prefill 2026-02-18 12:40:31 +00:00
Ryuichi Leo Takashige
080995a409 try a longer join timeout to prevent crashes on instance deletion 2026-02-18 12:34:05 +00:00
Ryuichi Leo Takashige
0d92d001e6 wrong model card 2026-02-18 12:00:40 +00:00
Ryuichi Leo Takashige
0287ad71a1 format and lint 2026-02-18 11:51:30 +00:00
Ryuichi Leo Takashige
82ac16cc0f format and lint 2026-02-18 11:48:45 +00:00
Ryuichi Leo Takashige
d78dd516f3 bad diffs 2 2026-02-18 11:47:55 +00:00
Ryuichi Leo Takashige
5c154d091d bad diffs 2026-02-18 11:45:41 +00:00
rltakashige
e4e0517627 Merge branch 'main' into leo/add-glm-5 2026-02-18 11:44:43 +00:00
Ryuichi Leo Takashige
b28c1d9e92 Merge branch 'main' into leo/add-glm-5
# Conflicts:
#	pyproject.toml
#	src/exo/worker/runner/runner.py
#	uv.lock
2026-02-18 11:43:38 +00:00
Ryuichi Leo Takashige
58e751a930 revert 2026-02-13 20:50:27 +00:00
Ryuichi Leo Takashige
6718da7af3 try optimisations 2026-02-13 19:54:23 +00:00
Ryuichi Leo Takashige
9d9237f68f delete unnecessary files 2026-02-13 19:54:08 +00:00
Ryuichi Leo Takashige
8de4e10736 fix depends for CacheList 2026-02-13 15:29:53 +00:00
Ryuichi Leo Takashige
0de3e486df update glm 5 to use upstream mlx lm 2026-02-13 12:50:08 +00:00
Ryuichi Leo Takashige
ce0eef999e return to mlx lm main 2026-02-13 12:31:07 +00:00
Ryuichi Leo Takashige
20fb6a9acc handle absolute paths 2026-02-13 11:09:46 +00:00
Ryuichi Leo Takashige
4a1234106b add type stub 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2929249147 fix glm eos id 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
837ffc6b97 dont patch glm5 tokenizer? 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2366ed0299 add glm5 model cards 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
c95c088952 convert glm5 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2af1c81cde convert glm5 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
6922dd4ead download faster 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
8c2fb7f130 Add tensor sharding 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
0488cb2967 update pyproject.toml 2026-02-12 23:46:13 +00:00
11 changed files with 127 additions and 31 deletions

View File

@@ -0,0 +1,46 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

View File

@@ -41,7 +41,7 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260217+50487b41"; in
version = let v = "0.30.7.dev20260218+14841977"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
@@ -49,8 +49,8 @@ let
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
};
patches = [

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-8bit-MXFP8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -183,6 +183,7 @@ class ConfigData(BaseModel):
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["GlmMoeDsaForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],

View File

@@ -163,11 +163,14 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
_cache.keys = mx.depends(_cache.keys, output) # type: ignore
if self.is_prefill:
mx.eval(output)
if cache is not None:
mx.eval(cache.keys) # type: ignore
mx.eval(_cache.keys) # type: ignore
if not self.is_prefill:
output = mx.distributed.all_gather(output, group=self.group)[
@@ -307,7 +310,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
last = cache[-1] # type: ignore
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
return logits
@@ -333,7 +338,9 @@ def patch_tensor_model[T](model: T) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
last = cache[-1] # pyright: ignore[reportAny]
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
@@ -547,10 +554,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -581,12 +590,18 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
# Shard the MoE.
else:
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
@@ -779,8 +794,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
# Shard the MoE.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
@@ -893,8 +907,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
# Shard the MoE.
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)

View File

@@ -57,6 +57,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
group: mx.distributed.Group | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -86,6 +87,9 @@ def prefill(
set_pipeline_prefill(model, is_prefill=True)
mx_barrier(group)
logger.info("Starting prefill")
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
@@ -305,16 +309,9 @@ def mlx_generate(
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Starting prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model,
tokenizer,
sampler,
prompt_tokens[:-1],
caches,
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
@@ -331,6 +328,7 @@ def mlx_generate(
think_start = tokenizer.think_start
think_end = tokenizer.think_end
logger.info("Starting decode")
mx_barrier(group)
for completion_tokens, out in enumerate(

View File

@@ -285,10 +285,12 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-4.7-flash" in model_id_lower:
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
# For GLM-5 and GLM-4.7
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
return None

View File

@@ -191,7 +191,7 @@ class RunnerSupervisor:
logger.info("Checking runner's status")
if self.runner_process.is_alive():
logger.info("Runner was found to be alive, attempting to join process")
await to_thread.run_sync(self.runner_process.join, 1)
await to_thread.run_sync(self.runner_process.join, 5)
rc = self.runner_process.exitcode
logger.info(f"RunnerSupervisor exited with exit code {rc}")
if rc == 0:

10
uv.lock generated
View File

@@ -378,7 +378,7 @@ dependencies = [
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1021,7 +1021,7 @@ dependencies = [
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1068,8 +1068,8 @@ cuda13 = [
[[package]]
name = "mlx"
version = "0.30.7.dev20260217+50487b41"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
version = "0.30.7.dev20260218+14841977"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }
resolution-markers = [
"sys_platform == 'darwin'",
]
@@ -1104,7 +1104,7 @@ version = "0.30.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },