Compare commits

..

4 Commits

Author SHA1 Message Date
ciaranbor
b77642a46b Use original model files for non-quantized upload 2026-02-05 18:05:08 +00:00
ciaranbor
c5f7ead69d Update mflux to 0.15.5 2026-02-05 18:01:13 +00:00
ciaranbor
90f743e89c Add quantize_and_upload script for quantized image models 2026-02-05 18:01:13 +00:00
ciaranbor
f366005ece Add FLUX.1-Kontext-dev 2026-02-05 18:01:13 +00:00
24 changed files with 2091 additions and 1164 deletions

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
__all__ = ["Flux1Kontext"]

View File

@@ -0,0 +1,49 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
from typing import Any
from mlx import nn
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
CLIPEncoder,
)
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.model.flux_vae.vae import VAE
from mflux.utils.generated_image import GeneratedImage
class Flux1Kontext(nn.Module):
vae: VAE
transformer: Transformer
t5_text_encoder: T5Encoder
clip_text_encoder: CLIPEncoder
bits: int | None
lora_paths: list[str] | None
lora_scales: list[float] | None
prompt_cache: dict[str, Any]
tokenizers: dict[str, Any]
def __init__(
self,
quantize: int | None = ...,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
model_config: ModelConfig = ...,
) -> None: ...
def generate_image(
self,
seed: int,
prompt: str,
num_inference_steps: int = ...,
height: int = ...,
width: int = ...,
guidance: float = ...,
image_path: Path | str | None = ...,
image_strength: float | None = ...,
scheduler: str = ...,
) -> GeneratedImage: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mflux.models.flux.model.flux_vae.vae import VAE
class KontextUtil:
@staticmethod
def create_image_conditioning_latents(
vae: VAE,
height: int,
width: int,
image_path: str,
) -> tuple[mx.array, mx.array]: ...

View File

