Compare commits

..

1 Commits

Author SHA1 Message Date
Evan Quiney
6b907398a4 cancel downloads for deleted instances (#1393)
after deleting an instance, if a given (node_id, model_id) pair doesn't exist in the left over instances, cancel the download of model_id on node_id.
2026-02-05 18:16:43 +00:00
26 changed files with 1146 additions and 2066 deletions

View File

@@ -1,7 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
__all__ = ["Flux1Kontext"]

View File

@@ -1,49 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
from typing import Any
from mlx import nn
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
CLIPEncoder,
)
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.model.flux_vae.vae import VAE
from mflux.utils.generated_image import GeneratedImage
class Flux1Kontext(nn.Module):
vae: VAE
transformer: Transformer
t5_text_encoder: T5Encoder
clip_text_encoder: CLIPEncoder
bits: int | None
lora_paths: list[str] | None
lora_scales: list[float] | None
prompt_cache: dict[str, Any]
tokenizers: dict[str, Any]
def __init__(
self,
quantize: int | None = ...,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
model_config: ModelConfig = ...,
) -> None: ...
def generate_image(
self,
seed: int,
prompt: str,
num_inference_steps: int = ...,
height: int = ...,
width: int = ...,
guidance: float = ...,
image_path: Path | str | None = ...,
image_strength: float | None = ...,
scheduler: str = ...,
) -> GeneratedImage: ...

View File

@@ -1,16 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mflux.models.flux.model.flux_vae.vae import VAE
class KontextUtil:
@staticmethod
def create_image_conditioning_latents(
vae: VAE,
height: int,
width: int,
image_path: str,
) -> tuple[mx.array, mx.array]: ...

View File

@@ -20,7 +20,7 @@ sync-clean:
rust-rebuild:
cargo run --bin stub_gen
just sync-clean
uv sync --reinstall-package exo_pyo3_bindings
build-dashboard:
#!/usr/bin/env bash

View File

