Compare commits

..

11 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
ce0eef999e return to mlx lm main 2026-02-13 12:31:07 +00:00
Ryuichi Leo Takashige
20fb6a9acc handle absolute paths 2026-02-13 11:09:46 +00:00
Ryuichi Leo Takashige
4a1234106b add type stub 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2929249147 fix glm eos id 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
837ffc6b97 dont patch glm5 tokenizer? 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2366ed0299 add glm5 model cards 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
c95c088952 convert glm5 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
2af1c81cde convert glm5 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
6922dd4ead download faster 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
8c2fb7f130 Add tensor sharding 2026-02-12 23:46:13 +00:00
Ryuichi Leo Takashige
0488cb2967 update pyproject.toml 2026-02-12 23:46:13 +00:00
25 changed files with 549 additions and 582 deletions

View File

@@ -1,15 +0,0 @@
.venv/
.direnv/
target/
.git/
.idea/
.pytest_cache/
.ruff_cache/
dashboard/node_modules/
dashboard/.svelte-kit/
dashboard/build/
dist/
*.pdb
**/__pycache__
**/.DS_Store
.mlx_typings/

View File

@@ -1,29 +0,0 @@
name: e2e-tests
on:
push:
pull_request:
branches:
- staging
- main
jobs:
e2e:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
/opt/microsoft /opt/az
docker system prune -af
df -h /
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Run E2E tests
run: python3 e2e/run_all.py

View File

@@ -0,0 +1,46 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

110
convert_glm5_mxfp4_q8.py Normal file
View File

@@ -0,0 +1,110 @@
"""Convert GLM-5 to MXFP4-Q8: experts use MXFP4, dense layers use 8-bit affine.
Matches the same quantization scheme used for mlx-community/GPT-OSS-MXFP4-Q8:
- Global default: MXFP4 (bits=4, group_size=32) for expert/switch_mlp weights
- Per-layer overrides: affine Q8 (bits=8, group_size=64) for attention, embeddings,
shared experts, dense MLP, lm_head
Usage:
python convert_glm5_mxfp4_q8.py \
--hf-path ~/.exo/models/zai-org--GLM-5 \
--mlx-path ~/.exo/models/mlx-community--GLM-5-MXFP4-Q8 \
--upload-repo mlx-community/GLM-5-MXFP4-Q8
"""
import argparse
import copy
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map_with_path
from mlx_lm.utils import compute_bits_per_weight, load, save, upload_to_hub
# Global default = MXFP4 for expert weights (switch_mlp)
MXFP4_PARAMS = {"group_size": 32, "bits": 4, "mode": "mxfp4"}
# Per-layer override = affine Q8 for everything else
AFFINE_Q8_PARAMS = {"group_size": 64, "bits": 8, "mode": "affine"}
def mxfp4_q8_predicate(path: str, module: nn.Module) -> dict | bool:
"""MXFP4 for expert (switch_mlp) weights, 8-bit affine for everything else."""
if not hasattr(module, "to_quantized"):
return False
# Expert layers get MXFP4 (global default)
if "switch_mlp" in path:
if module.weight.shape[-1] % MXFP4_PARAMS["group_size"] != 0:
return False
return MXFP4_PARAMS
# Everything else gets 8-bit affine
if module.weight.shape[-1] % AFFINE_Q8_PARAMS["group_size"] != 0:
return False
return AFFINE_Q8_PARAMS
def main():
parser = argparse.ArgumentParser(description="Convert GLM-5 to MXFP4-Q8")
parser.add_argument("--hf-path", required=True, help="Path to HF model")
parser.add_argument("--mlx-path", required=True, default="mlx_model", help="Output path")
parser.add_argument("--upload-repo", default=None, help="HF repo to upload to")
args = parser.parse_args()
mlx_path = Path(args.mlx_path)
if mlx_path.exists():
raise ValueError(f"Output path {mlx_path} already exists. Delete it first.")
print("[INFO] Loading")
model, tokenizer, config = load(
args.hf_path,
return_config=True,
lazy=True,
)
# Apply dtype from config
dtype = config.get("torch_dtype", None)
if dtype in ("float16", "bfloat16", "float32"):
print(f"[INFO] Using dtype: {dtype}")
dt = getattr(mx, dtype)
cast_predicate = getattr(model, "cast_predicate", lambda _: True)
def set_dtype(k, v):
if cast_predicate(k) and mx.issubdtype(v.dtype, mx.floating):
return v.astype(dt)
return v
model.update(tree_map_with_path(set_dtype, model.parameters()))
# Build quantization config matching GPT-OSS format:
# global default = mxfp4, per-layer overrides for Q8 layers
quantized_config = copy.deepcopy(config)
quantized_config["quantization"] = {**MXFP4_PARAMS}
def tracked_predicate(path: str, module: nn.Module) -> dict | bool:
result = mxfp4_q8_predicate(path, module)
if isinstance(result, dict) and result is not MXFP4_PARAMS:
# Only store overrides for non-default (Q8) layers
quantized_config["quantization"][path] = result
return result
print("[INFO] Quantizing (MXFP4 experts + Q8 dense)")
nn.quantize(
model,
class_predicate=tracked_predicate,
)
# Duplicate for HF compat (same as mlx_lm.convert does)
quantized_config["quantization_config"] = quantized_config["quantization"]
bpw = compute_bits_per_weight(model)
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
save(mlx_path, args.hf_path, model, tokenizer, quantized_config)
if args.upload_repo:
upload_to_hub(mlx_path, args.upload_repo)
if __name__ == "__main__":
main()

