mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-13 07:32:30 -05:00
Compare commits
11 Commits
e2e-tests
...
leo/add-gl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce0eef999e | ||
|
|
20fb6a9acc | ||
|
|
4a1234106b | ||
|
|
2929249147 | ||
|
|
837ffc6b97 | ||
|
|
2366ed0299 | ||
|
|
c95c088952 | ||
|
|
2af1c81cde | ||
|
|
6922dd4ead | ||
|
|
8c2fb7f130 | ||
|
|
0488cb2967 |
@@ -1,15 +0,0 @@
|
||||
.venv/
|
||||
.direnv/
|
||||
target/
|
||||
.git/
|
||||
.idea/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
dashboard/node_modules/
|
||||
dashboard/.svelte-kit/
|
||||
dashboard/build/
|
||||
dist/
|
||||
*.pdb
|
||||
**/__pycache__
|
||||
**/.DS_Store
|
||||
.mlx_typings/
|
||||
29
.github/workflows/e2e.yml
vendored
29
.github/workflows/e2e.yml
vendored
@@ -1,29 +0,0 @@
|
||||
name: e2e-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches:
|
||||
- staging
|
||||
- main
|
||||
|
||||
jobs:
|
||||
e2e:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
|
||||
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
|
||||
/opt/microsoft /opt/az
|
||||
docker system prune -af
|
||||
df -h /
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Run E2E tests
|
||||
run: python3 e2e/run_all.py
|
||||
46
.mlx_typings/mlx_lm/models/glm_moe_dsa.pyi
Normal file
46
.mlx_typings/mlx_lm/models/glm_moe_dsa.pyi
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .deepseek_v32 import Model as DSV32Model
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
index_head_dim: int
|
||||
index_n_heads: int
|
||||
index_topk: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_parameters: Dict[str, Any]
|
||||
attention_bias: bool
|
||||
rope_scaling: Dict[str, Any] | None
|
||||
rope_theta: float | None
|
||||
|
||||
class Model(DSV32Model):
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
110
convert_glm5_mxfp4_q8.py
Normal file
110
convert_glm5_mxfp4_q8.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Convert GLM-5 to MXFP4-Q8: experts use MXFP4, dense layers use 8-bit affine.
|
||||
|
||||
Matches the same quantization scheme used for mlx-community/GPT-OSS-MXFP4-Q8:
|
||||
- Global default: MXFP4 (bits=4, group_size=32) for expert/switch_mlp weights
|
||||
- Per-layer overrides: affine Q8 (bits=8, group_size=64) for attention, embeddings,
|
||||
shared experts, dense MLP, lm_head
|
||||
|
||||
Usage:
|
||||
python convert_glm5_mxfp4_q8.py \
|
||||
--hf-path ~/.exo/models/zai-org--GLM-5 \
|
||||
--mlx-path ~/.exo/models/mlx-community--GLM-5-MXFP4-Q8 \
|
||||
--upload-repo mlx-community/GLM-5-MXFP4-Q8
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map_with_path
|
||||
|
||||
from mlx_lm.utils import compute_bits_per_weight, load, save, upload_to_hub
|
||||
|
||||
# Global default = MXFP4 for expert weights (switch_mlp)
|
||||
MXFP4_PARAMS = {"group_size": 32, "bits": 4, "mode": "mxfp4"}
|
||||
# Per-layer override = affine Q8 for everything else
|
||||
AFFINE_Q8_PARAMS = {"group_size": 64, "bits": 8, "mode": "affine"}
|
||||
|
||||
|
||||
def mxfp4_q8_predicate(path: str, module: nn.Module) -> dict | bool:
|
||||
"""MXFP4 for expert (switch_mlp) weights, 8-bit affine for everything else."""
|
||||
if not hasattr(module, "to_quantized"):
|
||||
return False
|
||||
|
||||
# Expert layers get MXFP4 (global default)
|
||||
if "switch_mlp" in path:
|
||||
if module.weight.shape[-1] % MXFP4_PARAMS["group_size"] != 0:
|
||||
return False
|
||||
return MXFP4_PARAMS
|
||||
|
||||
# Everything else gets 8-bit affine
|
||||
if module.weight.shape[-1] % AFFINE_Q8_PARAMS["group_size"] != 0:
|
||||
return False
|
||||
return AFFINE_Q8_PARAMS
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert GLM-5 to MXFP4-Q8")
|
||||
parser.add_argument("--hf-path", required=True, help="Path to HF model")
|
||||
parser.add_argument("--mlx-path", required=True, default="mlx_model", help="Output path")
|
||||
parser.add_argument("--upload-repo", default=None, help="HF repo to upload to")
|
||||
args = parser.parse_args()
|
||||
|
||||
mlx_path = Path(args.mlx_path)
|
||||
if mlx_path.exists():
|
||||
raise ValueError(f"Output path {mlx_path} already exists. Delete it first.")
|
||||
|
||||
print("[INFO] Loading")
|
||||
model, tokenizer, config = load(
|
||||
args.hf_path,
|
||||
return_config=True,
|
||||
lazy=True,
|
||||
)
|
||||
|
||||
# Apply dtype from config
|
||||
dtype = config.get("torch_dtype", None)
|
||||
if dtype in ("float16", "bfloat16", "float32"):
|
||||
print(f"[INFO] Using dtype: {dtype}")
|
||||
dt = getattr(mx, dtype)
|
||||
cast_predicate = getattr(model, "cast_predicate", lambda _: True)
|
||||
|
||||
def set_dtype(k, v):
|
||||
if cast_predicate(k) and mx.issubdtype(v.dtype, mx.floating):
|
||||
return v.astype(dt)
|
||||
return v
|
||||
|
||||
model.update(tree_map_with_path(set_dtype, model.parameters()))
|
||||
|
||||
# Build quantization config matching GPT-OSS format:
|
||||
# global default = mxfp4, per-layer overrides for Q8 layers
|
||||
quantized_config = copy.deepcopy(config)
|
||||
quantized_config["quantization"] = {**MXFP4_PARAMS}
|
||||
|
||||
def tracked_predicate(path: str, module: nn.Module) -> dict | bool:
|
||||
result = mxfp4_q8_predicate(path, module)
|
||||
if isinstance(result, dict) and result is not MXFP4_PARAMS:
|
||||
# Only store overrides for non-default (Q8) layers
|
||||
quantized_config["quantization"][path] = result
|
||||
return result
|
||||
|
||||
print("[INFO] Quantizing (MXFP4 experts + Q8 dense)")
|
||||
nn.quantize(
|
||||
model,
|
||||
class_predicate=tracked_predicate,
|
||||
)
|
||||
# Duplicate for HF compat (same as mlx_lm.convert does)
|
||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||
|
||||
bpw = compute_bits_per_weight(model)
|
||||
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
||||
|
||||
save(mlx_path, args.hf_path, model, tokenizer, quantized_config)
|
||||
|
||||
if args.upload_repo:
|
||||
upload_to_hub(mlx_path, args.upload_repo)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
236
download_glm5_shard.py
Normal file
236
download_glm5_shard.py
Normal file
@@ -0,0 +1,236 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fast parallel downloader for a range of safetensors files from zai-org/GLM-5.
|
||||
Uses aiohttp with 8 concurrent downloads and 8MB chunks (same approach as exo).
|
||||
|
||||
Usage:
|
||||
python download_glm5_shard.py <start> <end> [--dir GLM-5] [--jobs 8]
|
||||
|
||||
Split across 2 Macs:
|
||||
Mac 1: python download_glm5_shard.py 1 141
|
||||
Mac 2: python download_glm5_shard.py 142 282
|
||||
|
||||
Split across 4 Macs:
|
||||
Mac 1: python download_glm5_shard.py 1 71
|
||||
Mac 2: python download_glm5_shard.py 72 141
|
||||
Mac 3: python download_glm5_shard.py 142 212
|
||||
Mac 4: python download_glm5_shard.py 213 282
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import time
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import certifi
|
||||
|
||||
REPO = "zai-org/GLM-5"
|
||||
TOTAL_SHARDS = 282
|
||||
CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB
|
||||
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||
|
||||
|
||||
def get_token() -> str | None:
|
||||
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
||||
if token:
|
||||
return token
|
||||
token_path = os.path.expanduser("~/.cache/huggingface/token")
|
||||
if os.path.exists(token_path):
|
||||
with open(token_path) as f:
|
||||
return f.read().strip() or None
|
||||
return None
|
||||
|
||||
|
||||
def make_session() -> aiohttp.ClientSession:
|
||||
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
||||
conn = aiohttp.TCPConnector(ssl=ssl_ctx, limit=0)
|
||||
timeout = aiohttp.ClientTimeout(total=1800, connect=60, sock_read=60)
|
||||
return aiohttp.ClientSession(connector=conn, timeout=timeout)
|
||||
|
||||
|
||||
class ProgressTracker:
|
||||
def __init__(self, total_files: int):
|
||||
self.total_files = total_files
|
||||
self.completed_files = 0
|
||||
self.skipped_files = 0
|
||||
self.failed_files = 0
|
||||
self.total_bytes = 0
|
||||
self.downloaded_bytes = 0
|
||||
self.start_time = time.monotonic()
|
||||
self.lock = asyncio.Lock()
|
||||
# per-file tracking: filename -> (downloaded, total)
|
||||
self.active: dict[str, tuple[int, int]] = {}
|
||||
|
||||
async def file_skip(self, filename: str) -> None:
|
||||
async with self.lock:
|
||||
self.skipped_files += 1
|
||||
self._render()
|
||||
|
||||
async def file_fail(self, filename: str) -> None:
|
||||
async with self.lock:
|
||||
self.active.pop(filename, None)
|
||||
self.failed_files += 1
|
||||
self._render()
|
||||
|
||||
async def file_start(self, filename: str, total: int, resumed: int) -> None:
|
||||
async with self.lock:
|
||||
self.total_bytes += total
|
||||
self.downloaded_bytes += resumed
|
||||
self.active[filename] = (resumed, total)
|
||||
self._render()
|
||||
|
||||
async def file_progress(self, filename: str, downloaded: int, total: int) -> None:
|
||||
async with self.lock:
|
||||
prev, _ = self.active.get(filename, (0, total))
|
||||
self.downloaded_bytes += downloaded - prev
|
||||
self.active[filename] = (downloaded, total)
|
||||
self._render()
|
||||
|
||||
async def file_done(self, filename: str) -> None:
|
||||
async with self.lock:
|
||||
self.active.pop(filename, None)
|
||||
self.completed_files += 1
|
||||
self._render()
|
||||
|
||||
def _render(self) -> None:
|
||||
elapsed = time.monotonic() - self.start_time
|
||||
speed = self.downloaded_bytes / elapsed if elapsed > 0 else 0
|
||||
done = self.completed_files + self.skipped_files
|
||||
remaining_bytes = self.total_bytes - self.downloaded_bytes
|
||||
eta = remaining_bytes / speed if speed > 0 else 0
|
||||
|
||||
# Overall progress bar
|
||||
pct = self.downloaded_bytes / self.total_bytes * 100 if self.total_bytes else 0
|
||||
bar_width = 30
|
||||
filled = int(bar_width * pct / 100)
|
||||
bar = "=" * filled + ">" * (1 if filled < bar_width else 0) + " " * (bar_width - filled - 1)
|
||||
|
||||
# Active file names (short)
|
||||
active_names = []
|
||||
for fn, (dl, tot) in sorted(self.active.items()):
|
||||
short = fn.replace("model-", "").replace(f"-of-{TOTAL_SHARDS:05d}.safetensors", "")
|
||||
file_pct = dl / tot * 100 if tot else 0
|
||||
active_names.append(f"{short}:{file_pct:.0f}%")
|
||||
active_str = " ".join(active_names[:8])
|
||||
|
||||
eta_m, eta_s = divmod(int(eta), 60)
|
||||
eta_h, eta_m = divmod(eta_m, 60)
|
||||
eta_str = f"{eta_h}h{eta_m:02d}m" if eta_h else f"{eta_m}m{eta_s:02d}s"
|
||||
|
||||
line = (
|
||||
f"\r[{bar}] {pct:5.1f}% "
|
||||
f"{done}/{self.total_files} files "
|
||||
f"{self.downloaded_bytes / 1024**3:.1f}/{self.total_bytes / 1024**3:.1f} GB "
|
||||
f"{speed / 1024**2:.1f} MB/s "
|
||||
f"ETA {eta_str} "
|
||||
f"{active_str}"
|
||||
)
|
||||
# Pad to clear previous line, truncate to terminal width
|
||||
try:
|
||||
cols = os.get_terminal_size().columns
|
||||
except OSError:
|
||||
cols = 120
|
||||
line = line[:cols].ljust(cols)
|
||||
sys.stderr.write(line)
|
||||
sys.stderr.flush()
|
||||
|
||||
def final_summary(self) -> None:
|
||||
elapsed = time.monotonic() - self.start_time
|
||||
speed = self.downloaded_bytes / elapsed if elapsed > 0 else 0
|
||||
mins, secs = divmod(int(elapsed), 60)
|
||||
sys.stderr.write("\n")
|
||||
print(
|
||||
f"Done: {self.completed_files} downloaded, {self.skipped_files} skipped, "
|
||||
f"{self.failed_files} failed. "
|
||||
f"{self.downloaded_bytes / 1024**3:.1f} GB in {mins}m{secs:02d}s "
|
||||
f"({speed / 1024**2:.1f} MB/s avg)"
|
||||
)
|
||||
|
||||
|
||||
async def download_file(
|
||||
session: aiohttp.ClientSession,
|
||||
filename: str,
|
||||
target_dir: str,
|
||||
headers: dict[str, str],
|
||||
sem: asyncio.Semaphore,
|
||||
progress: ProgressTracker,
|
||||
) -> None:
|
||||
async with sem:
|
||||
url = f"{HF_ENDPOINT}/{REPO}/resolve/main/{filename}"
|
||||
target = os.path.join(target_dir, filename)
|
||||
partial = target + ".partial"
|
||||
os.makedirs(os.path.dirname(target), exist_ok=True)
|
||||
|
||||
if os.path.exists(target):
|
||||
await progress.file_skip(filename)
|
||||
return
|
||||
|
||||
resume_pos = 0
|
||||
req_headers = dict(headers)
|
||||
if os.path.exists(partial):
|
||||
resume_pos = os.path.getsize(partial)
|
||||
req_headers["Range"] = f"bytes={resume_pos}-"
|
||||
|
||||
async with session.get(url, headers=req_headers) as r:
|
||||
if r.status == 416:
|
||||
os.rename(partial, target)
|
||||
await progress.file_skip(filename)
|
||||
return
|
||||
if r.status not in (200, 206):
|
||||
await progress.file_fail(filename)
|
||||
return
|
||||
|
||||
total = int(r.headers.get("Content-Length", 0)) + resume_pos
|
||||
downloaded = resume_pos
|
||||
await progress.file_start(filename, total, resume_pos)
|
||||
|
||||
async with aiofiles.open(partial, "ab" if resume_pos else "wb") as f:
|
||||
while True:
|
||||
chunk = await r.content.read(CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
await f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
await progress.file_progress(filename, downloaded, total)
|
||||
|
||||
os.rename(partial, target)
|
||||
await progress.file_done(filename)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Fast parallel GLM-5 shard downloader")
|
||||
parser.add_argument("start", type=int, help="First shard number (1-based)")
|
||||
parser.add_argument("end", type=int, help="Last shard number (inclusive)")
|
||||
parser.add_argument("--dir", default="GLM-5", help="Target directory (default: GLM-5)")
|
||||
parser.add_argument("--jobs", type=int, default=8, help="Parallel downloads (default: 8)")
|
||||
args = parser.parse_args()
|
||||
|
||||
files = [
|
||||
f"model-{i:05d}-of-{TOTAL_SHARDS:05d}.safetensors"
|
||||
for i in range(args.start, args.end + 1)
|
||||
]
|
||||
|
||||
headers: dict[str, str] = {"Accept-Encoding": "identity"}
|
||||
token = get_token()
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
print(f"Downloading {len(files)} files ({args.start}-{args.end}) to {args.dir}/ with {args.jobs} parallel jobs")
|
||||
|
||||
progress = ProgressTracker(len(files))
|
||||
sem = asyncio.Semaphore(args.jobs)
|
||||
async with make_session() as session:
|
||||
await asyncio.gather(*[
|
||||
download_file(session, f, args.dir, headers, sem, progress)
|
||||
for f in files
|
||||
])
|
||||
|
||||
progress.final_summary()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
22
download_glm5_shard.sh
Executable file
22
download_glm5_shard.sh
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
# Usage: ./download_glm5_shard.sh <start> <end> [local_dir]
|
||||
#
|
||||
# Split across 4 Macs:
|
||||
# Mac 1: ./download_glm5_shard.sh 1 71
|
||||
# Mac 2: ./download_glm5_shard.sh 72 141
|
||||
# Mac 3: ./download_glm5_shard.sh 142 212
|
||||
# Mac 4: ./download_glm5_shard.sh 213 282
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
START=${1:?Usage: $0 <start> <end> [local_dir]}
|
||||
END=${2:?Usage: $0 <start> <end> [local_dir]}
|
||||
LOCAL_DIR="${3:-GLM-5}"
|
||||
|
||||
INCLUDES=()
|
||||
for i in $(seq "$START" "$END"); do
|
||||
INCLUDES+=(--include "$(printf 'model-%05d-of-00282.safetensors' "$i")")
|
||||
done
|
||||
|
||||
echo "Downloading safetensors $START-$END to $LOCAL_DIR"
|
||||
hf download zai-org/GLM-5 "${INCLUDES[@]}" --local-dir "$LOCAL_DIR"
|
||||
@@ -1,53 +0,0 @@
|
||||
# Stage 1: Build the dashboard
|
||||
FROM node:22-slim AS dashboard
|
||||
WORKDIR /app/dashboard
|
||||
COPY dashboard/package.json dashboard/package-lock.json ./
|
||||
RUN npm ci
|
||||
COPY dashboard/ .
|
||||
RUN npm run build
|
||||
|
||||
# Stage 2: Build and run exo
|
||||
FROM python:3.13-slim
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
curl \
|
||||
protobuf-compiler \
|
||||
iptables \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust nightly
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
# Install uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files first for better layer caching
|
||||
COPY pyproject.toml Cargo.toml uv.lock README.md ./
|
||||
COPY rust/ ./rust/
|
||||
COPY bench/pyproject.toml ./bench/pyproject.toml
|
||||
|
||||
# Copy source and resources
|
||||
COPY src/ ./src/
|
||||
COPY resources/ ./resources/
|
||||
|
||||
# Copy built dashboard from stage 1
|
||||
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
|
||||
|
||||
# Install Python deps and build Rust bindings, then clean up build artifacts
|
||||
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
|
||||
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
|
||||
|
||||
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
|
||||
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
|
||||
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
|
||||
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
|
||||
chmod +x /usr/bin/g++
|
||||
|
||||
CMD [".venv/bin/exo", "-v"]
|
||||
182
e2e/conftest.py
182
e2e/conftest.py
@@ -1,182 +0,0 @@
|
||||
"""Shared E2E test infrastructure for exo cluster tests."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
E2E_DIR = Path(__file__).parent.resolve()
|
||||
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
|
||||
|
||||
|
||||
class Cluster:
|
||||
"""Async wrapper around a docker compose exo cluster."""
|
||||
|
||||
def __init__(self, name: str, overrides: list[str] | None = None):
|
||||
self.name = name
|
||||
self.project = f"e2e-{name}"
|
||||
compose_files = [str(E2E_DIR / "docker-compose.yml")]
|
||||
for path in overrides or []:
|
||||
compose_files.append(str(E2E_DIR / path))
|
||||
self._compose_base = [
|
||||
"docker",
|
||||
"compose",
|
||||
"-p",
|
||||
self.project,
|
||||
*[arg for f in compose_files for arg in ("-f", f)],
|
||||
]
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
await self.stop()
|
||||
|
||||
async def _run(self, *args: str, check: bool = True) -> str:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*self._compose_base,
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode()
|
||||
if check and proc.returncode != 0:
|
||||
print(output, file=sys.stderr)
|
||||
raise RuntimeError(
|
||||
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
|
||||
)
|
||||
return output
|
||||
|
||||
async def build(self):
|
||||
print(" Building images...")
|
||||
await self._run("build", "--quiet")
|
||||
|
||||
async def start(self):
|
||||
print(" Starting cluster...")
|
||||
await self._run("up", "-d")
|
||||
|
||||
async def stop(self):
|
||||
print(" Cleaning up...")
|
||||
await self._run("down", "--timeout", "5", check=False)
|
||||
|
||||
async def logs(self) -> str:
|
||||
return await self._run("logs", check=False)
|
||||
|
||||
async def exec(
|
||||
self, service: str, *cmd: str, check: bool = True
|
||||
) -> tuple[int, str]:
|
||||
"""Run a command inside a running container. Returns (returncode, output)."""
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*self._compose_base,
|
||||
"exec",
|
||||
"-T",
|
||||
service,
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode()
|
||||
if check and proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
|
||||
)
|
||||
return proc.returncode, output
|
||||
|
||||
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
|
||||
"""Poll check_fn every 2s until it returns True or timeout expires."""
|
||||
print(f" Waiting for {description}...")
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if await check_fn():
|
||||
print(f" {description}")
|
||||
return
|
||||
await asyncio.sleep(2)
|
||||
output = await self.logs()
|
||||
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
|
||||
raise TimeoutError(f"Timed out waiting for {description}")
|
||||
|
||||
async def assert_healthy(self):
|
||||
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
|
||||
|
||||
async def both_nodes_started():
|
||||
log = await self.logs()
|
||||
return log.count("Starting node") >= 2
|
||||
|
||||
async def nodes_discovered():
|
||||
log = await self.logs()
|
||||
return log.count("ConnectionMessageType.Connected") >= 2
|
||||
|
||||
async def master_elected():
|
||||
log = await self.logs()
|
||||
return "demoting self" in log
|
||||
|
||||
async def api_responding():
|
||||
try:
|
||||
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
|
||||
return resp.status == 200
|
||||
except (URLError, OSError):
|
||||
return False
|
||||
|
||||
await self.wait_for("Both nodes started", both_nodes_started)
|
||||
await self.wait_for("Nodes discovered each other", nodes_discovered)
|
||||
await self.wait_for("Master election resolved", master_elected)
|
||||
await self.wait_for("API responding", api_responding)
|
||||
|
||||
async def _api(
|
||||
self, method: str, path: str, body: dict | None = None, timeout: int = 30
|
||||
) -> dict:
|
||||
"""Make an API request to the cluster. Returns parsed JSON."""
|
||||
url = f"http://localhost:52415{path}"
|
||||
data = json.dumps(body).encode() if body else None
|
||||
req = Request(
|
||||
url, data=data, headers={"Content-Type": "application/json"}, method=method
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
resp_bytes = await loop.run_in_executor(
|
||||
None, lambda: urlopen(req, timeout=timeout).read()
|
||||
)
|
||||
return json.loads(resp_bytes)
|
||||
|
||||
async def place_model(self, model: str, timeout: int = 600):
|
||||
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
|
||||
await self._api("POST", "/place_instance", {"model_id": model})
|
||||
|
||||
async def model_ready():
|
||||
try:
|
||||
resp = await self._api("GET", "/v1/models")
|
||||
return any(m.get("id") == model for m in resp.get("data", []))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
|
||||
|
||||
async def chat(
|
||||
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
|
||||
) -> dict:
|
||||
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
|
||||
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
last_error = None
|
||||
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
try:
|
||||
req = Request(
|
||||
"http://localhost:52415/v1/chat/completions",
|
||||
data=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
resp_bytes = await loop.run_in_executor(
|
||||
None, lambda r=req: urlopen(r, timeout=300).read()
|
||||
)
|
||||
return json.loads(resp_bytes)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
await asyncio.sleep(5)
|
||||
|
||||
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")
|
||||
@@ -1,18 +0,0 @@
|
||||
services:
|
||||
exo-node-1:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile
|
||||
environment:
|
||||
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||
command: [".venv/bin/exo", "-v"]
|
||||
ports:
|
||||
- "52415:52415"
|
||||
|
||||
exo-node-2:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile
|
||||
environment:
|
||||
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||
command: [".venv/bin/exo", "-v"]
|
||||
@@ -1,75 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Discovers and runs all E2E tests in e2e/test_*.py.
|
||||
|
||||
Tests with '# slow' on the first line of their docstring are skipped
|
||||
unless --slow is passed or E2E_SLOW=1 is set.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
E2E_DIR = Path(__file__).parent.resolve()
|
||||
|
||||
|
||||
def is_slow(test_file: Path) -> bool:
|
||||
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
|
||||
with open(test_file) as f:
|
||||
for line in f:
|
||||
if line.strip().startswith("#"):
|
||||
continue
|
||||
if line.strip().startswith('"""') or line.strip().startswith("'''"):
|
||||
# Read into the docstring
|
||||
for doc_line in f:
|
||||
if "slow" in doc_line.lower() and doc_line.strip().startswith(
|
||||
"slow"
|
||||
):
|
||||
return True
|
||||
if '"""' in doc_line or "'''" in doc_line:
|
||||
break
|
||||
break
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
|
||||
test_files = sorted(E2E_DIR.glob("test_*.py"))
|
||||
if not test_files:
|
||||
print("No test files found")
|
||||
sys.exit(1)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
failures = []
|
||||
|
||||
for test_file in test_files:
|
||||
name = test_file.stem
|
||||
if is_slow(test_file) and not run_slow:
|
||||
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
print(f"=== {name} ===")
|
||||
result = subprocess.run([sys.executable, str(test_file)])
|
||||
if result.returncode == 0:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
failures.append(name)
|
||||
print()
|
||||
|
||||
total = passed + failed + skipped
|
||||
print("================================")
|
||||
print(
|
||||
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
|
||||
)
|
||||
|
||||
if failed:
|
||||
print(f"Failed: {' '.join(failures)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"model": "mlx-community/Qwen3-0.6B-4bit",
|
||||
"seed": 42,
|
||||
"temperature": 0,
|
||||
"prompt": "What is 2+2? Reply with just the number.",
|
||||
"max_tokens": 32,
|
||||
"content": "<think>\nOkay, so I need to figure out what 2+2 is. Let me think. Well, if you add 2 and 2 together"
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Test: Basic cluster formation.
|
||||
|
||||
Verifies two nodes discover each other, elect a master, and the API responds.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("cluster_formation") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
print("PASSED: cluster_formation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Test: Deterministic inference output (snapshot test).
|
||||
slow
|
||||
|
||||
Sends a chat completion request with a fixed seed and temperature=0,
|
||||
then verifies the output matches a known-good snapshot. This ensures
|
||||
inference produces consistent results across runs.
|
||||
|
||||
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
|
||||
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = "What is 2+2? Reply with just the number."
|
||||
MAX_TOKENS = 32
|
||||
SNAPSHOT_FILE = Path(__file__).parent / "snapshots" / "inference.json"
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("inference_snapshot") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
# Launch the model instance (triggers download + placement)
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED}, temperature=0)...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
temperature=0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
# Load or create snapshot
|
||||
if SNAPSHOT_FILE.exists():
|
||||
snapshot = json.loads(SNAPSHOT_FILE.read_text())
|
||||
expected = snapshot["content"]
|
||||
assert content == expected, (
|
||||
f"Snapshot mismatch!\n"
|
||||
f" Expected: {expected!r}\n"
|
||||
f" Got: {content!r}\n"
|
||||
f" Delete {SNAPSHOT_FILE} to regenerate."
|
||||
)
|
||||
print(" Output matches snapshot")
|
||||
else:
|
||||
SNAPSHOT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
SNAPSHOT_FILE.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"temperature": 0,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"content": content,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
print(f" Snapshot created: {SNAPSHOT_FILE}")
|
||||
|
||||
print("PASSED: inference_snapshot")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Test: Cluster works without internet access.
|
||||
|
||||
Verifies exo functions correctly when containers can talk to each other
|
||||
but cannot reach the internet. Uses iptables to block all outbound traffic
|
||||
except private subnets and multicast (for mDNS discovery).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster(
|
||||
"no_internet",
|
||||
overrides=["tests/no_internet/docker-compose.override.yml"],
|
||||
) as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
# Verify internet is actually blocked from inside the containers
|
||||
for node in ["exo-node-1", "exo-node-2"]:
|
||||
rc, _ = await cluster.exec(
|
||||
node,
|
||||
"curl",
|
||||
"-sf",
|
||||
"--max-time",
|
||||
"3",
|
||||
"https://huggingface.co",
|
||||
check=False,
|
||||
)
|
||||
assert rc != 0, f"{node} should not be able to reach the internet"
|
||||
print(f" {node}: internet correctly blocked")
|
||||
|
||||
# Verify exo detected no internet connectivity
|
||||
log = await cluster.logs()
|
||||
assert "Internet connectivity: False" in log, "exo should detect no internet"
|
||||
print(" exo correctly detected no internet connectivity")
|
||||
|
||||
print("PASSED: no_internet")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,32 +0,0 @@
|
||||
# Block all outbound internet traffic using iptables while preserving:
|
||||
# - Multicast (224.0.0.0/4) for mDNS peer discovery
|
||||
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
|
||||
# - Loopback (127/8)
|
||||
# Requires NET_ADMIN capability for iptables.
|
||||
services:
|
||||
exo-node-1:
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
entrypoint: ["/bin/sh", "-c"]
|
||||
command:
|
||||
- |
|
||||
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||
iptables -A OUTPUT -j REJECT
|
||||
exec .venv/bin/exo -v
|
||||
exo-node-2:
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
entrypoint: ["/bin/sh", "-c"]
|
||||
command:
|
||||
- |
|
||||
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||
iptables -A OUTPUT -j REJECT
|
||||
exec .venv/bin/exo -v
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.6; sys_platform == 'darwin'",
|
||||
"mlx==0.30.6",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
@@ -64,6 +64,8 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
#mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", marker = "sys_platform == 'darwin'" }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-8bit"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 790517400864
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5-MXFP4-Q8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "MXFP4-Q8"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405478939008
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/GLM-5"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1487822475264
|
||||
@@ -7,11 +7,17 @@ from exo.utils.dashboard_path import find_dashboard, find_resources
|
||||
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
|
||||
|
||||
|
||||
def _resolve_env_path(env_value: str) -> Path:
|
||||
"""Resolve an environment variable path: absolute paths are used as-is, relative paths are resolved from home."""
|
||||
p = Path(env_value)
|
||||
return p if p.is_absolute() else Path.home() / p
|
||||
|
||||
|
||||
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
|
||||
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
|
||||
|
||||
if _EXO_HOME_ENV is not None:
|
||||
return Path.home() / _EXO_HOME_ENV
|
||||
return _resolve_env_path(_EXO_HOME_ENV)
|
||||
|
||||
if sys.platform != "linux":
|
||||
return Path.home() / ".exo"
|
||||
@@ -31,15 +37,19 @@ _EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR", None)
|
||||
EXO_MODELS_DIR = (
|
||||
EXO_DATA_HOME / "models"
|
||||
if _EXO_MODELS_DIR_ENV is None
|
||||
else Path.home() / _EXO_MODELS_DIR_ENV
|
||||
else _resolve_env_path(_EXO_MODELS_DIR_ENV)
|
||||
)
|
||||
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
|
||||
RESOURCES_DIR = (
|
||||
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
|
||||
find_resources()
|
||||
if _RESOURCES_DIR_ENV is None
|
||||
else _resolve_env_path(_RESOURCES_DIR_ENV)
|
||||
)
|
||||
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
|
||||
DASHBOARD_DIR = (
|
||||
find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV
|
||||
find_dashboard()
|
||||
if _DASHBOARD_DIR_ENV is None
|
||||
else _resolve_env_path(_DASHBOARD_DIR_ENV)
|
||||
)
|
||||
|
||||
# Log files (data/logs or cache)
|
||||
|
||||
@@ -182,6 +182,7 @@ class ConfigData(BaseModel):
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["GlmMoeDsaForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
|
||||
@@ -24,6 +24,8 @@ from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.glm_moe_dsa import Glm4MoeLiteMoE as GlmMoeDsaMoE
|
||||
from mlx_lm.models.glm_moe_dsa import Model as GlmMoeDsaModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
@@ -160,11 +162,14 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
|
||||
# doesn't have .keys directly; access via first sub-cache.
|
||||
dep_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
|
||||
dep_cache.keys = mx.depends(dep_cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
if self.is_prefill:
|
||||
mx.eval(output)
|
||||
if cache is not None:
|
||||
mx.eval(cache.keys) # type: ignore
|
||||
mx.eval(dep_cache.keys) # type: ignore
|
||||
|
||||
if not self.is_prefill:
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
@@ -403,6 +408,14 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GlmMoeDsaModel):
|
||||
tensor_parallel_sharding_strategy = GlmMoeDsaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Glm4MoeModel):
|
||||
tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy(
|
||||
group,
|
||||
@@ -654,6 +667,62 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class GlmMoeDsaShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GlmMoeDsaModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
timeout_seconds / len(model.layers),
|
||||
on_timeout,
|
||||
)
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
if isinstance(layer.mlp, Glm4MoeLiteMLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
else:
|
||||
moe = cast(GlmMoeDsaMoE, layer.mlp)
|
||||
if moe.shared_experts is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
moe.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
moe.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
moe.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(moe.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(moe.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(moe.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedMoE(moe) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
|
||||
@@ -311,10 +311,12 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm-4.7-flash" in model_id_lower:
|
||||
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
|
||||
# For GLM-5 and GLM-4.7
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif "glm" in model_id_lower:
|
||||
# For GLM-4.5 and older
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
|
||||
@@ -295,8 +295,8 @@ def main(
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
elif "glm-4" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm4_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
@@ -863,7 +863,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
def patch_glm4_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
|
||||
12
uv.lock
generated
12
uv.lock
generated
@@ -416,9 +416,9 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.5" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
|
||||
{ name = "mlx", specifier = "==0.30.6" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
|
||||
{ name = "msgspec", specifier = ">=0.19.0" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
@@ -1098,8 +1098,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
version = "0.30.7"
|
||||
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#bcf630614ffb5624bcb19870a7bcb0d847e6e98f" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -1109,10 +1109,6 @@ dependencies = [
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
|
||||
Reference in New Issue
Block a user