mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 19:52:16 -05:00
Compare commits
1 Commits
ciaran/flu
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b907398a4 |
@@ -1,7 +0,0 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
|
||||
|
||||
__all__ = ["Flux1Kontext"]
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from mlx import nn
|
||||
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
|
||||
CLIPEncoder,
|
||||
)
|
||||
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.model.flux_vae.vae import VAE
|
||||
from mflux.utils.generated_image import GeneratedImage
|
||||
|
||||
class Flux1Kontext(nn.Module):
|
||||
vae: VAE
|
||||
transformer: Transformer
|
||||
t5_text_encoder: T5Encoder
|
||||
clip_text_encoder: CLIPEncoder
|
||||
bits: int | None
|
||||
lora_paths: list[str] | None
|
||||
lora_scales: list[float] | None
|
||||
prompt_cache: dict[str, Any]
|
||||
tokenizers: dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quantize: int | None = ...,
|
||||
model_path: str | None = ...,
|
||||
lora_paths: list[str] | None = ...,
|
||||
lora_scales: list[float] | None = ...,
|
||||
model_config: ModelConfig = ...,
|
||||
) -> None: ...
|
||||
def generate_image(
|
||||
self,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
num_inference_steps: int = ...,
|
||||
height: int = ...,
|
||||
width: int = ...,
|
||||
guidance: float = ...,
|
||||
image_path: Path | str | None = ...,
|
||||
image_strength: float | None = ...,
|
||||
scheduler: str = ...,
|
||||
) -> GeneratedImage: ...
|
||||
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mflux.models.flux.model.flux_vae.vae import VAE
|
||||
|
||||
class KontextUtil:
|
||||
@staticmethod
|
||||
def create_image_conditioning_latents(
|
||||
vae: VAE,
|
||||
height: int,
|
||||
width: int,
|
||||
image_path: str,
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
2
justfile
2
justfile
@@ -20,7 +20,7 @@ sync-clean:
|
||||
|
||||
rust-rebuild:
|
||||
cargo run --bin stub_gen
|
||||
just sync-clean
|
||||
uv sync --reinstall-package exo_pyo3_bindings
|
||||
|
||||
build-dashboard:
|
||||
#!/usr/bin/env bash
|
||||
|
||||
@@ -26,7 +26,7 @@ dependencies = [
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux==0.15.5",
|
||||
"mflux==0.15.4",
|
||||
"python-multipart>=0.0.21",
|
||||
]
|
||||
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -16,6 +16,7 @@ from exo.download.download_utils import (
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
@@ -107,6 +108,13 @@ class DownloadCoordinator:
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
case CancelDownload(model_id=model_id):
|
||||
await self._cancel_download(model_id)
|
||||
|
||||
async def _cancel_download(self, model_id: ModelId) -> None:
|
||||
if model_id in self.active_downloads and model_id in self.download_status:
|
||||
logger.info(f"Cancelling download for {model_id}")
|
||||
self.active_downloads.pop(model_id).cancel()
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
@@ -105,6 +105,7 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -188,6 +189,9 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
add_instance_to_placements,
|
||||
cancel_unnecessary_downloads,
|
||||
delete_instance,
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
@@ -16,6 +17,7 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
@@ -66,12 +68,9 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
# Receiving indexed events from the forwarder to be applied to state
|
||||
# Ideally these would be WorkerForwarderEvents but type system says no :(
|
||||
local_event_receiver: Receiver[ForwarderEvent],
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -81,6 +80,7 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -280,6 +280,14 @@ class Master:
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id, command=cmd
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
|
||||
@@ -15,14 +15,20 @@ from exo.master.placement_utils import (
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -202,3 +208,29 @@ def get_transition_events(
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def cancel_unnecessary_downloads(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> Sequence[DownloadCommand]:
|
||||
commands: list[DownloadCommand] = []
|
||||
currently_downloading = [
|
||||
(k, v.shard_metadata.model_card.model_id)
|
||||
for k, vs in download_status.items()
|
||||
for v in vs
|
||||
if isinstance(v, (DownloadOngoing))
|
||||
]
|
||||
active_models = set(
|
||||
(
|
||||
node_id,
|
||||
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
|
||||
)
|
||||
for instance in instances.values()
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
)
|
||||
for pair in currently_downloading:
|
||||
if pair not in active_models:
|
||||
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
|
||||
|
||||
return commands
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -47,6 +48,7 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +69,7 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
download_command_sender=fcds,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -72,7 +72,12 @@ class DeleteDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
class CancelDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
Command = (
|
||||
|
||||
@@ -5,9 +5,7 @@ from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter
|
||||
from exo.worker.engines.image.models.flux import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_KONTEXT_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
FluxKontextModelAdapter,
|
||||
FluxModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen import (
|
||||
@@ -28,16 +26,13 @@ AdapterFactory = Callable[
|
||||
# Registry maps model_family string to adapter factory
|
||||
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
"flux": FluxModelAdapter,
|
||||
"flux-kontext": FluxKontextModelAdapter,
|
||||
"qwen-edit": QwenEditModelAdapter,
|
||||
"qwen": QwenModelAdapter,
|
||||
}
|
||||
|
||||
# Config registry: maps model ID patterns to configs
|
||||
# Order matters: longer/more-specific patterns must come before shorter ones
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-kontext": FLUX_KONTEXT_CONFIG, # Must come before "flux.1-dev" for pattern matching
|
||||
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
|
||||
@@ -66,19 +66,6 @@ class PromptData(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
"""Kontext-style position IDs for image conditioning.
|
||||
|
||||
For FLUX.1-Kontext models, returns position IDs with first_coord=1
|
||||
to distinguish conditioning tokens from generation tokens (first_coord=0).
|
||||
|
||||
Returns:
|
||||
Position IDs array [1, seq_len, 3] for Kontext, None for other models.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
|
||||
from exo.worker.engines.image.models.flux.config import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_KONTEXT_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
)
|
||||
from exo.worker.engines.image.models.flux.kontext_adapter import (
|
||||
FluxKontextModelAdapter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FluxModelAdapter",
|
||||
"FluxKontextModelAdapter",
|
||||
"FLUX_DEV_CONFIG",
|
||||
"FLUX_KONTEXT_CONFIG",
|
||||
"FLUX_SCHNELL_CONFIG",
|
||||
]
|
||||
|
||||
@@ -59,10 +59,6 @@ class FluxPromptData(PromptData):
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
|
||||
@@ -32,19 +32,3 @@ FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
)
|
||||
|
||||
|
||||
FLUX_KONTEXT_CONFIG = ImageModelConfig(
|
||||
model_family="flux-kontext",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
guidance_scale=4.0,
|
||||
)
|
||||
|
||||
@@ -1,348 +0,0 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, final
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
|
||||
from mflux.models.flux.variants.kontext.kontext_util import KontextUtil
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
RotaryEmbeddings,
|
||||
)
|
||||
from exo.worker.engines.image.models.flux.wrappers import (
|
||||
FluxJointBlockWrapper,
|
||||
FluxSingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FluxKontextPromptData(PromptData):
|
||||
"""Prompt data for FLUX.1-Kontext image editing.
|
||||
|
||||
Stores text embeddings along with conditioning latents and position IDs
|
||||
for the input image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
conditioning_latents: mx.array,
|
||||
kontext_image_ids: mx.array,
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._pooled_prompt_embeds = pooled_prompt_embeds
|
||||
self._conditioning_latents = conditioning_latents
|
||||
self._kontext_image_ids = kontext_image_ids
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""VAE-encoded input image latents for Kontext conditioning."""
|
||||
return self._conditioning_latents
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
"""Position IDs for Kontext conditioning (first_coord=1)."""
|
||||
return self._kontext_image_ids
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Kontext doesn't use CFG, but we return positive data for compatibility."""
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
None,
|
||||
self._pooled_prompt_embeds,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
# Kontext doesn't use CFG
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
class FluxKontextModelAdapter(ModelAdapter[Flux1Kontext, Transformer]):
|
||||
"""Adapter for FLUX.1-Kontext image editing model.
|
||||
|
||||
Key differences from standard FluxModelAdapter:
|
||||
- Takes an input image and computes output dimensions from it
|
||||
- Creates conditioning latents from the input image via VAE
|
||||
- Creates special position IDs (kontext_image_ids) for conditioning tokens
|
||||
- Creates pure noise latents (not img2img blending)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = Flux1Kontext(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
model_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
# Stores image path and computed dimensions after set_image_dimensions
|
||||
self._image_path: str | None = None
|
||||
self._output_height: int | None = None
|
||||
self._output_width: int | None = None
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.x_embedder.weight.shape[0] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
return False
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return FluxLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper[Any]]:
|
||||
"""Create wrapped joint blocks for Flux Kontext."""
|
||||
return [
|
||||
FluxJointBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper[Any]]:
|
||||
"""Create wrapped single blocks for Flux Kontext."""
|
||||
return [
|
||||
FluxSingleBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.single_transformer_blocks
|
||||
]
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
all_joint = list(self._transformer.transformer_blocks)
|
||||
all_single = list(self._transformer.single_transformer_blocks)
|
||||
total_joint_blocks = len(all_joint)
|
||||
if end_layer <= total_joint_blocks:
|
||||
# All assigned are joint blocks
|
||||
joint_start, joint_end = start_layer, end_layer
|
||||
single_start, single_end = 0, 0
|
||||
elif start_layer >= total_joint_blocks:
|
||||
# All assigned are single blocks
|
||||
joint_start, joint_end = 0, 0
|
||||
single_start = start_layer - total_joint_blocks
|
||||
single_end = end_layer - total_joint_blocks
|
||||
else:
|
||||
# Spans both joint and single
|
||||
joint_start, joint_end = start_layer, total_joint_blocks
|
||||
single_start = 0
|
||||
single_end = end_layer - total_joint_blocks
|
||||
|
||||
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
|
||||
self._transformer.single_transformer_blocks = all_single[
|
||||
single_start:single_end
|
||||
]
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
|
||||
"""Compute and store dimensions from input image.
|
||||
|
||||
Also stores image_path for use in encode_prompt().
|
||||
|
||||
Args:
|
||||
image_path: Path to the input image
|
||||
|
||||
Returns:
|
||||
(output_width, output_height) for runtime config
|
||||
"""
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
|
||||
image_size = pil_image.size
|
||||
|
||||
# Compute output dimensions from input image aspect ratio
|
||||
# Target area of 1024x1024 = ~1M pixels
|
||||
target_area = 1024 * 1024
|
||||
ratio = image_size[0] / image_size[1]
|
||||
output_width = math.sqrt(target_area * ratio)
|
||||
output_height = output_width / ratio
|
||||
output_width = round(output_width / 32) * 32
|
||||
output_height = round(output_height / 32) * 32
|
||||
|
||||
# Ensure multiple of 16 for VAE
|
||||
vae_scale_factor = 8
|
||||
multiple_of = vae_scale_factor * 2
|
||||
output_width = output_width // multiple_of * multiple_of
|
||||
output_height = output_height // multiple_of * multiple_of
|
||||
|
||||
self._image_path = str(image_path)
|
||||
self._output_width = int(output_width)
|
||||
self._output_height = int(output_height)
|
||||
|
||||
return self._output_width, self._output_height
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
|
||||
"""Create initial noise latents for Kontext.
|
||||
|
||||
Unlike standard img2img which blends noise with encoded input,
|
||||
Kontext uses pure noise latents. The input image is provided
|
||||
separately as conditioning.
|
||||
"""
|
||||
return FluxLatentCreator.create_noise(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> FluxKontextPromptData:
|
||||
"""Encode prompt and create conditioning from stored input image.
|
||||
|
||||
Must call set_image_dimensions() before this method.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for editing
|
||||
negative_prompt: Ignored (Kontext doesn't use CFG)
|
||||
|
||||
Returns:
|
||||
FluxKontextPromptData with text embeddings and image conditioning
|
||||
"""
|
||||
del negative_prompt # Kontext doesn't support negative prompts or CFG
|
||||
|
||||
if (
|
||||
self._image_path is None
|
||||
or self._output_height is None
|
||||
or self._output_width is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"set_image_dimensions() must be called before encode_prompt() "
|
||||
"for FluxKontextModelAdapter"
|
||||
)
|
||||
|
||||
assert isinstance(self.model.prompt_cache, dict)
|
||||
assert isinstance(self.model.tokenizers, dict)
|
||||
|
||||
# Encode text prompt
|
||||
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_cache=self.model.prompt_cache,
|
||||
t5_tokenizer=self.model.tokenizers["t5"], # pyright: ignore[reportAny]
|
||||
clip_tokenizer=self.model.tokenizers["clip"], # pyright: ignore[reportAny]
|
||||
t5_text_encoder=self.model.t5_text_encoder,
|
||||
clip_text_encoder=self.model.clip_text_encoder,
|
||||
)
|
||||
|
||||
# Create conditioning latents from input image
|
||||
conditioning_latents, kontext_image_ids = (
|
||||
KontextUtil.create_image_conditioning_latents(
|
||||
vae=self.model.vae,
|
||||
height=self._output_height,
|
||||
width=self._output_width,
|
||||
image_path=self._image_path,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxKontextPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
conditioning_latents=conditioning_latents,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.x_embedder(hidden_states)
|
||||
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
if pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"pooled_prompt_embeds is required for Flux Kontext text embeddings"
|
||||
)
|
||||
|
||||
return Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
|
||||
)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> RotaryEmbeddings:
|
||||
return Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
self._transformer.pos_embed,
|
||||
runtime_config,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
raise NotImplementedError("Flux Kontext does not use classifier-free guidance")
|
||||
@@ -69,10 +69,6 @@ class QwenPromptData(PromptData):
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
|
||||
@@ -85,10 +85,6 @@ class QwenEditPromptData(PromptData):
|
||||
def qwen_image_ids(self) -> mx.array:
|
||||
return self._qwen_image_ids
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_edit_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -567,7 +567,6 @@ class DiffusionRunner:
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
conditioning_latents: mx.array | None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Run a single forward pass through the transformer.
|
||||
Args:
|
||||
@@ -579,7 +578,6 @@ class DiffusionRunner:
|
||||
encoder_hidden_states_mask: Attention mask for text (Qwen)
|
||||
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
|
||||
conditioning_latents: Conditioning latents for edit mode
|
||||
kontext_image_ids: Position IDs for Kontext conditioning (Flux Kontext)
|
||||
|
||||
Returns:
|
||||
Noise prediction tensor
|
||||
@@ -612,7 +610,6 @@ class DiffusionRunner:
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
@@ -684,7 +681,6 @@ class DiffusionRunner:
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
@@ -704,7 +700,6 @@ class DiffusionRunner:
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=branch.cond_latents,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
@@ -907,10 +902,10 @@ class DiffusionRunner:
|
||||
config: Config,
|
||||
hidden_states: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
|
||||
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
|
||||
@@ -984,10 +979,10 @@ class DiffusionRunner:
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
is_first_async_step: bool,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
done
|
||||
|
||||
|
||||
@@ -1,377 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download an mflux model, quantize it, and upload to HuggingFace.
|
||||
|
||||
Usage (run from mflux project directory):
|
||||
cd /path/to/mflux
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
|
||||
|
||||
Requires:
|
||||
- Must be run from mflux project directory using `uv run`
|
||||
- huggingface_hub installed (add to mflux deps or install separately)
|
||||
- HuggingFace authentication: run `huggingface-cli login` or set HF_TOKEN
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
|
||||
|
||||
HF_ORG = "exolabs"
|
||||
|
||||
|
||||
def get_model_class(model_name: str) -> type:
|
||||
"""Get the appropriate model class based on model name."""
|
||||
from mflux.models.fibo.variants.txt2img.fibo import FIBO
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
from mflux.models.z_image.variants.turbo.z_image_turbo import ZImageTurbo
|
||||
|
||||
model_name_lower = model_name.lower()
|
||||
if "qwen" in model_name_lower:
|
||||
return QwenImage
|
||||
elif "fibo" in model_name_lower:
|
||||
return FIBO
|
||||
elif "z-image" in model_name_lower or "zimage" in model_name_lower:
|
||||
return ZImageTurbo
|
||||
elif "flux2" in model_name_lower or "flux.2" in model_name_lower:
|
||||
return Flux2Klein
|
||||
else:
|
||||
return Flux1
|
||||
|
||||
|
||||
def get_repo_name(model_name: str, bits: int | None) -> str:
|
||||
"""Get the HuggingFace repo name for a model variant."""
|
||||
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
|
||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
suffix = f"-{bits}bit" if bits else ""
|
||||
return f"{HF_ORG}/{base_name}{suffix}"
|
||||
|
||||
|
||||
def get_local_path(output_dir: Path, model_name: str, bits: int | None) -> Path:
|
||||
"""Get the local save path for a model variant."""
|
||||
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
|
||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
suffix = f"-{bits}bit" if bits else ""
|
||||
return output_dir / f"{base_name}{suffix}"
|
||||
|
||||
|
||||
def copy_source_repo(
|
||||
source_repo: str,
|
||||
local_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Copy all files from source repo (replicating original HF structure)."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Copying full repo from source: {source_repo}")
|
||||
print(f"Output path: {local_path}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would download all files from source repo")
|
||||
return
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download all files to our local path
|
||||
snapshot_download(
|
||||
repo_id=source_repo,
|
||||
local_dir=local_path,
|
||||
)
|
||||
|
||||
# Remove root-level safetensors files (flux.1-dev.safetensors, etc.)
|
||||
# These are redundant with the component directories
|
||||
for f in local_path.glob("*.safetensors"):
|
||||
print(f"Removing root-level safetensors: {f.name}")
|
||||
if not dry_run:
|
||||
f.unlink()
|
||||
|
||||
print(f"Source repo copied to {local_path}")
|
||||
|
||||
|
||||
def load_and_save_quantized_model(
|
||||
model_name: str,
|
||||
bits: int,
|
||||
output_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Load a model with quantization and save it in mflux format."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Loading {model_name} with {bits}-bit quantization...")
|
||||
print(f"Output path: {output_path}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would load and save quantized model")
|
||||
return
|
||||
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
|
||||
model_class = get_model_class(model_name)
|
||||
model_config = ModelConfig.from_name(model_name=model_name, base_model=None)
|
||||
|
||||
model: Flux1 = model_class(
|
||||
quantize=bits,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
print(f"Saving model to {output_path}...")
|
||||
model.save_model(str(output_path))
|
||||
print(f"Model saved successfully to {output_path}")
|
||||
|
||||
|
||||
def copy_source_metadata(
|
||||
source_repo: str,
|
||||
local_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Copy metadata files (LICENSE, README, etc.) from source repo, excluding safetensors."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Copying metadata from source repo: {source_repo}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would download metadata files (excluding *.safetensors)")
|
||||
return
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download all files except safetensors to our local path
|
||||
snapshot_download(
|
||||
repo_id=source_repo,
|
||||
local_dir=local_path,
|
||||
ignore_patterns=["*.safetensors"],
|
||||
)
|
||||
print(f"Metadata files copied to {local_path}")
|
||||
|
||||
|
||||
def upload_to_huggingface(
|
||||
local_path: Path,
|
||||
repo_id: str,
|
||||
dry_run: bool = False,
|
||||
clean_remote: bool = False,
|
||||
) -> None:
|
||||
"""Upload a saved model to HuggingFace."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Uploading to HuggingFace: {repo_id}")
|
||||
print(f"Local path: {local_path}")
|
||||
print(f"Clean remote first: {clean_remote}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would upload to HuggingFace")
|
||||
return
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Create the repo if it doesn't exist
|
||||
print(f"Creating/verifying repo: {repo_id}")
|
||||
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
|
||||
|
||||
# Clean remote repo if requested (delete old mflux-format files)
|
||||
if clean_remote:
|
||||
print("Cleaning old mflux-format files from remote...")
|
||||
try:
|
||||
# Pattern for mflux numbered shards: <dir>/<number>.safetensors
|
||||
numbered_pattern = re.compile(r".*/\d+\.safetensors$")
|
||||
|
||||
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="model")
|
||||
for file_path in repo_files:
|
||||
# Delete numbered safetensors (mflux format) and mflux index files
|
||||
if numbered_pattern.match(file_path) or file_path.endswith(
|
||||
"/model.safetensors.index.json"
|
||||
):
|
||||
print(f" Deleting: {file_path}")
|
||||
api.delete_file(
|
||||
path_in_repo=file_path, repo_id=repo_id, repo_type="model"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not clean remote files: {e}")
|
||||
|
||||
# Upload the folder
|
||||
print("Uploading folder contents...")
|
||||
api.upload_folder(
|
||||
folder_path=str(local_path),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f"Upload complete: https://huggingface.co/{repo_id}")
|
||||
|
||||
|
||||
def clean_local_files(local_path: Path, dry_run: bool = False) -> None:
|
||||
"""Remove local model files after upload."""
|
||||
print(f"\nCleaning up: {local_path}")
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would remove local files")
|
||||
return
|
||||
|
||||
if local_path.exists():
|
||||
shutil.rmtree(local_path)
|
||||
print(f"Removed {local_path}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download an mflux model, quantize it, and upload to HuggingFace.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Process all variants (base, 4-bit, 8-bit) for FLUX.1-Kontext-dev
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
|
||||
|
||||
# Only process 4-bit variant
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
|
||||
|
||||
# Save locally without uploading
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-upload
|
||||
|
||||
# Preview what would happen
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
required=True,
|
||||
help="HuggingFace model path (e.g., black-forest-labs/FLUX.1-Kontext-dev)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("./tmp/models"),
|
||||
help="Local directory to save models (default: ./tmp/models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-base",
|
||||
action="store_true",
|
||||
help="Skip base model (no quantization)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-4bit",
|
||||
action="store_true",
|
||||
help="Skip 4-bit quantized model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-8bit",
|
||||
action="store_true",
|
||||
help="Skip 8-bit quantized model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-download",
|
||||
action="store_true",
|
||||
help="Skip downloading/processing, only do upload/clean operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload",
|
||||
action="store_true",
|
||||
help="Only save locally, don't upload to HuggingFace",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean",
|
||||
action="store_true",
|
||||
help="Remove local files after upload",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean-remote",
|
||||
action="store_true",
|
||||
help="Delete old mflux-format files from remote repo before uploading",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print actions without executing",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine which variants to process
|
||||
variants: list[int | None] = []
|
||||
if not args.skip_base:
|
||||
variants.append(None) # Base model (no quantization)
|
||||
if not args.skip_4bit:
|
||||
variants.append(4)
|
||||
if not args.skip_8bit:
|
||||
variants.append(8)
|
||||
|
||||
if not variants:
|
||||
print("Error: All variants skipped. Nothing to do.")
|
||||
return 1
|
||||
|
||||
# Create output directory
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Output directory: {args.output_dir}")
|
||||
print(
|
||||
f"Variants to process: {['base' if v is None else f'{v}-bit' for v in variants]}"
|
||||
)
|
||||
print(f"Upload to HuggingFace: {not args.skip_upload}")
|
||||
print(f"Clean after upload: {args.clean}")
|
||||
if args.dry_run:
|
||||
print("\n*** DRY RUN MODE - No actual changes will be made ***")
|
||||
|
||||
# Process each variant
|
||||
for bits in variants:
|
||||
local_path = get_local_path(args.output_dir, args.model, bits)
|
||||
repo_id = get_repo_name(args.model, bits)
|
||||
|
||||
if not args.skip_download:
|
||||
if bits is None:
|
||||
# Base model: copy original HF repo structure (no mflux conversion)
|
||||
copy_source_repo(
|
||||
source_repo=args.model,
|
||||
local_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
else:
|
||||
# Quantized model: load, quantize, and save with mflux
|
||||
load_and_save_quantized_model(
|
||||
model_name=args.model,
|
||||
bits=bits,
|
||||
output_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
# Copy metadata from source repo (LICENSE, README, etc.)
|
||||
copy_source_metadata(
|
||||
source_repo=args.model,
|
||||
local_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
# Upload
|
||||
if not args.skip_upload:
|
||||
upload_to_huggingface(
|
||||
local_path=local_path,
|
||||
repo_id=repo_id,
|
||||
dry_run=args.dry_run,
|
||||
clean_remote=args.clean_remote,
|
||||
)
|
||||
|
||||
# Clean up if requested
|
||||
if args.clean:
|
||||
clean_local_files(local_path, dry_run=args.dry_run)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All done!")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user