@@ -26,7 +26,7 @@ dependencies = [
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux==0.15.5",
"mflux==0.15.4",
"python-multipart>=0.0.21",
]

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Kontext-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Kontext-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Kontext-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -16,6 +16,7 @@ from exo.download.download_utils import (
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
@@ -107,6 +108,13 @@ class DownloadCoordinator:
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
case CancelDownload(model_id=model_id):
await self._cancel_download(model_id)
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads and model_id in self.download_status:
logger.info(f"Cancelling download for {model_id}")
self.active_downloads.pop(model_id).cancel()
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id

View File

@@ -105,6 +105,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
er_send, er_recv = channel[ElectionResult]()
@@ -188,6 +189,9 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
)
self._tg.start_soon(self.master.run)
elif (

View File

@@ -6,6 +6,7 @@ from loguru import logger
from exo.master.placement import (
add_instance_to_placements,
cancel_unnecessary_downloads,
delete_instance,
get_transition_events,
place_instance,
@@ -16,6 +17,7 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
@@ -66,12 +68,9 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
# Receiving indexed events from the forwarder to be applied to state
# Ideally these would be WorkerForwarderEvents but type system says no :(
local_event_receiver: Receiver[ForwarderEvent],
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -81,6 +80,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -280,6 +280,14 @@ class Master:
transition_events = get_transition_events(
self.state.instances, placement
)
for cmd in cancel_unnecessary_downloads(
placement, self.state.downloads
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
)
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(

View File

@@ -15,14 +15,20 @@ from exo.master.placement_utils import (
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CancelDownload,
CreateInstance,
DeleteInstance,
DownloadCommand,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -202,3 +208,29 @@ def get_transition_events(
)
return events
def cancel_unnecessary_downloads(
instances: Mapping[InstanceId, Instance],
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> Sequence[DownloadCommand]:
commands: list[DownloadCommand] = []
currently_downloading = [
(k, v.shard_metadata.model_card.model_id)
for k, vs in download_status.items()
for v in vs
if isinstance(v, (DownloadOngoing))
]
active_models = set(
(
node_id,
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
)
for instance in instances.values()
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
)
for pair in currently_downloading:
if pair not in active_models:
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
return commands

View File

@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
CommandId,
ForwarderCommand,
ForwarderDownloadCommand,
PlaceInstance,
TextGeneration,
)
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
download_command_sender=fcds,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -72,7 +72,12 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
Command = (

View File

@@ -5,9 +5,7 @@ from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_KONTEXT_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxKontextModelAdapter,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
@@ -28,16 +26,13 @@ AdapterFactory = Callable[
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"flux-kontext": FluxKontextModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
# Order matters: longer/more-specific patterns must come before shorter ones
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-kontext": FLUX_KONTEXT_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching

View File

@@ -66,19 +66,6 @@ class PromptData(ABC):
"""
...
@property
@abstractmethod
def kontext_image_ids(self) -> mx.array | None:
"""Kontext-style position IDs for image conditioning.
For FLUX.1-Kontext models, returns position IDs with first_coord=1
to distinguish conditioning tokens from generation tokens (first_coord=0).
Returns:
Position IDs array [1, seq_len, 3] for Kontext, None for other models.
"""
...
@abstractmethod
def get_batched_cfg_data(
self,

View File

@@ -1,17 +1,11 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_KONTEXT_CONFIG,
FLUX_SCHNELL_CONFIG,
)
from exo.worker.engines.image.models.flux.kontext_adapter import (
FluxKontextModelAdapter,
)
__all__ = [
"FluxModelAdapter",
"FluxKontextModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_KONTEXT_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -59,10 +59,6 @@ class FluxPromptData(PromptData):
def conditioning_latents(self) -> mx.array | None:
return None
@property
def kontext_image_ids(self) -> mx.array | None:
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:

View File

@@ -32,19 +32,3 @@ FLUX_DEV_CONFIG = ImageModelConfig(
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
)
FLUX_KONTEXT_CONFIG = ImageModelConfig(
model_family="flux-kontext",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
guidance_scale=4.0,
)

View File

@@ -1,348 +0,0 @@
import math
from pathlib import Path
from typing import Any, final
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
from mflux.models.flux.variants.kontext.kontext_util import KontextUtil
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
ModelAdapter,
PromptData,
RotaryEmbeddings,
)
from exo.worker.engines.image.models.flux.wrappers import (
FluxJointBlockWrapper,
FluxSingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
@final
class FluxKontextPromptData(PromptData):
"""Prompt data for FLUX.1-Kontext image editing.
Stores text embeddings along with conditioning latents and position IDs
for the input image.
"""
def __init__(
self,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
conditioning_latents: mx.array,
kontext_image_ids: mx.array,
):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
self._conditioning_latents = conditioning_latents
self._kontext_image_ids = kontext_image_ids
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
return None
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
return None
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
return None
@property
def conditioning_latents(self) -> mx.array | None:
"""VAE-encoded input image latents for Kontext conditioning."""
return self._conditioning_latents
@property
def kontext_image_ids(self) -> mx.array | None:
"""Position IDs for Kontext conditioning (first_coord=1)."""
return self._kontext_image_ids
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Kontext doesn't use CFG, but we return positive data for compatibility."""
return (
self._prompt_embeds,
None,
self._pooled_prompt_embeds,
self._conditioning_latents,
)
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
# Kontext doesn't use CFG
return None
@final
class FluxKontextModelAdapter(ModelAdapter[Flux1Kontext, Transformer]):
"""Adapter for FLUX.1-Kontext image editing model.
Key differences from standard FluxModelAdapter:
- Takes an input image and computes output dimensions from it
- Creates conditioning latents from the input image via VAE
- Creates special position IDs (kontext_image_ids) for conditioning tokens
- Creates pure noise latents (not img2img blending)
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1Kontext(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
# Stores image path and computed dimensions after set_image_dimensions
self._image_path: str | None = None
self._output_height: int | None = None
self._output_width: int | None = None
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
@property
def needs_cfg(self) -> bool:
return False
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper[Any]]:
"""Create wrapped joint blocks for Flux Kontext."""
return [
FluxJointBlockWrapper(block, text_seq_len)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper[Any]]:
"""Create wrapped single blocks for Flux Kontext."""
return [
FluxSingleBlockWrapper(block, text_seq_len)
for block in self._transformer.single_transformer_blocks
]
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
all_joint = list(self._transformer.transformer_blocks)
all_single = list(self._transformer.single_transformer_blocks)
total_joint_blocks = len(all_joint)
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_path for use in encode_prompt().
Args:
image_path: Path to the input image
Returns:
(output_width, output_height) for runtime config
"""
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Compute output dimensions from input image aspect ratio
# Target area of 1024x1024 = ~1M pixels
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
self._image_path = str(image_path)
self._output_width = int(output_width)
self._output_height = int(output_height)
return self._output_width, self._output_height
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial noise latents for Kontext.
Unlike standard img2img which blends noise with encoded input,
Kontext uses pure noise latents. The input image is provided
separately as conditioning.
"""
return FluxLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> FluxKontextPromptData:
"""Encode prompt and create conditioning from stored input image.
Must call set_image_dimensions() before this method.
Args:
prompt: Text prompt for editing
negative_prompt: Ignored (Kontext doesn't use CFG)
Returns:
FluxKontextPromptData with text embeddings and image conditioning
"""
del negative_prompt # Kontext doesn't support negative prompts or CFG
if (
self._image_path is None
or self._output_height is None
or self._output_width is None
):
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for FluxKontextModelAdapter"
)
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
# Encode text prompt
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self.model.prompt_cache,
t5_tokenizer=self.model.tokenizers["t5"], # pyright: ignore[reportAny]
clip_tokenizer=self.model.tokenizers["clip"], # pyright: ignore[reportAny]
t5_text_encoder=self.model.t5_text_encoder,
clip_text_encoder=self.model.clip_text_encoder,
)
# Create conditioning latents from input image
conditioning_latents, kontext_image_ids = (
KontextUtil.create_image_conditioning_latents(
vae=self.model.vae,
height=self._output_height,
width=self._output_width,
image_path=self._image_path,
)
)
return FluxKontextPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
conditioning_latents=conditioning_latents,
kontext_image_ids=kontext_image_ids,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux Kontext text embeddings"
)
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> RotaryEmbeddings:
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux Kontext does not use classifier-free guidance")

View File

@@ -69,10 +69,6 @@ class QwenPromptData(PromptData):
def conditioning_latents(self) -> mx.array | None:
return None
@property
def kontext_image_ids(self) -> mx.array | None:
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:

View File

@@ -85,10 +85,6 @@ class QwenEditPromptData(PromptData):
def qwen_image_ids(self) -> mx.array:
return self._qwen_image_ids
@property
def kontext_image_ids(self) -> mx.array | None:
return None
@property
def is_edit_mode(self) -> bool:
return True

View File

@@ -567,7 +567,6 @@ class DiffusionRunner:
| list[tuple[int, int, int]]
| None = None,
conditioning_latents: mx.array | None = None,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
"""Run a single forward pass through the transformer.
Args:
@@ -579,7 +578,6 @@ class DiffusionRunner:
encoder_hidden_states_mask: Attention mask for text (Qwen)
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
conditioning_latents: Conditioning latents for edit mode
kontext_image_ids: Position IDs for Kontext conditioning (Flux Kontext)
Returns:
Noise prediction tensor
@@ -612,7 +610,6 @@ class DiffusionRunner:
config,
encoder_hidden_states_mask=encoder_hidden_states_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
assert self.joint_block_wrappers is not None
@@ -684,7 +681,6 @@ class DiffusionRunner:
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
@@ -704,7 +700,6 @@ class DiffusionRunner:
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
conditioning_latents=branch.cond_latents,
kontext_image_ids=kontext_image_ids,
)
results.append((branch.positive, noise))
@@ -907,10 +902,10 @@ class DiffusionRunner:
config: Config,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
@@ -984,10 +979,10 @@ class DiffusionRunner:
latents: mx.array,
prompt_data: PromptData,
is_first_async_step: bool,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
prev_patch_latents = [p for p in patch_latents]

View File

@@ -35,7 +35,7 @@ i=0
for host; do
colour=${colours[i++ % 4]}
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
done

View File

@@ -1,377 +0,0 @@
#!/usr/bin/env python3
"""
Download an mflux model, quantize it, and upload to HuggingFace.
Usage (run from mflux project directory):
cd /path/to/mflux
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
Requires:
- Must be run from mflux project directory using `uv run`
- huggingface_hub installed (add to mflux deps or install separately)
- HuggingFace authentication: run `huggingface-cli login` or set HF_TOKEN
"""
from __future__ import annotations
import argparse
import re
import shutil
import sys
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from mflux.models.flux.variants.txt2img.flux import Flux1
HF_ORG = "exolabs"
def get_model_class(model_name: str) -> type:
"""Get the appropriate model class based on model name."""
from mflux.models.fibo.variants.txt2img.fibo import FIBO
from mflux.models.flux.variants.txt2img.flux import Flux1
from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from mflux.models.z_image.variants.turbo.z_image_turbo import ZImageTurbo
model_name_lower = model_name.lower()
if "qwen" in model_name_lower:
return QwenImage
elif "fibo" in model_name_lower:
return FIBO
elif "z-image" in model_name_lower or "zimage" in model_name_lower:
return ZImageTurbo
elif "flux2" in model_name_lower or "flux.2" in model_name_lower:
return Flux2Klein
else:
return Flux1
def get_repo_name(model_name: str, bits: int | None) -> str:
"""Get the HuggingFace repo name for a model variant."""
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
suffix = f"-{bits}bit" if bits else ""
return f"{HF_ORG}/{base_name}{suffix}"
def get_local_path(output_dir: Path, model_name: str, bits: int | None) -> Path:
"""Get the local save path for a model variant."""
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
suffix = f"-{bits}bit" if bits else ""
return output_dir / f"{base_name}{suffix}"
def copy_source_repo(
source_repo: str,
local_path: Path,
dry_run: bool = False,
) -> None:
"""Copy all files from source repo (replicating original HF structure)."""
print(f"\n{'=' * 60}")
print(f"Copying full repo from source: {source_repo}")
print(f"Output path: {local_path}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would download all files from source repo")
return
from huggingface_hub import snapshot_download
# Download all files to our local path
snapshot_download(
repo_id=source_repo,
local_dir=local_path,
)
# Remove root-level safetensors files (flux.1-dev.safetensors, etc.)
# These are redundant with the component directories
for f in local_path.glob("*.safetensors"):
print(f"Removing root-level safetensors: {f.name}")
if not dry_run:
f.unlink()
print(f"Source repo copied to {local_path}")
def load_and_save_quantized_model(
model_name: str,
bits: int,
output_path: Path,
dry_run: bool = False,
) -> None:
"""Load a model with quantization and save it in mflux format."""
print(f"\n{'=' * 60}")
print(f"Loading {model_name} with {bits}-bit quantization...")
print(f"Output path: {output_path}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would load and save quantized model")
return
from mflux.models.common.config.model_config import ModelConfig
model_class = get_model_class(model_name)
model_config = ModelConfig.from_name(model_name=model_name, base_model=None)
model: Flux1 = model_class(
quantize=bits,
model_config=model_config,
)
print(f"Saving model to {output_path}...")
model.save_model(str(output_path))
print(f"Model saved successfully to {output_path}")
def copy_source_metadata(
source_repo: str,
local_path: Path,
dry_run: bool = False,
) -> None:
"""Copy metadata files (LICENSE, README, etc.) from source repo, excluding safetensors."""
print(f"\n{'=' * 60}")
print(f"Copying metadata from source repo: {source_repo}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would download metadata files (excluding *.safetensors)")
return
from huggingface_hub import snapshot_download
# Download all files except safetensors to our local path
snapshot_download(
repo_id=source_repo,
local_dir=local_path,
ignore_patterns=["*.safetensors"],
)
print(f"Metadata files copied to {local_path}")
def upload_to_huggingface(
local_path: Path,
repo_id: str,
dry_run: bool = False,
clean_remote: bool = False,
) -> None:
"""Upload a saved model to HuggingFace."""
print(f"\n{'=' * 60}")
print(f"Uploading to HuggingFace: {repo_id}")
print(f"Local path: {local_path}")
print(f"Clean remote first: {clean_remote}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would upload to HuggingFace")
return
from huggingface_hub import HfApi
api = HfApi()
# Create the repo if it doesn't exist
print(f"Creating/verifying repo: {repo_id}")
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
# Clean remote repo if requested (delete old mflux-format files)
if clean_remote:
print("Cleaning old mflux-format files from remote...")
try:
# Pattern for mflux numbered shards: <dir>/<number>.safetensors
numbered_pattern = re.compile(r".*/\d+\.safetensors$")
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="model")
for file_path in repo_files:
# Delete numbered safetensors (mflux format) and mflux index files
if numbered_pattern.match(file_path) or file_path.endswith(
"/model.safetensors.index.json"
):
print(f" Deleting: {file_path}")
api.delete_file(
path_in_repo=file_path, repo_id=repo_id, repo_type="model"
)
except Exception as e:
print(f"Warning: Could not clean remote files: {e}")
# Upload the folder
print("Uploading folder contents...")
api.upload_folder(
folder_path=str(local_path),
repo_id=repo_id,
repo_type="model",
)
print(f"Upload complete: https://huggingface.co/{repo_id}")
def clean_local_files(local_path: Path, dry_run: bool = False) -> None:
"""Remove local model files after upload."""
print(f"\nCleaning up: {local_path}")
if dry_run:
print("[DRY RUN] Would remove local files")
return
if local_path.exists():
shutil.rmtree(local_path)
print(f"Removed {local_path}")
def main() -> int:
parser = argparse.ArgumentParser(
description="Download an mflux model, quantize it, and upload to HuggingFace.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Process all variants (base, 4-bit, 8-bit) for FLUX.1-Kontext-dev
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
# Only process 4-bit variant
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
# Save locally without uploading
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-upload
# Preview what would happen
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
""",
)
parser.add_argument(
"--model",
"-m",
required=True,
help="HuggingFace model path (e.g., black-forest-labs/FLUX.1-Kontext-dev)",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("./tmp/models"),
help="Local directory to save models (default: ./tmp/models)",
)
parser.add_argument(
"--skip-base",
action="store_true",
help="Skip base model (no quantization)",
)
parser.add_argument(
"--skip-4bit",
action="store_true",
help="Skip 4-bit quantized model",
)
parser.add_argument(
"--skip-8bit",
action="store_true",
help="Skip 8-bit quantized model",
)
parser.add_argument(
"--skip-download",
action="store_true",
help="Skip downloading/processing, only do upload/clean operations",
)
parser.add_argument(
"--skip-upload",
action="store_true",
help="Only save locally, don't upload to HuggingFace",
)
parser.add_argument(
"--clean",
action="store_true",
help="Remove local files after upload",
)
parser.add_argument(
"--clean-remote",
action="store_true",
help="Delete old mflux-format files from remote repo before uploading",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print actions without executing",
)
args = parser.parse_args()
# Determine which variants to process
variants: list[int | None] = []
if not args.skip_base:
variants.append(None) # Base model (no quantization)
if not args.skip_4bit:
variants.append(4)
if not args.skip_8bit:
variants.append(8)
if not variants:
print("Error: All variants skipped. Nothing to do.")
return 1
# Create output directory
args.output_dir.mkdir(parents=True, exist_ok=True)
print(f"Model: {args.model}")
print(f"Output directory: {args.output_dir}")
print(
f"Variants to process: {['base' if v is None else f'{v}-bit' for v in variants]}"
)
print(f"Upload to HuggingFace: {not args.skip_upload}")
print(f"Clean after upload: {args.clean}")
if args.dry_run:
print("\n*** DRY RUN MODE - No actual changes will be made ***")
# Process each variant
for bits in variants:
local_path = get_local_path(args.output_dir, args.model, bits)
repo_id = get_repo_name(args.model, bits)
if not args.skip_download:
if bits is None:
# Base model: copy original HF repo structure (no mflux conversion)
copy_source_repo(
source_repo=args.model,
local_path=local_path,
dry_run=args.dry_run,
)
else:
# Quantized model: load, quantize, and save with mflux
load_and_save_quantized_model(
model_name=args.model,
bits=bits,
output_path=local_path,
dry_run=args.dry_run,
)
# Copy metadata from source repo (LICENSE, README, etc.)
copy_source_metadata(
source_repo=args.model,
local_path=local_path,
dry_run=args.dry_run,
)
# Upload
if not args.skip_upload:
upload_to_huggingface(
local_path=local_path,
repo_id=repo_id,
dry_run=args.dry_run,
clean_remote=args.clean_remote,
)
# Clean up if requested
if args.clean:
clean_local_files(local_path, dry_run=args.dry_run)
print("\n" + "=" * 60)
print("All done!")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())

2143
uv.lock generated
View File

File diff suppressed because it is too large Load Diff