Files
exo/scripts/fetch_kv_heads.py
Mustafa Alp Yılmaz 2994b41089 fix: validate num_key_value_heads in tensor sharding placement (#1669)
## Problem

Models with fewer KV heads than nodes crash during tensor parallelism.
For example, Qwen3.5 MoE models have only 2 KV heads — trying to shard
across 4 nodes produces empty tensors and a reshape error at runtime.

The placement system already validates `hidden_size % num_nodes == 0`
but doesn't check KV heads, so it creates configurations that look valid
but blow up when the worker tries to split the attention heads.

Affected models include Qwen3.5-35B-A3B, Qwen3.5-122B-A10B,
Qwen3.5-397B-A17B, Qwen3-Next-80B-A3B, and Qwen3-Coder-Next (all have 2
KV heads).

## Changes

**Placement validation** (`src/exo/master/placement.py`):
- Combined KV heads divisibility check with the existing hidden_size
filter in a single pass
- Cycles where `num_key_value_heads % len(cycle) != 0` are now excluded
for tensor sharding
- Error message includes both constraints when no valid cycle is found

**Model card schema** (`src/exo/shared/models/model_cards.py`):
- Added optional `num_key_value_heads` field to `ModelCard` and
`ConfigData`
- Extracted from HuggingFace `config.json` (handles both top-level and
`text_config` nesting)
- Passed through in `fetch_from_hf()` for dynamically fetched cards

**All 68 inference model cards**
(`resources/inference_model_cards/*.toml`):
- Populated `num_key_value_heads` from each model's HuggingFace config

**Utility script** (`scripts/fetch_kv_heads.py`):
- Fetches `num_key_value_heads` from HuggingFace and updates TOML cards
- `--missing`: only fills in cards that don't have the field yet
- `--all`: re-fetches and overwrites everything
- Uses tomlkit for safe TOML editing and ThreadPoolExecutor for parallel
fetches

## Behavior

- Instance previews no longer show tensor options for models that can't
split their KV heads across the cluster size
- `place_instance()` rejects with a clear error instead of crash-looping
- Pipeline parallelism is unaffected
- 2-node tensor still works for 2-KV-head models (2 ÷ 2 = 1)
- Field is optional — existing custom cards without it continue to work
(validation is skipped when `None`)
2026-03-11 13:46:33 +00:00

134 lines
3.9 KiB
Python
Executable File

#!/usr/bin/env python3
"""Fetch num_key_value_heads from HuggingFace config.json and update TOML model cards.
Usage:
# Update only cards missing num_key_value_heads
uv run python scripts/fetch_kv_heads.py --missing
# Update all cards (overwrite existing values)
uv run python scripts/fetch_kv_heads.py --all
"""
from __future__ import annotations
import argparse
import json
import sys
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import tomlkit
CARDS_DIR = (
Path(__file__).resolve().parent.parent / "resources" / "inference_model_cards"
)
MAX_WORKERS = 5
def fetch_kv_heads(model_id: str) -> int | None:
"""Fetch num_key_value_heads from HuggingFace config.json."""
url = f"https://huggingface.co/{model_id}/raw/main/config.json"
try:
with urllib.request.urlopen(url, timeout=15) as resp:
config = json.loads(resp.read())
except Exception as e:
print(f" ERROR fetching {url}: {e}", file=sys.stderr)
return None
for source in [config, config.get("text_config", {})]:
if "num_key_value_heads" in source:
return int(source["num_key_value_heads"])
return None
def update_toml(path: Path, kv_heads: int) -> bool:
"""Insert or update num_key_value_heads in a TOML file. Returns True if changed."""
content = path.read_text()
doc = tomlkit.parse(content)
if doc.get("num_key_value_heads") == kv_heads:
return False
# Insert after hidden_size if adding for the first time
if "num_key_value_heads" not in doc:
new_doc = tomlkit.document()
for key, value in doc.items():
new_doc[key] = value
if key == "hidden_size":
new_doc["num_key_value_heads"] = kv_heads
path.write_text(tomlkit.dumps(new_doc))
else:
doc["num_key_value_heads"] = kv_heads
path.write_text(tomlkit.dumps(doc))
return True
def process_card(path: Path) -> tuple[str, str]:
"""Fetch and update a single card. Returns (filename, status)."""
content = path.read_text()
doc = tomlkit.parse(content)
model_id = doc.get("model_id")
if not model_id:
return path.name, "SKIP (no model_id)"
kv_heads = fetch_kv_heads(str(model_id))
if kv_heads is None:
return path.name, "FAILED"
changed = update_toml(path, kv_heads)
return path.name, f"{kv_heads} ({'UPDATED' if changed else 'UNCHANGED'})"
def main():
parser = argparse.ArgumentParser(
description="Fetch num_key_value_heads from HuggingFace and update TOML cards."
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--all",
action="store_true",
help="Update all model cards (overwrite existing values)",
)
group.add_argument(
"--missing",
action="store_true",
help="Only update cards missing num_key_value_heads",
)
args = parser.parse_args()
toml_files = sorted(CARDS_DIR.glob("*.toml"))
if not toml_files:
print(f"No TOML files found in {CARDS_DIR}", file=sys.stderr)
sys.exit(1)
to_process = []
skipped = 0
for path in toml_files:
if args.missing and "num_key_value_heads" in path.read_text():
skipped += 1
continue
to_process.append(path)
updated = 0
failed = 0
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
futures = {pool.submit(process_card, path): path for path in to_process}
for future in as_completed(futures):
name, status = future.result()
print(f" {name}: {status}")
if "UPDATED" in status:
updated += 1
elif "FAILED" in status:
failed += 1
print(f"\nDone: {updated} updated, {skipped} skipped, {failed} failed")
if __name__ == "__main__":
main()