mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
fix: match load_model's selective quantization for large model transfer
The receiver's nn.quantize() used no class_predicate, quantizing ALL Linear/Embedding layers. load_model's internal quantize selectively skips layers without .scales in the weights (e.g. lm_head, embeddings). For large models with selective quantization this created shape mismatches — broadcast weights couldn't load into incorrectly-quantized layers, leaving them with garbage data. Fix: use broadcast metadata weight names to build the same class_predicate as load_model, also pass the mode parameter and respect per-layer overrides from config.json. Exclude .safetensors.index.json from metadata transfer to avoid stale weight references on receivers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -43,8 +43,16 @@ def _all_sum_cpu(x: mx.array, group: Group) -> mx.array:
|
||||
|
||||
|
||||
def _is_metadata_file(filename: str) -> bool:
|
||||
"""A metadata file is anything that isn't a weight file (.safetensors)."""
|
||||
return not filename.endswith(".safetensors")
|
||||
"""A metadata file is anything that isn't a weight file or weight index.
|
||||
|
||||
Weight indices (.safetensors.index.json) reference safetensors shard paths.
|
||||
Transferring them to a receiver that has no safetensors files is harmless
|
||||
today (load_model's glob doesn't match them), but excluding them avoids
|
||||
stale references and keeps the transfer minimal.
|
||||
"""
|
||||
if filename.endswith(".safetensors"):
|
||||
return False
|
||||
return not filename.endswith(".safetensors.index.json")
|
||||
|
||||
|
||||
def model_path_for_id(model_id: ModelId) -> Path:
|
||||
|
||||
@@ -251,21 +251,41 @@ def shard_and_load(
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
if broadcast_state is not None:
|
||||
# When receiver has no weight files, load_model skips quantization.
|
||||
# Apply it explicitly so QuantizedLinear layers match broadcast weight shapes.
|
||||
# When receiver has no weight files, load_model skips quantization
|
||||
# (its class_predicate checks `f"{p}.scales" in weights`, which is
|
||||
# always False when weights is empty). Apply quantization explicitly
|
||||
# using the broadcast metadata to determine which layers are quantized,
|
||||
# matching load_model's selective quantization logic exactly.
|
||||
if not has_local_model:
|
||||
config_path = model_path / "config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f) # pyright: ignore[reportAny]
|
||||
quant_config: dict[str, int] | None = config.get( # pyright: ignore[reportAny]
|
||||
quant_config: dict[str, Any] | None = config.get( # pyright: ignore[reportAny]
|
||||
"quantization", None
|
||||
)
|
||||
if quant_config is not None:
|
||||
logger.info(f"Applying quantization to receiver model: {quant_config}")
|
||||
broadcast_weight_names = set(broadcast_state.meta.keys())
|
||||
|
||||
def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
|
||||
# Per-layer overrides from config (e.g. "lm_head": false)
|
||||
assert quant_config is not None
|
||||
if p in quant_config:
|
||||
return quant_config[p] # pyright: ignore[reportAny]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
# Only quantize layers whose .scales exist in broadcast weights
|
||||
return f"{p}.scales" in broadcast_weight_names
|
||||
|
||||
group_size = int(quant_config.get("group_size", 64)) # pyright: ignore[reportAny]
|
||||
bits = int(quant_config.get("bits", 4)) # pyright: ignore[reportAny]
|
||||
mode: str = quant_config.get("mode", "affine") # pyright: ignore[reportAny]
|
||||
nn.quantize( # pyright: ignore[reportUnknownMemberType]
|
||||
model,
|
||||
group_size=quant_config.get("group_size", 64),
|
||||
bits=quant_config.get("bits", 4),
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mode,
|
||||
class_predicate=_class_predicate,
|
||||
)
|
||||
|
||||
# Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.
|
||||
|
||||
Reference in New Issue
Block a user