fix: match load_model's selective quantization for large model transfer

The receiver's nn.quantize() used no class_predicate, quantizing ALL
Linear/Embedding layers. load_model's internal quantize selectively
skips layers without .scales in the weights (e.g. lm_head, embeddings).
For large models with selective quantization this created shape
mismatches — broadcast weights couldn't load into incorrectly-quantized
layers, leaving them with garbage data.

Fix: use broadcast metadata weight names to build the same
class_predicate as load_model, also pass the mode parameter and respect
per-layer overrides from config.json. Exclude .safetensors.index.json
from metadata transfer to avoid stale weight references on receivers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Alex Cheema
2026-02-13 11:04:26 -08:00
parent 89e3159871
commit b5c4df2700
2 changed files with 35 additions and 7 deletions

View File

@@ -43,8 +43,16 @@ def _all_sum_cpu(x: mx.array, group: Group) -> mx.array:
def _is_metadata_file(filename: str) -> bool:
"""A metadata file is anything that isn't a weight file (.safetensors)."""
return not filename.endswith(".safetensors")
"""A metadata file is anything that isn't a weight file or weight index.
Weight indices (.safetensors.index.json) reference safetensors shard paths.
Transferring them to a receiver that has no safetensors files is harmless
today (load_model's glob doesn't match them), but excluding them avoids
stale references and keeps the transfer minimal.
"""
if filename.endswith(".safetensors"):
return False
return not filename.endswith(".safetensors.index.json")
def model_path_for_id(model_id: ModelId) -> Path:

View File

@@ -251,21 +251,41 @@ def shard_and_load(
assert isinstance(model, nn.Module)
if broadcast_state is not None:
# When receiver has no weight files, load_model skips quantization.
# Apply it explicitly so QuantizedLinear layers match broadcast weight shapes.
# When receiver has no weight files, load_model skips quantization
# (its class_predicate checks `f"{p}.scales" in weights`, which is
# always False when weights is empty). Apply quantization explicitly
# using the broadcast metadata to determine which layers are quantized,
# matching load_model's selective quantization logic exactly.
if not has_local_model:
config_path = model_path / "config.json"
with open(config_path) as f:
config = json.load(f) # pyright: ignore[reportAny]
quant_config: dict[str, int] | None = config.get( # pyright: ignore[reportAny]
quant_config: dict[str, Any] | None = config.get( # pyright: ignore[reportAny]
"quantization", None
)
if quant_config is not None:
logger.info(f"Applying quantization to receiver model: {quant_config}")
broadcast_weight_names = set(broadcast_state.meta.keys())
def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
# Per-layer overrides from config (e.g. "lm_head": false)
assert quant_config is not None
if p in quant_config:
return quant_config[p] # pyright: ignore[reportAny]
if not hasattr(m, "to_quantized"):
return False
# Only quantize layers whose .scales exist in broadcast weights
return f"{p}.scales" in broadcast_weight_names
group_size = int(quant_config.get("group_size", 64)) # pyright: ignore[reportAny]
bits = int(quant_config.get("bits", 4)) # pyright: ignore[reportAny]
mode: str = quant_config.get("mode", "affine") # pyright: ignore[reportAny]
nn.quantize( # pyright: ignore[reportUnknownMemberType]
model,
group_size=quant_config.get("group_size", 64),
bits=quant_config.get("bits", 4),
group_size=group_size,
bits=bits,
mode=mode,
class_predicate=_class_predicate,
)
# Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.