236
download_glm5_shard.py Normal file
View File

@@ -0,0 +1,236 @@
#!/usr/bin/env python3
"""
Fast parallel downloader for a range of safetensors files from zai-org/GLM-5.
Uses aiohttp with 8 concurrent downloads and 8MB chunks (same approach as exo).
Usage:
python download_glm5_shard.py <start> <end> [--dir GLM-5] [--jobs 8]
Split across 2 Macs:
Mac 1: python download_glm5_shard.py 1 141
Mac 2: python download_glm5_shard.py 142 282
Split across 4 Macs:
Mac 1: python download_glm5_shard.py 1 71
Mac 2: python download_glm5_shard.py 72 141
Mac 3: python download_glm5_shard.py 142 212
Mac 4: python download_glm5_shard.py 213 282
"""
import argparse
import asyncio
import os
import ssl
import sys
import time
import aiofiles
import aiohttp
import certifi
REPO = "zai-org/GLM-5"
TOTAL_SHARDS = 282
CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
def get_token() -> str | None:
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
if token:
return token
token_path = os.path.expanduser("~/.cache/huggingface/token")
if os.path.exists(token_path):
with open(token_path) as f:
return f.read().strip() or None
return None
def make_session() -> aiohttp.ClientSession:
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
conn = aiohttp.TCPConnector(ssl=ssl_ctx, limit=0)
timeout = aiohttp.ClientTimeout(total=1800, connect=60, sock_read=60)
return aiohttp.ClientSession(connector=conn, timeout=timeout)
class ProgressTracker:
def __init__(self, total_files: int):
self.total_files = total_files
self.completed_files = 0
self.skipped_files = 0
self.failed_files = 0
self.total_bytes = 0
self.downloaded_bytes = 0
self.start_time = time.monotonic()
self.lock = asyncio.Lock()
# per-file tracking: filename -> (downloaded, total)
self.active: dict[str, tuple[int, int]] = {}
async def file_skip(self, filename: str) -> None:
async with self.lock:
self.skipped_files += 1
self._render()
async def file_fail(self, filename: str) -> None:
async with self.lock:
self.active.pop(filename, None)
self.failed_files += 1
self._render()
async def file_start(self, filename: str, total: int, resumed: int) -> None:
async with self.lock:
self.total_bytes += total
self.downloaded_bytes += resumed
self.active[filename] = (resumed, total)
self._render()
async def file_progress(self, filename: str, downloaded: int, total: int) -> None:
async with self.lock:
prev, _ = self.active.get(filename, (0, total))
self.downloaded_bytes += downloaded - prev
self.active[filename] = (downloaded, total)
self._render()
async def file_done(self, filename: str) -> None:
async with self.lock:
self.active.pop(filename, None)
self.completed_files += 1
self._render()
def _render(self) -> None:
elapsed = time.monotonic() - self.start_time
speed = self.downloaded_bytes / elapsed if elapsed > 0 else 0
done = self.completed_files + self.skipped_files
remaining_bytes = self.total_bytes - self.downloaded_bytes
eta = remaining_bytes / speed if speed > 0 else 0
# Overall progress bar
pct = self.downloaded_bytes / self.total_bytes * 100 if self.total_bytes else 0
bar_width = 30
filled = int(bar_width * pct / 100)
bar = "=" * filled + ">" * (1 if filled < bar_width else 0) + " " * (bar_width - filled - 1)
# Active file names (short)
active_names = []
for fn, (dl, tot) in sorted(self.active.items()):
short = fn.replace("model-", "").replace(f"-of-{TOTAL_SHARDS:05d}.safetensors", "")
file_pct = dl / tot * 100 if tot else 0
active_names.append(f"{short}:{file_pct:.0f}%")
active_str = " ".join(active_names[:8])
eta_m, eta_s = divmod(int(eta), 60)
eta_h, eta_m = divmod(eta_m, 60)
eta_str = f"{eta_h}h{eta_m:02d}m" if eta_h else f"{eta_m}m{eta_s:02d}s"
line = (
f"\r[{bar}] {pct:5.1f}% "
f"{done}/{self.total_files} files "
f"{self.downloaded_bytes / 1024**3:.1f}/{self.total_bytes / 1024**3:.1f} GB "
f"{speed / 1024**2:.1f} MB/s "
f"ETA {eta_str} "
f"{active_str}"
)
# Pad to clear previous line, truncate to terminal width
try:
cols = os.get_terminal_size().columns
except OSError:
cols = 120
line = line[:cols].ljust(cols)
sys.stderr.write(line)
sys.stderr.flush()
def final_summary(self) -> None:
elapsed = time.monotonic() - self.start_time
speed = self.downloaded_bytes / elapsed if elapsed > 0 else 0
mins, secs = divmod(int(elapsed), 60)
sys.stderr.write("\n")
print(
f"Done: {self.completed_files} downloaded, {self.skipped_files} skipped, "
f"{self.failed_files} failed. "
f"{self.downloaded_bytes / 1024**3:.1f} GB in {mins}m{secs:02d}s "
f"({speed / 1024**2:.1f} MB/s avg)"
)
async def download_file(
session: aiohttp.ClientSession,
filename: str,
target_dir: str,
headers: dict[str, str],
sem: asyncio.Semaphore,
progress: ProgressTracker,
) -> None:
async with sem:
url = f"{HF_ENDPOINT}/{REPO}/resolve/main/{filename}"
target = os.path.join(target_dir, filename)
partial = target + ".partial"
os.makedirs(os.path.dirname(target), exist_ok=True)
if os.path.exists(target):
await progress.file_skip(filename)
return
resume_pos = 0
req_headers = dict(headers)
if os.path.exists(partial):
resume_pos = os.path.getsize(partial)
req_headers["Range"] = f"bytes={resume_pos}-"
async with session.get(url, headers=req_headers) as r:
if r.status == 416:
os.rename(partial, target)
await progress.file_skip(filename)
return
if r.status not in (200, 206):
await progress.file_fail(filename)
return
total = int(r.headers.get("Content-Length", 0)) + resume_pos
downloaded = resume_pos
await progress.file_start(filename, total, resume_pos)
async with aiofiles.open(partial, "ab" if resume_pos else "wb") as f:
while True:
chunk = await r.content.read(CHUNK_SIZE)
if not chunk:
break
await f.write(chunk)
downloaded += len(chunk)
await progress.file_progress(filename, downloaded, total)
os.rename(partial, target)
await progress.file_done(filename)
async def main() -> None:
parser = argparse.ArgumentParser(description="Fast parallel GLM-5 shard downloader")
parser.add_argument("start", type=int, help="First shard number (1-based)")
parser.add_argument("end", type=int, help="Last shard number (inclusive)")
parser.add_argument("--dir", default="GLM-5", help="Target directory (default: GLM-5)")
parser.add_argument("--jobs", type=int, default=8, help="Parallel downloads (default: 8)")
args = parser.parse_args()
files = [
f"model-{i:05d}-of-{TOTAL_SHARDS:05d}.safetensors"
for i in range(args.start, args.end + 1)
]
headers: dict[str, str] = {"Accept-Encoding": "identity"}
token = get_token()
if token:
headers["Authorization"] = f"Bearer {token}"
print(f"Downloading {len(files)} files ({args.start}-{args.end}) to {args.dir}/ with {args.jobs} parallel jobs")
progress = ProgressTracker(len(files))
sem = asyncio.Semaphore(args.jobs)
async with make_session() as session:
await asyncio.gather(*[
download_file(session, f, args.dir, headers, sem, progress)
for f in files
])
progress.final_summary()
if __name__ == "__main__":
asyncio.run(main())

