mirror of
https://github.com/exo-explore/exo.git
synced 2026-04-17 20:40:35 -04:00
## Problem Models with fewer KV heads than nodes crash during tensor parallelism. For example, Qwen3.5 MoE models have only 2 KV heads — trying to shard across 4 nodes produces empty tensors and a reshape error at runtime. The placement system already validates `hidden_size % num_nodes == 0` but doesn't check KV heads, so it creates configurations that look valid but blow up when the worker tries to split the attention heads. Affected models include Qwen3.5-35B-A3B, Qwen3.5-122B-A10B, Qwen3.5-397B-A17B, Qwen3-Next-80B-A3B, and Qwen3-Coder-Next (all have 2 KV heads). ## Changes **Placement validation** (`src/exo/master/placement.py`): - Combined KV heads divisibility check with the existing hidden_size filter in a single pass - Cycles where `num_key_value_heads % len(cycle) != 0` are now excluded for tensor sharding - Error message includes both constraints when no valid cycle is found **Model card schema** (`src/exo/shared/models/model_cards.py`): - Added optional `num_key_value_heads` field to `ModelCard` and `ConfigData` - Extracted from HuggingFace `config.json` (handles both top-level and `text_config` nesting) - Passed through in `fetch_from_hf()` for dynamically fetched cards **All 68 inference model cards** (`resources/inference_model_cards/*.toml`): - Populated `num_key_value_heads` from each model's HuggingFace config **Utility script** (`scripts/fetch_kv_heads.py`): - Fetches `num_key_value_heads` from HuggingFace and updates TOML cards - `--missing`: only fills in cards that don't have the field yet - `--all`: re-fetches and overwrites everything - Uses tomlkit for safe TOML editing and ThreadPoolExecutor for parallel fetches ## Behavior - Instance previews no longer show tensor options for models that can't split their KV heads across the cluster size - `place_instance()` rejects with a clear error instead of crash-looping - Pipeline parallelism is unaffected - 2-node tensor still works for 2-KV-head models (2 ÷ 2 = 1) - Field is optional — existing custom cards without it continue to work (validation is skipped when `None`)
134 lines
3.9 KiB
Python
Executable File
134 lines
3.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Fetch num_key_value_heads from HuggingFace config.json and update TOML model cards.
|
|
|
|
Usage:
|
|
# Update only cards missing num_key_value_heads
|
|
uv run python scripts/fetch_kv_heads.py --missing
|
|
|
|
# Update all cards (overwrite existing values)
|
|
uv run python scripts/fetch_kv_heads.py --all
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import urllib.request
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from pathlib import Path
|
|
|
|
import tomlkit
|
|
|
|
CARDS_DIR = (
|
|
Path(__file__).resolve().parent.parent / "resources" / "inference_model_cards"
|
|
)
|
|
MAX_WORKERS = 5
|
|
|
|
|
|
def fetch_kv_heads(model_id: str) -> int | None:
|
|
"""Fetch num_key_value_heads from HuggingFace config.json."""
|
|
url = f"https://huggingface.co/{model_id}/raw/main/config.json"
|
|
try:
|
|
with urllib.request.urlopen(url, timeout=15) as resp:
|
|
config = json.loads(resp.read())
|
|
except Exception as e:
|
|
print(f" ERROR fetching {url}: {e}", file=sys.stderr)
|
|
return None
|
|
|
|
for source in [config, config.get("text_config", {})]:
|
|
if "num_key_value_heads" in source:
|
|
return int(source["num_key_value_heads"])
|
|
|
|
return None
|
|
|
|
|
|
def update_toml(path: Path, kv_heads: int) -> bool:
|
|
"""Insert or update num_key_value_heads in a TOML file. Returns True if changed."""
|
|
content = path.read_text()
|
|
doc = tomlkit.parse(content)
|
|
|
|
if doc.get("num_key_value_heads") == kv_heads:
|
|
return False
|
|
|
|
# Insert after hidden_size if adding for the first time
|
|
if "num_key_value_heads" not in doc:
|
|
new_doc = tomlkit.document()
|
|
for key, value in doc.items():
|
|
new_doc[key] = value
|
|
if key == "hidden_size":
|
|
new_doc["num_key_value_heads"] = kv_heads
|
|
path.write_text(tomlkit.dumps(new_doc))
|
|
else:
|
|
doc["num_key_value_heads"] = kv_heads
|
|
path.write_text(tomlkit.dumps(doc))
|
|
|
|
return True
|
|
|
|
|
|
def process_card(path: Path) -> tuple[str, str]:
|
|
"""Fetch and update a single card. Returns (filename, status)."""
|
|
content = path.read_text()
|
|
doc = tomlkit.parse(content)
|
|
model_id = doc.get("model_id")
|
|
if not model_id:
|
|
return path.name, "SKIP (no model_id)"
|
|
|
|
kv_heads = fetch_kv_heads(str(model_id))
|
|
if kv_heads is None:
|
|
return path.name, "FAILED"
|
|
|
|
changed = update_toml(path, kv_heads)
|
|
return path.name, f"{kv_heads} ({'UPDATED' if changed else 'UNCHANGED'})"
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Fetch num_key_value_heads from HuggingFace and update TOML cards."
|
|
)
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
group.add_argument(
|
|
"--all",
|
|
action="store_true",
|
|
help="Update all model cards (overwrite existing values)",
|
|
)
|
|
group.add_argument(
|
|
"--missing",
|
|
action="store_true",
|
|
help="Only update cards missing num_key_value_heads",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
toml_files = sorted(CARDS_DIR.glob("*.toml"))
|
|
if not toml_files:
|
|
print(f"No TOML files found in {CARDS_DIR}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
to_process = []
|
|
skipped = 0
|
|
|
|
for path in toml_files:
|
|
if args.missing and "num_key_value_heads" in path.read_text():
|
|
skipped += 1
|
|
continue
|
|
to_process.append(path)
|
|
|
|
updated = 0
|
|
failed = 0
|
|
|
|
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
|
|
futures = {pool.submit(process_card, path): path for path in to_process}
|
|
for future in as_completed(futures):
|
|
name, status = future.result()
|
|
print(f" {name}: {status}")
|
|
if "UPDATED" in status:
|
|
updated += 1
|
|
elif "FAILED" in status:
|
|
failed += 1
|
|
|
|
print(f"\nDone: {updated} updated, {skipped} skipped, {failed} failed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|