@@ -148,15 +148,6 @@
setImageGenerationParams({ guidance: null });
}
function handleNumSyncStepsChange(event: Event) {
const value = parseInt((event.target as HTMLInputElement).value, 10);
setImageGenerationParams({ numSyncSteps: value });
}
function clearNumSyncSteps() {
setImageGenerationParams({ numSyncSteps: null });
}
function handleReset() {
resetImageGenerationParams();
showAdvanced = false;
@@ -166,8 +157,7 @@
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null,
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
);
</script>
@@ -588,50 +578,7 @@
</div>
</div>
<!-- Row 3: Sync Steps -->
<div class="flex items-center gap-1.5">
<span
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
>SYNC STEPS:</span
>
<div class="flex items-center gap-2 flex-1 max-w-xs">
<input
type="range"
min="1"
max="100"
value={params.numSyncSteps ?? 1}
oninput={handleNumSyncStepsChange}
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
/>
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
{params.numSyncSteps ?? "--"}
</span>
{#if params.numSyncSteps !== null}
<button
type="button"
onclick={clearNumSyncSteps}
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
title="Clear"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
{/if}
</div>
</div>
<!-- Row 4: Negative Prompt -->
<!-- Row 3: Negative Prompt -->
<div class="flex flex-col gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>NEGATIVE PROMPT:</span

View File

@@ -298,7 +298,6 @@ export interface ImageGenerationParams {
numInferenceSteps: number | null;
guidance: number | null;
negativePrompt: string | null;
numSyncSteps: number | null;
// Edit mode params
inputFidelity: "low" | "high";
}
@@ -320,7 +319,6 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
numInferenceSteps: null,
guidance: null,
negativePrompt: null,
numSyncSteps: null,
inputFidelity: "low",
};
@@ -2398,9 +2396,7 @@ class AppStore {
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null;
(params.negativePrompt !== null && params.negativePrompt.trim() !== "");
const requestBody: Record<string, unknown> = {
model,
@@ -2425,9 +2421,6 @@ class AppStore {
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
...(params.numSyncSteps !== null && {
num_sync_steps: params.numSyncSteps,
}),
};
}
@@ -2677,19 +2670,11 @@ class AppStore {
formData.append("input_fidelity", params.inputFidelity);
// Advanced params
const hasAdvancedParams =
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null;
if (hasAdvancedParams) {
if (params.seed !== null) {
formData.append(
"advanced_params",
JSON.stringify({
...(params.seed !== null && { seed: params.seed }),
seed: params.seed,
...(params.numInferenceSteps !== null && {
num_inference_steps: params.numInferenceSteps,
}),
@@ -2698,9 +2683,24 @@ class AppStore {
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
...(params.numSyncSteps !== null && {
num_sync_steps: params.numSyncSteps,
}),
);
} else if (
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "")
) {
formData.append(
"advanced_params",
JSON.stringify({
...(params.numInferenceSteps !== null && {
num_inference_steps: params.numInferenceSteps,
}),
...(params.guidance !== null && { guidance: params.guidance }),
...(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
}),
);
}

View File

@@ -26,7 +26,7 @@ dependencies = [
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux==0.15.4",
"mflux==0.15.5",
"python-multipart>=0.0.21",
]

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Kontext-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Kontext-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Kontext-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -272,7 +272,6 @@ class AdvancedImageParams(BaseModel):
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
negative_prompt: str | None = None
num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None
class ImageGenerationTaskParams(BaseModel):

View File

@@ -1,4 +1,5 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
@@ -22,7 +23,7 @@ class ImageModelConfig(BaseModel):
block_configs: tuple[TransformerBlockConfig, ...]
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps: int # Number of sync steps for distributed inference
num_sync_steps_factor: float # Fraction of steps for sync phase
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@@ -44,3 +45,6 @@ class ImageModelConfig(BaseModel):
def get_steps_for_quality(self, quality: str) -> int:
return self.default_steps[quality]
def get_num_sync_steps(self, steps: int) -> int:
return ceil(steps * self.num_sync_steps_factor)

View File

@@ -150,10 +150,7 @@ class DistributedImageModel:
guidance=guidance_override if guidance_override is not None else 4.0,
)
if advanced_params is not None and advanced_params.num_sync_steps is not None:
num_sync_steps = advanced_params.num_sync_steps
else:
num_sync_steps = self._config.num_sync_steps
num_sync_steps = self._config.get_num_sync_steps(steps)
for result in self._runner.generate_image(
runtime_config=config,

View File

@@ -5,7 +5,9 @@ from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_KONTEXT_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxKontextModelAdapter,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
@@ -26,13 +28,16 @@ AdapterFactory = Callable[
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"flux-kontext": FluxKontextModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
# Order matters: longer/more-specific patterns must come before shorter ones
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-kontext": FLUX_KONTEXT_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching

View File

@@ -66,6 +66,19 @@ class PromptData(ABC):
"""
...
@property
@abstractmethod
def kontext_image_ids(self) -> mx.array | None:
"""Kontext-style position IDs for image conditioning.
For FLUX.1-Kontext models, returns position IDs with first_coord=1
to distinguish conditioning tokens from generation tokens (first_coord=0).
Returns:
Position IDs array [1, seq_len, 3] for Kontext, None for other models.
"""
...
@abstractmethod
def get_batched_cfg_data(
self,

View File

@@ -1,11 +1,17 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_KONTEXT_CONFIG,
FLUX_SCHNELL_CONFIG,
)
from exo.worker.engines.image.models.flux.kontext_adapter import (
FluxKontextModelAdapter,
)
__all__ = [
"FluxModelAdapter",
"FluxKontextModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_KONTEXT_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -59,6 +59,10 @@ class FluxPromptData(PromptData):
def conditioning_latents(self) -> mx.array | None:
return None
@property
def kontext_image_ids(self) -> mx.array | None:
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:

View File

@@ -15,7 +15,7 @@ FLUX_SCHNELL_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps=1,
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
)
@@ -30,5 +30,21 @@ FLUX_DEV_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps=4,
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
)
FLUX_KONTEXT_CONFIG = ImageModelConfig(
model_family="flux-kontext",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
guidance_scale=4.0,
)

View File

@@ -0,0 +1,348 @@
import math
from pathlib import Path
from typing import Any, final
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
from mflux.models.flux.variants.kontext.kontext_util import KontextUtil
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
ModelAdapter,
PromptData,
RotaryEmbeddings,
)
from exo.worker.engines.image.models.flux.wrappers import (
FluxJointBlockWrapper,
FluxSingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
@final
class FluxKontextPromptData(PromptData):
"""Prompt data for FLUX.1-Kontext image editing.
Stores text embeddings along with conditioning latents and position IDs
for the input image.
"""
def __init__(
self,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
conditioning_latents: mx.array,
kontext_image_ids: mx.array,
):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
self._conditioning_latents = conditioning_latents
self._kontext_image_ids = kontext_image_ids
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
return None
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
return None
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
return None
@property
def conditioning_latents(self) -> mx.array | None:
"""VAE-encoded input image latents for Kontext conditioning."""
return self._conditioning_latents
@property
def kontext_image_ids(self) -> mx.array | None:
"""Position IDs for Kontext conditioning (first_coord=1)."""
return self._kontext_image_ids
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Kontext doesn't use CFG, but we return positive data for compatibility."""
return (
self._prompt_embeds,
None,
self._pooled_prompt_embeds,
self._conditioning_latents,
)
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
# Kontext doesn't use CFG
return None
@final
class FluxKontextModelAdapter(ModelAdapter[Flux1Kontext, Transformer]):
"""Adapter for FLUX.1-Kontext image editing model.
Key differences from standard FluxModelAdapter:
- Takes an input image and computes output dimensions from it
- Creates conditioning latents from the input image via VAE
- Creates special position IDs (kontext_image_ids) for conditioning tokens
- Creates pure noise latents (not img2img blending)
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1Kontext(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
# Stores image path and computed dimensions after set_image_dimensions
self._image_path: str | None = None
self._output_height: int | None = None
self._output_width: int | None = None
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
@property
def needs_cfg(self) -> bool:
return False
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper[Any]]:
"""Create wrapped joint blocks for Flux Kontext."""
return [
FluxJointBlockWrapper(block, text_seq_len)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper[Any]]:
"""Create wrapped single blocks for Flux Kontext."""
return [
FluxSingleBlockWrapper(block, text_seq_len)
for block in self._transformer.single_transformer_blocks
]
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
all_joint = list(self._transformer.transformer_blocks)
all_single = list(self._transformer.single_transformer_blocks)
total_joint_blocks = len(all_joint)
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_path for use in encode_prompt().
Args:
image_path: Path to the input image
Returns:
(output_width, output_height) for runtime config
"""
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Compute output dimensions from input image aspect ratio
# Target area of 1024x1024 = ~1M pixels
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
self._image_path = str(image_path)
self._output_width = int(output_width)
self._output_height = int(output_height)
return self._output_width, self._output_height
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial noise latents for Kontext.
Unlike standard img2img which blends noise with encoded input,
Kontext uses pure noise latents. The input image is provided
separately as conditioning.
"""
return FluxLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> FluxKontextPromptData:
"""Encode prompt and create conditioning from stored input image.
Must call set_image_dimensions() before this method.
Args:
prompt: Text prompt for editing
negative_prompt: Ignored (Kontext doesn't use CFG)
Returns:
FluxKontextPromptData with text embeddings and image conditioning
"""
del negative_prompt # Kontext doesn't support negative prompts or CFG
if (
self._image_path is None
or self._output_height is None
or self._output_width is None
):
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for FluxKontextModelAdapter"
)
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
# Encode text prompt
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self.model.prompt_cache,
t5_tokenizer=self.model.tokenizers["t5"], # pyright: ignore[reportAny]
clip_tokenizer=self.model.tokenizers["clip"], # pyright: ignore[reportAny]
t5_text_encoder=self.model.t5_text_encoder,
clip_text_encoder=self.model.clip_text_encoder,
)
# Create conditioning latents from input image
conditioning_latents, kontext_image_ids = (
KontextUtil.create_image_conditioning_latents(
vae=self.model.vae,
height=self._output_height,
width=self._output_width,
image_path=self._image_path,
)
)
return FluxKontextPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
conditioning_latents=conditioning_latents,
kontext_image_ids=kontext_image_ids,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux Kontext text embeddings"
)
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> RotaryEmbeddings:
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux Kontext does not use classifier-free guidance")

View File

@@ -69,6 +69,10 @@ class QwenPromptData(PromptData):
def conditioning_latents(self) -> mx.array | None:
return None
@property
def kontext_image_ids(self) -> mx.array | None:
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:

View File

@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps=7,
num_sync_steps_factor=0.25,
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
)
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps=7,
num_sync_steps_factor=0.25,
guidance_scale=3.5,
)

View File

@@ -85,6 +85,10 @@ class QwenEditPromptData(PromptData):
def qwen_image_ids(self) -> mx.array:
return self._qwen_image_ids
@property
def kontext_image_ids(self) -> mx.array | None:
return None
@property
def is_edit_mode(self) -> bool:
return True

View File

@@ -567,6 +567,7 @@ class DiffusionRunner:
| list[tuple[int, int, int]]
| None = None,
conditioning_latents: mx.array | None = None,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
"""Run a single forward pass through the transformer.
Args:
@@ -578,6 +579,7 @@ class DiffusionRunner:
encoder_hidden_states_mask: Attention mask for text (Qwen)
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
conditioning_latents: Conditioning latents for edit mode
kontext_image_ids: Position IDs for Kontext conditioning (Flux Kontext)
Returns:
Noise prediction tensor
@@ -610,6 +612,7 @@ class DiffusionRunner:
config,
encoder_hidden_states_mask=encoder_hidden_states_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
assert self.joint_block_wrappers is not None
@@ -681,6 +684,7 @@ class DiffusionRunner:
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
@@ -700,6 +704,7 @@ class DiffusionRunner:
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
conditioning_latents=branch.cond_latents,
kontext_image_ids=kontext_image_ids,
)
results.append((branch.positive, noise))
@@ -902,10 +907,10 @@ class DiffusionRunner:
config: Config,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
@@ -979,10 +984,10 @@ class DiffusionRunner:
latents: mx.array,
prompt_data: PromptData,
is_first_async_step: bool,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
prev_patch_latents = [p for p in patch_latents]

377
tmp/quantize_and_upload.py Executable file
View File

@@ -0,0 +1,377 @@
#!/usr/bin/env python3
"""
Download an mflux model, quantize it, and upload to HuggingFace.
Usage (run from mflux project directory):
cd /path/to/mflux
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
Requires:
- Must be run from mflux project directory using `uv run`
- huggingface_hub installed (add to mflux deps or install separately)
- HuggingFace authentication: run `huggingface-cli login` or set HF_TOKEN
"""
from __future__ import annotations
import argparse
import re
import shutil
import sys
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from mflux.models.flux.variants.txt2img.flux import Flux1
HF_ORG = "exolabs"
def get_model_class(model_name: str) -> type:
"""Get the appropriate model class based on model name."""
from mflux.models.fibo.variants.txt2img.fibo import FIBO
from mflux.models.flux.variants.txt2img.flux import Flux1
from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from mflux.models.z_image.variants.turbo.z_image_turbo import ZImageTurbo
model_name_lower = model_name.lower()
if "qwen" in model_name_lower:
return QwenImage
elif "fibo" in model_name_lower:
return FIBO
elif "z-image" in model_name_lower or "zimage" in model_name_lower:
return ZImageTurbo
elif "flux2" in model_name_lower or "flux.2" in model_name_lower:
return Flux2Klein
else:
return Flux1
def get_repo_name(model_name: str, bits: int | None) -> str:
"""Get the HuggingFace repo name for a model variant."""
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
suffix = f"-{bits}bit" if bits else ""
return f"{HF_ORG}/{base_name}{suffix}"
def get_local_path(output_dir: Path, model_name: str, bits: int | None) -> Path:
"""Get the local save path for a model variant."""
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
suffix = f"-{bits}bit" if bits else ""
return output_dir / f"{base_name}{suffix}"
def copy_source_repo(
source_repo: str,
local_path: Path,
dry_run: bool = False,
) -> None:
"""Copy all files from source repo (replicating original HF structure)."""
print(f"\n{'=' * 60}")
print(f"Copying full repo from source: {source_repo}")
print(f"Output path: {local_path}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would download all files from source repo")
return
from huggingface_hub import snapshot_download
# Download all files to our local path
snapshot_download(
repo_id=source_repo,
local_dir=local_path,
)
# Remove root-level safetensors files (flux.1-dev.safetensors, etc.)
# These are redundant with the component directories
for f in local_path.glob("*.safetensors"):
print(f"Removing root-level safetensors: {f.name}")
if not dry_run:
f.unlink()
print(f"Source repo copied to {local_path}")
def load_and_save_quantized_model(
model_name: str,
bits: int,
output_path: Path,
dry_run: bool = False,
) -> None:
"""Load a model with quantization and save it in mflux format."""
print(f"\n{'=' * 60}")
print(f"Loading {model_name} with {bits}-bit quantization...")
print(f"Output path: {output_path}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would load and save quantized model")
return
from mflux.models.common.config.model_config import ModelConfig
model_class = get_model_class(model_name)
model_config = ModelConfig.from_name(model_name=model_name, base_model=None)
model: Flux1 = model_class(
quantize=bits,
model_config=model_config,
)
print(f"Saving model to {output_path}...")
model.save_model(str(output_path))
print(f"Model saved successfully to {output_path}")
def copy_source_metadata(
source_repo: str,
local_path: Path,
dry_run: bool = False,
) -> None:
"""Copy metadata files (LICENSE, README, etc.) from source repo, excluding safetensors."""
print(f"\n{'=' * 60}")
print(f"Copying metadata from source repo: {source_repo}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would download metadata files (excluding *.safetensors)")
return
from huggingface_hub import snapshot_download
# Download all files except safetensors to our local path
snapshot_download(
repo_id=source_repo,
local_dir=local_path,
ignore_patterns=["*.safetensors"],
)
print(f"Metadata files copied to {local_path}")
def upload_to_huggingface(
local_path: Path,
repo_id: str,
dry_run: bool = False,
clean_remote: bool = False,
) -> None:
"""Upload a saved model to HuggingFace."""
print(f"\n{'=' * 60}")
print(f"Uploading to HuggingFace: {repo_id}")
print(f"Local path: {local_path}")
print(f"Clean remote first: {clean_remote}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would upload to HuggingFace")
return
from huggingface_hub import HfApi
api = HfApi()
# Create the repo if it doesn't exist
print(f"Creating/verifying repo: {repo_id}")
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
# Clean remote repo if requested (delete old mflux-format files)
if clean_remote:
print("Cleaning old mflux-format files from remote...")
try:
# Pattern for mflux numbered shards: <dir>/<number>.safetensors
numbered_pattern = re.compile(r".*/\d+\.safetensors$")
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="model")
for file_path in repo_files:
# Delete numbered safetensors (mflux format) and mflux index files
if numbered_pattern.match(file_path) or file_path.endswith(
"/model.safetensors.index.json"
):
print(f" Deleting: {file_path}")
api.delete_file(
path_in_repo=file_path, repo_id=repo_id, repo_type="model"
)
except Exception as e:
print(f"Warning: Could not clean remote files: {e}")
# Upload the folder
print("Uploading folder contents...")
api.upload_folder(
folder_path=str(local_path),
repo_id=repo_id,
repo_type="model",
)
print(f"Upload complete: https://huggingface.co/{repo_id}")
def clean_local_files(local_path: Path, dry_run: bool = False) -> None:
"""Remove local model files after upload."""
print(f"\nCleaning up: {local_path}")
if dry_run:
print("[DRY RUN] Would remove local files")
return
if local_path.exists():
shutil.rmtree(local_path)
print(f"Removed {local_path}")
def main() -> int:
parser = argparse.ArgumentParser(
description="Download an mflux model, quantize it, and upload to HuggingFace.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Process all variants (base, 4-bit, 8-bit) for FLUX.1-Kontext-dev
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
# Only process 4-bit variant
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
# Save locally without uploading
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-upload
# Preview what would happen
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
""",
)
parser.add_argument(
"--model",
"-m",
required=True,
help="HuggingFace model path (e.g., black-forest-labs/FLUX.1-Kontext-dev)",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("./tmp/models"),
help="Local directory to save models (default: ./tmp/models)",
)
parser.add_argument(
"--skip-base",
action="store_true",
help="Skip base model (no quantization)",
)
parser.add_argument(
"--skip-4bit",
action="store_true",
help="Skip 4-bit quantized model",
)
parser.add_argument(
"--skip-8bit",
action="store_true",
help="Skip 8-bit quantized model",
)
parser.add_argument(
"--skip-download",
action="store_true",
help="Skip downloading/processing, only do upload/clean operations",
)
parser.add_argument(
"--skip-upload",
action="store_true",
help="Only save locally, don't upload to HuggingFace",
)
parser.add_argument(
"--clean",
action="store_true",
help="Remove local files after upload",
)
parser.add_argument(
"--clean-remote",
action="store_true",
help="Delete old mflux-format files from remote repo before uploading",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print actions without executing",
)
args = parser.parse_args()
# Determine which variants to process
variants: list[int | None] = []
if not args.skip_base:
variants.append(None) # Base model (no quantization)
if not args.skip_4bit:
variants.append(4)
if not args.skip_8bit:
variants.append(8)
if not variants:
print("Error: All variants skipped. Nothing to do.")
return 1
# Create output directory
args.output_dir.mkdir(parents=True, exist_ok=True)
print(f"Model: {args.model}")
print(f"Output directory: {args.output_dir}")
print(
f"Variants to process: {['base' if v is None else f'{v}-bit' for v in variants]}"
)
print(f"Upload to HuggingFace: {not args.skip_upload}")
print(f"Clean after upload: {args.clean}")
if args.dry_run:
print("\n*** DRY RUN MODE - No actual changes will be made ***")
# Process each variant
for bits in variants:
local_path = get_local_path(args.output_dir, args.model, bits)
repo_id = get_repo_name(args.model, bits)
if not args.skip_download:
if bits is None:
# Base model: copy original HF repo structure (no mflux conversion)
copy_source_repo(
source_repo=args.model,
local_path=local_path,
dry_run=args.dry_run,
)
else:
# Quantized model: load, quantize, and save with mflux
load_and_save_quantized_model(
model_name=args.model,
bits=bits,
output_path=local_path,
dry_run=args.dry_run,
)
# Copy metadata from source repo (LICENSE, README, etc.)
copy_source_metadata(
source_repo=args.model,
local_path=local_path,
dry_run=args.dry_run,
)
# Upload
if not args.skip_upload:
upload_to_huggingface(
local_path=local_path,
repo_id=repo_id,
dry_run=args.dry_run,
clean_remote=args.clean_remote,
)
# Clean up if requested
if args.clean:
clean_local_files(local_path, dry_run=args.dry_run)
print("\n" + "=" * 60)
print("All done!")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())

2143
uv.lock generated
View File

File diff suppressed because it is too large Load Diff