diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 1d09293a..694a8bb0 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -40,6 +40,7 @@ class ModelCard(CamelCaseModel): supports_tensor: bool tasks: list[ModelTask] components: list[ComponentInfo] | None = None + quantization: int | None = None @field_validator("tasks", mode="before") @classmethod @@ -413,7 +414,7 @@ MODEL_CARDS: dict[str, ModelCard] = { ), } -_IMAGE_MODEL_CARDS: dict[str, ModelCard] = { +_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = { "flux1-schnell": ModelCard( model_id=ModelId("black-forest-labs/FLUX.1-schnell"), storage_size=Memory.from_bytes(23782357120 + 9524621312), @@ -428,7 +429,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { storage_size=Memory.from_kb(0), n_layers=12, can_shard=False, - safetensors_index_filename=None, # Single file + safetensors_index_filename=None, ), ComponentInfo( component_name="text_encoder_2", @@ -442,7 +443,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { component_name="transformer", component_path="transformer/", storage_size=Memory.from_bytes(23782357120), - n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks + n_layers=57, can_shard=True, safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json", ), @@ -470,7 +471,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { storage_size=Memory.from_kb(0), n_layers=12, can_shard=False, - safetensors_index_filename=None, # Single file + safetensors_index_filename=None, ), ComponentInfo( component_name="text_encoder_2", @@ -484,7 +485,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { component_name="transformer", component_path="transformer/", storage_size=Memory.from_bytes(23802816640), - n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks + n_layers=57, can_shard=True, safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json", ), @@ -543,7 +544,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { "qwen-image": ModelCard( model_id=ModelId("Qwen/Qwen-Image"), storage_size=Memory.from_bytes(16584333312 + 40860802176), - n_layers=60, # Qwen has 60 transformer blocks (all joint-style) + n_layers=60, hidden_size=1, supports_tensor=False, tasks=[ModelTask.TextToImage], @@ -551,10 +552,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { ComponentInfo( component_name="text_encoder", component_path="text_encoder/", - storage_size=Memory.from_kb(16584333312), + storage_size=Memory.from_bytes(16584333312), n_layers=12, can_shard=False, - safetensors_index_filename=None, # Single file + safetensors_index_filename=None, ), ComponentInfo( component_name="transformer", @@ -577,7 +578,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { "qwen-image-edit-2509": ModelCard( model_id=ModelId("Qwen/Qwen-Image-Edit-2509"), storage_size=Memory.from_bytes(16584333312 + 40860802176), - n_layers=60, # Qwen has 60 transformer blocks (all joint-style) + n_layers=60, hidden_size=1, supports_tensor=False, tasks=[ModelTask.ImageToImage], @@ -585,10 +586,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { ComponentInfo( component_name="text_encoder", component_path="text_encoder/", - storage_size=Memory.from_kb(16584333312), + storage_size=Memory.from_bytes(16584333312), n_layers=12, can_shard=False, - safetensors_index_filename=None, # Single file + safetensors_index_filename=None, ), ComponentInfo( component_name="transformer", @@ -610,6 +611,91 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = { ), } + +def _create_image_model_quant_variants( + base_name: str, + base_card: ModelCard, +) -> dict[str, ModelCard]: + """Create quantized variants of an image model card. + + Only the transformer component is quantized; text encoders stay at bf16. + Sizes are calculated exactly from the base card's component sizes. + """ + if base_card.components is None: + raise ValueError(f"Image model {base_name} must have components defined") + + quantizations = [8, 6, 5, 4, 3] + + num_transformer_bytes = next( + c.storage_size.in_bytes + for c in base_card.components + if c.component_name == "transformer" + ) + + transformer_bytes = Memory.from_bytes(num_transformer_bytes) + + remaining_bytes = Memory.from_bytes( + sum( + c.storage_size.in_bytes + for c in base_card.components + if c.component_name != "transformer" + ) + ) + + def with_transformer_size(new_size: Memory) -> list[ComponentInfo]: + assert base_card.components is not None + return [ + ComponentInfo( + component_name=c.component_name, + component_path=c.component_path, + storage_size=new_size + if c.component_name == "transformer" + else c.storage_size, + n_layers=c.n_layers, + can_shard=c.can_shard, + safetensors_index_filename=c.safetensors_index_filename, + ) + for c in base_card.components + ] + + variants = { + base_name: ModelCard( + model_id=base_card.model_id, + storage_size=transformer_bytes + remaining_bytes, + n_layers=base_card.n_layers, + hidden_size=base_card.hidden_size, + supports_tensor=base_card.supports_tensor, + tasks=base_card.tasks, + components=with_transformer_size(transformer_bytes), + quantization=None, + ) + } + + for quant in quantizations: + quant_transformer_bytes = Memory.from_bytes( + (num_transformer_bytes * quant) // 16 + ) + total_bytes = remaining_bytes + quant_transformer_bytes + + variants[f"{base_name}-{quant}bit"] = ModelCard( + model_id=base_card.model_id, + storage_size=total_bytes, + n_layers=base_card.n_layers, + hidden_size=base_card.hidden_size, + supports_tensor=base_card.supports_tensor, + tasks=base_card.tasks, + components=with_transformer_size(quant_transformer_bytes), + quantization=quant, + ) + + return variants + + +_image_model_cards: dict[str, ModelCard] = {} +for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items(): + _image_model_cards |= _create_image_model_quant_variants(_base_name, _base_card) +_IMAGE_MODEL_CARDS = _image_model_cards + if EXO_ENABLE_IMAGE_MODELS: MODEL_CARDS.update(_IMAGE_MODEL_CARDS)