22
download_glm5_shard.sh Executable file
View File

@@ -0,0 +1,22 @@
#!/bin/bash
# Usage: ./download_glm5_shard.sh <start> <end> [local_dir]
#
# Split across 4 Macs:
# Mac 1: ./download_glm5_shard.sh 1 71
# Mac 2: ./download_glm5_shard.sh 72 141
# Mac 3: ./download_glm5_shard.sh 142 212
# Mac 4: ./download_glm5_shard.sh 213 282
set -euo pipefail
START=${1:?Usage: $0 <start> <end> [local_dir]}
END=${2:?Usage: $0 <start> <end> [local_dir]}
LOCAL_DIR="${3:-GLM-5}"
INCLUDES=()
for i in $(seq "$START" "$END"); do
INCLUDES+=(--include "$(printf 'model-%05d-of-00282.safetensors' "$i")")
done
echo "Downloading safetensors $START-$END to $LOCAL_DIR"
hf download zai-org/GLM-5 "${INCLUDES[@]}" --local-dir "$LOCAL_DIR"

View File

@@ -1,53 +0,0 @@
# Stage 1: Build the dashboard
FROM node:22-slim AS dashboard
WORKDIR /app/dashboard
COPY dashboard/package.json dashboard/package-lock.json ./
RUN npm ci
COPY dashboard/ .
RUN npm run build
# Stage 2: Build and run exo
FROM python:3.13-slim
# Install system dependencies
RUN apt-get update && apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
curl \
protobuf-compiler \
iptables \
&& rm -rf /var/lib/apt/lists/*
# Install Rust nightly
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
ENV PATH="/root/.cargo/bin:${PATH}"
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files first for better layer caching
COPY pyproject.toml Cargo.toml uv.lock README.md ./
COPY rust/ ./rust/
COPY bench/pyproject.toml ./bench/pyproject.toml
# Copy source and resources
COPY src/ ./src/
COPY resources/ ./resources/
# Copy built dashboard from stage 1
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
# Install Python deps and build Rust bindings, then clean up build artifacts
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
chmod +x /usr/bin/g++
CMD [".venv/bin/exo", "-v"]

View File

@@ -1,182 +0,0 @@
"""Shared E2E test infrastructure for exo cluster tests."""
import asyncio
import json
import os
import sys
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
E2E_DIR = Path(__file__).parent.resolve()
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
class Cluster:
"""Async wrapper around a docker compose exo cluster."""
def __init__(self, name: str, overrides: list[str] | None = None):
self.name = name
self.project = f"e2e-{name}"
compose_files = [str(E2E_DIR / "docker-compose.yml")]
for path in overrides or []:
compose_files.append(str(E2E_DIR / path))
self._compose_base = [
"docker",
"compose",
"-p",
self.project,
*[arg for f in compose_files for arg in ("-f", f)],
]
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
await self.stop()
async def _run(self, *args: str, check: bool = True) -> str:
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
print(output, file=sys.stderr)
raise RuntimeError(
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
)
return output
async def build(self):
print(" Building images...")
await self._run("build", "--quiet")
async def start(self):
print(" Starting cluster...")
await self._run("up", "-d")
async def stop(self):
print(" Cleaning up...")
await self._run("down", "--timeout", "5", check=False)
async def logs(self) -> str:
return await self._run("logs", check=False)
async def exec(
self, service: str, *cmd: str, check: bool = True
) -> tuple[int, str]:
"""Run a command inside a running container. Returns (returncode, output)."""
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
"exec",
"-T",
service,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
raise RuntimeError(
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
)
return proc.returncode, output
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
"""Poll check_fn every 2s until it returns True or timeout expires."""
print(f" Waiting for {description}...")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if await check_fn():
print(f" {description}")
return
await asyncio.sleep(2)
output = await self.logs()
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
raise TimeoutError(f"Timed out waiting for {description}")
async def assert_healthy(self):
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
async def both_nodes_started():
log = await self.logs()
return log.count("Starting node") >= 2
async def nodes_discovered():
log = await self.logs()
return log.count("ConnectionMessageType.Connected") >= 2
async def master_elected():
log = await self.logs()
return "demoting self" in log
async def api_responding():
try:
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
return resp.status == 200
except (URLError, OSError):
return False
await self.wait_for("Both nodes started", both_nodes_started)
await self.wait_for("Nodes discovered each other", nodes_discovered)
await self.wait_for("Master election resolved", master_elected)
await self.wait_for("API responding", api_responding)
async def _api(
self, method: str, path: str, body: dict | None = None, timeout: int = 30
) -> dict:
"""Make an API request to the cluster. Returns parsed JSON."""
url = f"http://localhost:52415{path}"
data = json.dumps(body).encode() if body else None
req = Request(
url, data=data, headers={"Content-Type": "application/json"}, method=method
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda: urlopen(req, timeout=timeout).read()
)
return json.loads(resp_bytes)
async def place_model(self, model: str, timeout: int = 600):
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
await self._api("POST", "/place_instance", {"model_id": model})
async def model_ready():
try:
resp = await self._api("GET", "/v1/models")
return any(m.get("id") == model for m in resp.get("data", []))
except Exception:
return False
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
async def chat(
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
) -> dict:
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
deadline = asyncio.get_event_loop().time() + timeout
last_error = None
while asyncio.get_event_loop().time() < deadline:
try:
req = Request(
"http://localhost:52415/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda r=req: urlopen(r, timeout=300).read()
)
return json.loads(resp_bytes)
except Exception as e:
last_error = e
await asyncio.sleep(5)
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")

View File

@@ -1,18 +0,0 @@
services:
exo-node-1:
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]
ports:
- "52415:52415"
exo-node-2:
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]

View File

@@ -1,75 +0,0 @@
#!/usr/bin/env python3
"""Discovers and runs all E2E tests in e2e/test_*.py.
Tests with '# slow' on the first line of their docstring are skipped
unless --slow is passed or E2E_SLOW=1 is set.
"""
import os
import subprocess
import sys
from pathlib import Path
E2E_DIR = Path(__file__).parent.resolve()
def is_slow(test_file: Path) -> bool:
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
with open(test_file) as f:
for line in f:
if line.strip().startswith("#"):
continue
if line.strip().startswith('"""') or line.strip().startswith("'''"):
# Read into the docstring
for doc_line in f:
if "slow" in doc_line.lower() and doc_line.strip().startswith(
"slow"
):
return True
if '"""' in doc_line or "'''" in doc_line:
break
break
return False
def main():
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
test_files = sorted(E2E_DIR.glob("test_*.py"))
if not test_files:
print("No test files found")
sys.exit(1)
passed = 0
failed = 0
skipped = 0
failures = []
for test_file in test_files:
name = test_file.stem
if is_slow(test_file) and not run_slow:
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
skipped += 1
continue
print(f"=== {name} ===")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
failed += 1
failures.append(name)
print()
total = passed + failed + skipped
print("================================")
print(
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
)
if failed:
print(f"Failed: {' '.join(failures)}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,8 +0,0 @@
{
"model": "mlx-community/Qwen3-0.6B-4bit",
"seed": 42,
"temperature": 0,
"prompt": "What is 2+2? Reply with just the number.",
"max_tokens": 32,
"content": "<think>\nOkay, so I need to figure out what 2+2 is. Let me think. Well, if you add 2 and 2 together"
}

View File

@@ -1,22 +0,0 @@
"""Test: Basic cluster formation.
Verifies two nodes discover each other, elect a master, and the API responds.
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster("cluster_formation") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print("PASSED: cluster_formation")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,82 +0,0 @@
"""Test: Deterministic inference output (snapshot test).
slow
Sends a chat completion request with a fixed seed and temperature=0,
then verifies the output matches a known-good snapshot. This ensures
inference produces consistent results across runs.
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "What is 2+2? Reply with just the number."
MAX_TOKENS = 32
SNAPSHOT_FILE = Path(__file__).parent / "snapshots" / "inference.json"
async def main():
async with Cluster("inference_snapshot") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Launch the model instance (triggers download + placement)
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED}, temperature=0)...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
# Load or create snapshot
if SNAPSHOT_FILE.exists():
snapshot = json.loads(SNAPSHOT_FILE.read_text())
expected = snapshot["content"]
assert content == expected, (
f"Snapshot mismatch!\n"
f" Expected: {expected!r}\n"
f" Got: {content!r}\n"
f" Delete {SNAPSHOT_FILE} to regenerate."
)
print(" Output matches snapshot")
else:
SNAPSHOT_FILE.parent.mkdir(parents=True, exist_ok=True)
SNAPSHOT_FILE.write_text(
json.dumps(
{
"model": MODEL,
"seed": SEED,
"temperature": 0,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
"content": content,
},
indent=2,
)
+ "\n"
)
print(f" Snapshot created: {SNAPSHOT_FILE}")
print("PASSED: inference_snapshot")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,47 +0,0 @@
"""Test: Cluster works without internet access.
Verifies exo functions correctly when containers can talk to each other
but cannot reach the internet. Uses iptables to block all outbound traffic
except private subnets and multicast (for mDNS discovery).
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster(
"no_internet",
overrides=["tests/no_internet/docker-compose.override.yml"],
) as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Verify internet is actually blocked from inside the containers
for node in ["exo-node-1", "exo-node-2"]:
rc, _ = await cluster.exec(
node,
"curl",
"-sf",
"--max-time",
"3",
"https://huggingface.co",
check=False,
)
assert rc != 0, f"{node} should not be able to reach the internet"
print(f" {node}: internet correctly blocked")
# Verify exo detected no internet connectivity
log = await cluster.logs()
assert "Internet connectivity: False" in log, "exo should detect no internet"
print(" exo correctly detected no internet connectivity")
print("PASSED: no_internet")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,32 +0,0 @@
# Block all outbound internet traffic using iptables while preserving:
# - Multicast (224.0.0.0/4) for mDNS peer discovery
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
# - Loopback (127/8)
# Requires NET_ADMIN capability for iptables.
services:
exo-node-1:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v
exo-node-2:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx==0.30.6",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
@@ -64,6 +64,8 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
#mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", marker = "sys_platform == 'darwin'" }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-8bit"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -7,11 +7,17 @@ from exo.utils.dashboard_path import find_dashboard, find_resources
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
def _resolve_env_path(env_value: str) -> Path:
"""Resolve an environment variable path: absolute paths are used as-is, relative paths are resolved from home."""
p = Path(env_value)
return p if p.is_absolute() else Path.home() / p
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
if _EXO_HOME_ENV is not None:
return Path.home() / _EXO_HOME_ENV
return _resolve_env_path(_EXO_HOME_ENV)
if sys.platform != "linux":
return Path.home() / ".exo"
@@ -31,15 +37,19 @@ _EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR", None)
EXO_MODELS_DIR = (
EXO_DATA_HOME / "models"
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
else _resolve_env_path(_EXO_MODELS_DIR_ENV)
)
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
find_resources()
if _RESOURCES_DIR_ENV is None
else _resolve_env_path(_RESOURCES_DIR_ENV)
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV
find_dashboard()
if _DASHBOARD_DIR_ENV is None
else _resolve_env_path(_DASHBOARD_DIR_ENV)
)
# Log files (data/logs or cache)

View File

@@ -182,6 +182,7 @@ class ConfigData(BaseModel):
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["GlmMoeDsaForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],

View File

@@ -24,6 +24,8 @@ from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
from mlx_lm.models.glm_moe_dsa import Glm4MoeLiteMoE as GlmMoeDsaMoE
from mlx_lm.models.glm_moe_dsa import Model as GlmMoeDsaModel
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
@@ -160,11 +162,14 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
dep_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
dep_cache.keys = mx.depends(dep_cache.keys, output) # type: ignore[reportUnknownMemberType]
if self.is_prefill:
mx.eval(output)
if cache is not None:
mx.eval(cache.keys) # type: ignore
mx.eval(dep_cache.keys) # type: ignore
if not self.is_prefill:
output = mx.distributed.all_gather(output, group=self.group)[
@@ -403,6 +408,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GlmMoeDsaModel):
tensor_parallel_sharding_strategy = GlmMoeDsaShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Glm4MoeModel):
tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy(
group,
@@ -654,6 +667,62 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
return model
class GlmMoeDsaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GlmMoeDsaModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(),
timeout_seconds / len(model.layers),
on_timeout,
)
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
if isinstance(layer.mlp, Glm4MoeLiteMLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
else:
moe = cast(GlmMoeDsaMoE, layer.mlp)
if moe.shared_experts is not None:
self.all_to_sharded_linear_in_place(
moe.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
moe.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
moe.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(moe.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(moe.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(moe.switch_mlp.up_proj)
layer.mlp = ShardedMoE(moe) # type: ignore
layer.mlp.sharding_group = self.group
mx.eval(layer)
return model
class WrappedMiniMaxAttention(CustomMlxLayer):
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
super().__init__(layer)

View File

@@ -311,10 +311,12 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-4.7-flash" in model_id_lower:
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
# For GLM-5 and GLM-4.7
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
return None

View File

@@ -295,8 +295,8 @@ def main(
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
elif "glm-4" in shard_metadata.model_card.model_id.lower():
patch_glm4_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
@@ -863,7 +863,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
def patch_glm4_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""

12
uv.lock generated
View File

@@ -416,9 +416,9 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", specifier = "==0.30.6" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
{ name = "msgspec", specifier = ">=0.19.0" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
@@ -1098,8 +1098,8 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
version = "0.30.7"
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#bcf630614ffb5624bcb19870a7bcb0d847e6e98f" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
@@ -1109,10 +1109,6 @@ dependencies = [
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"