mirror of
https://github.com/exo-explore/exo.git
synced 2026-04-18 04:52:40 -04:00
## Motivation Qwen3.5 MoE models (e.g., `Qwen3.5-397B-A17B-6bit`) are now supported by `mlx-lm` via `qwen3_5_moe` model type, but exo lacks tensor parallel sharding support for this architecture. This prevents running large Qwen3.5 models across multiple nodes. Qwen3.5 uses a GatedDeltaNet hybrid attention mechanism similar to Qwen3-Next, but with a different projection layout — separate `in_proj_qkv`, `in_proj_z`, `in_proj_b`, `in_proj_a` instead of Qwen3-Next's combined `in_proj_qkvz` and `in_proj_ba`. This requires architecture-aware sharding logic. ## Changes (evan summary) - enable qwen3_5 dense + moe tensor parallelism from config - defensively skip evalling _cache.keys if it doesn't exist - ignore kwargs in qwen35 pipeline masking and ensure pipeline segments match global model parameters for mask creation - add sharding for qwen3_5 moe linear attention - added another 6 million model cards ## Why It Works Qwen3.5's GatedDeltaNet has an `in_proj_qkv` linear layer with three concatenated sections: `[q(key_dim), k(key_dim), v(value_dim)]`. A naive contiguous split (`segments=1`) would slice across section boundaries, corrupting q/k/v values and producing garbled output. By passing `segments=[key_dim, key_dim + key_dim]` to `shard_linear()`, each section is split independently before distributing across devices. This ensures every rank receives correctly aligned q, k, and v components. The remaining separate projections (`in_proj_z`, `in_proj_b`, `in_proj_a`) and the MoE layers follow the same `all_to_sharded` / `sharded_to_all` pattern already used for Qwen3-Next. Some pipeline splits didn't include an ssm layer or a linear layer resulting in a subset of the model acting like it shouldn't create the appropriate masks for the next layer - we patch the model to manually create such masks. ## Test Plan tensor sharded 2,3,4 models & pipeline sharded 2,3,4 with simple eval. --------- Co-authored-by: hw <hw@hwStudio1.local> Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net> Co-authored-by: Evan <evanev7@gmail.com>
34 lines
860 B
Python
34 lines
860 B
Python
"""
|
|
Generates inference model cards for EXO.
|
|
Usage:
|
|
uv run tmp/gen_card.py mlx-community/my_cool_model-8bit [repo-id/model-id-2] [...]
|
|
|
|
Model Cards require cleanup for family & quantization data
|
|
"""
|
|
|
|
import sys
|
|
|
|
import anyio
|
|
|
|
from exo.shared.models.model_cards import ModelCard, ModelId
|
|
|
|
|
|
async def main():
|
|
if len(sys.argv) == 1:
|
|
print(f"USAGE: {sys.argv[0]} repo-id/model-id-1 [repo-id/model-id-2] [...]")
|
|
quit(1)
|
|
print("Remember! Model Cards require cleanup for family & quantization data")
|
|
for arg in sys.argv[1:]:
|
|
mid = ModelId(arg)
|
|
mc = await ModelCard.fetch_from_hf(mid)
|
|
await mc.save(
|
|
anyio.Path(__file__).parent.parent
|
|
/ "resources"
|
|
/ "inference_model_cards"
|
|
/ (mid.normalize() + ".toml")
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
anyio.run(main)
|