diff --git a/src/exo/worker/engines/mlx/model_transfer.py b/src/exo/worker/engines/mlx/model_transfer.py index 4ada7cb8c..ae30f684b 100644 --- a/src/exo/worker/engines/mlx/model_transfer.py +++ b/src/exo/worker/engines/mlx/model_transfer.py @@ -43,8 +43,16 @@ def _all_sum_cpu(x: mx.array, group: Group) -> mx.array: def _is_metadata_file(filename: str) -> bool: - """A metadata file is anything that isn't a weight file (.safetensors).""" - return not filename.endswith(".safetensors") + """A metadata file is anything that isn't a weight file or weight index. + + Weight indices (.safetensors.index.json) reference safetensors shard paths. + Transferring them to a receiver that has no safetensors files is harmless + today (load_model's glob doesn't match them), but excluding them avoids + stale references and keeps the transfer minimal. + """ + if filename.endswith(".safetensors"): + return False + return not filename.endswith(".safetensors.index.json") def model_path_for_id(model_id: ModelId) -> Path: diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 1b8730371..0b9b2370f 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -251,21 +251,41 @@ def shard_and_load( assert isinstance(model, nn.Module) if broadcast_state is not None: - # When receiver has no weight files, load_model skips quantization. - # Apply it explicitly so QuantizedLinear layers match broadcast weight shapes. + # When receiver has no weight files, load_model skips quantization + # (its class_predicate checks `f"{p}.scales" in weights`, which is + # always False when weights is empty). Apply quantization explicitly + # using the broadcast metadata to determine which layers are quantized, + # matching load_model's selective quantization logic exactly. if not has_local_model: config_path = model_path / "config.json" with open(config_path) as f: config = json.load(f) # pyright: ignore[reportAny] - quant_config: dict[str, int] | None = config.get( # pyright: ignore[reportAny] + quant_config: dict[str, Any] | None = config.get( # pyright: ignore[reportAny] "quantization", None ) if quant_config is not None: logger.info(f"Applying quantization to receiver model: {quant_config}") + broadcast_weight_names = set(broadcast_state.meta.keys()) + + def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]: + # Per-layer overrides from config (e.g. "lm_head": false) + assert quant_config is not None + if p in quant_config: + return quant_config[p] # pyright: ignore[reportAny] + if not hasattr(m, "to_quantized"): + return False + # Only quantize layers whose .scales exist in broadcast weights + return f"{p}.scales" in broadcast_weight_names + + group_size = int(quant_config.get("group_size", 64)) # pyright: ignore[reportAny] + bits = int(quant_config.get("bits", 4)) # pyright: ignore[reportAny] + mode: str = quant_config.get("mode", "affine") # pyright: ignore[reportAny] nn.quantize( # pyright: ignore[reportUnknownMemberType] model, - group_size=quant_config.get("group_size", 64), - bits=quant_config.get("bits", 4), + group_size=group_size, + bits=bits, + mode=mode, + class_predicate=_class_predicate, ) # Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.