#!/usr/bin/env python3 """ Download an mflux model, quantize it, and upload to HuggingFace. Usage (run from mflux project directory): cd /path/to/mflux uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run Requires: - Must be run from mflux project directory using `uv run` - huggingface_hub installed (add to mflux deps or install separately) - HuggingFace authentication: run `huggingface-cli login` or set HF_TOKEN """ from __future__ import annotations import argparse import re import shutil import sys from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: from mflux.models.flux.variants.txt2img.flux import Flux1 HF_ORG = "exolabs" def get_model_class(model_name: str) -> type: """Get the appropriate model class based on model name.""" from mflux.models.fibo.variants.txt2img.fibo import FIBO from mflux.models.flux.variants.txt2img.flux import Flux1 from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage from mflux.models.z_image.variants.turbo.z_image_turbo import ZImageTurbo model_name_lower = model_name.lower() if "qwen" in model_name_lower: return QwenImage elif "fibo" in model_name_lower: return FIBO elif "z-image" in model_name_lower or "zimage" in model_name_lower: return ZImageTurbo elif "flux2" in model_name_lower or "flux.2" in model_name_lower: return Flux2Klein else: return Flux1 def get_repo_name(model_name: str, bits: int | None) -> str: """Get the HuggingFace repo name for a model variant.""" # Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev") base_name = model_name.split("/")[-1] if "/" in model_name else model_name suffix = f"-{bits}bit" if bits else "" return f"{HF_ORG}/{base_name}{suffix}" def get_local_path(output_dir: Path, model_name: str, bits: int | None) -> Path: """Get the local save path for a model variant.""" # Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev") base_name = model_name.split("/")[-1] if "/" in model_name else model_name suffix = f"-{bits}bit" if bits else "" return output_dir / f"{base_name}{suffix}" def copy_source_repo( source_repo: str, local_path: Path, dry_run: bool = False, ) -> None: """Copy all files from source repo (replicating original HF structure).""" print(f"\n{'=' * 60}") print(f"Copying full repo from source: {source_repo}") print(f"Output path: {local_path}") print(f"{'=' * 60}") if dry_run: print("[DRY RUN] Would download all files from source repo") return from huggingface_hub import snapshot_download # Download all files to our local path snapshot_download( repo_id=source_repo, local_dir=local_path, ) # Remove root-level safetensors files (flux.1-dev.safetensors, etc.) # These are redundant with the component directories for f in local_path.glob("*.safetensors"): print(f"Removing root-level safetensors: {f.name}") if not dry_run: f.unlink() print(f"Source repo copied to {local_path}") def load_and_save_quantized_model( model_name: str, bits: int, output_path: Path, dry_run: bool = False, ) -> None: """Load a model with quantization and save it in mflux format.""" print(f"\n{'=' * 60}") print(f"Loading {model_name} with {bits}-bit quantization...") print(f"Output path: {output_path}") print(f"{'=' * 60}") if dry_run: print("[DRY RUN] Would load and save quantized model") return from mflux.models.common.config.model_config import ModelConfig model_class = get_model_class(model_name) model_config = ModelConfig.from_name(model_name=model_name, base_model=None) model: Flux1 = model_class( quantize=bits, model_config=model_config, ) print(f"Saving model to {output_path}...") model.save_model(str(output_path)) print(f"Model saved successfully to {output_path}") def copy_source_metadata( source_repo: str, local_path: Path, dry_run: bool = False, ) -> None: """Copy metadata files (LICENSE, README, etc.) from source repo, excluding safetensors.""" print(f"\n{'=' * 60}") print(f"Copying metadata from source repo: {source_repo}") print(f"{'=' * 60}") if dry_run: print("[DRY RUN] Would download metadata files (excluding *.safetensors)") return from huggingface_hub import snapshot_download # Download all files except safetensors to our local path snapshot_download( repo_id=source_repo, local_dir=local_path, ignore_patterns=["*.safetensors"], ) print(f"Metadata files copied to {local_path}") def upload_to_huggingface( local_path: Path, repo_id: str, dry_run: bool = False, clean_remote: bool = False, ) -> None: """Upload a saved model to HuggingFace.""" print(f"\n{'=' * 60}") print(f"Uploading to HuggingFace: {repo_id}") print(f"Local path: {local_path}") print(f"Clean remote first: {clean_remote}") print(f"{'=' * 60}") if dry_run: print("[DRY RUN] Would upload to HuggingFace") return from huggingface_hub import HfApi api = HfApi() # Create the repo if it doesn't exist print(f"Creating/verifying repo: {repo_id}") api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) # Clean remote repo if requested (delete old mflux-format files) if clean_remote: print("Cleaning old mflux-format files from remote...") try: # Pattern for mflux numbered shards: