Compare commits

..

5 Commits

Author SHA1 Message Date
Evan
7597e271ab api cancellation
closing the http request to the api now
- sends a cancellation from the api
- writes that canellation in the master
- worker plans off the cancellation
- runner observes that cancellation after every generation step (+1
communication per token)
- cancellation happens synchronously to prevent gpu locks
2026-02-06 17:09:49 +00:00
rltakashige
5455a97a8c Fix GLM4Moe Tensor Sharding (#1411)
## Motivation

Recent commit broke glm (non lite) sharding

## Why It Works

Assert is no longer hit, as isinstance check includes
GLM4MoeDecoderLayer.
Added type stubs to keep the type checker happy.

## Test Plan

### Manual Testing
Runs as expected without gibberish.
2026-02-06 16:53:15 +00:00
ciaranbor
6f0cb99919 Ciaran/flux1 kontext (#1394)
## Motivation

Add support for FLUX.1-Kontext-dev, an image editing variant of
FLUX.1-dev

## Changes

- New FluxKontextModelAdapter: Handles Kontext's image-to-image workflow
- encodes input image as conditioning latents with special position IDs,
generates from pure noise
- Model config: 57 transformer blocks (19 joint + 38 single), guidance
scale 4.0, ImageToImage task
- Pipeline updates: Added kontext_image_ids property to PromptData
interface, passed through diffusion runner
  - Model cards: Added TOML configs for base, 4-bit, and 8-bit variants
  - Dependency: mflux 0.15.4 → 0.15.5
- Utility: tmp/quantize_and_upload.py for quantizing and uploading
models to HuggingFace

## Test Plan

### Manual Testing

Works better than Qwen-Image-Edit
2026-02-06 16:20:31 +00:00
ciaranbor
c8d3154f83 More image dimensions (#1395)
## Motivation

More dimensions for image generation

## Changes

- dashboard/src/lib/components/ImageParamsPanel.svelte: Added
"1024x1365" and "1365x1024" to the sizeOptions array
- dashboard/src/lib/stores/app.svelte.ts: Extended the size type in
ImageGenerationParams interface to include the two new dimension options
2026-02-06 15:59:06 +00:00
ciaranbor
63e9cc4fea Ciaran/num sync steps (#1396)
## Motivation

Allow users to directly configure num_sync_steps for distributed image
generation instead of deriving it from a factor of total steps.

## Changes

  - Added num_sync_steps field to AdvancedImageParams API (range 1-50)
- Changed model configs from num_sync_steps_factor: float to
num_sync_steps: int
  - Updated Flux/Qwen configs with direct values (1, 4, 7 respectively)
  - Added slider control in dashboard advanced params panel
  - Falls back to model default when not specified

## Why It Works

Decouples sync steps from inference steps, giving users direct control
over distributed inference synchronization while preserving sensible
defaults.

## Test Plan

### Manual Testing

  - Generate images with various sync step values via dashboard slider
  - Verify default behavior when parameter is unset
2026-02-06 15:51:46 +00:00
44 changed files with 1560 additions and 228 deletions

3
.gitignore vendored
View File

@@ -35,3 +35,6 @@ hosts_*.json
# bench files
bench/**/*.json
# tmp
tmp/models

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
__all__ = ["Flux1Kontext"]

View File

@@ -0,0 +1,49 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
from typing import Any
from mlx import nn
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
CLIPEncoder,
)
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.model.flux_vae.vae import VAE
from mflux.utils.generated_image import GeneratedImage
class Flux1Kontext(nn.Module):
vae: VAE
transformer: Transformer
t5_text_encoder: T5Encoder
clip_text_encoder: CLIPEncoder
bits: int | None
lora_paths: list[str] | None
lora_scales: list[float] | None
prompt_cache: dict[str, Any]
tokenizers: dict[str, Any]
def __init__(
self,
quantize: int | None = ...,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
model_config: ModelConfig = ...,
) -> None: ...
def generate_image(
self,
seed: int,
prompt: str,
num_inference_steps: int = ...,
height: int = ...,
width: int = ...,
guidance: float = ...,
image_path: Path | str | None = ...,
image_strength: float | None = ...,
scheduler: str = ...,
) -> GeneratedImage: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mflux.models.flux.model.flux_vae.vae import VAE
class KontextUtil:
@staticmethod
def create_image_conditioning_latents(
vae: VAE,
height: int,
width: int,
image_path: str,
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,153 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
max_position_embeddings: int
moe_intermediate_size: int
norm_topk_prob: bool
num_attention_heads: int
n_group: int
head_dim: int
topk_group: int
n_shared_experts: int
n_routed_experts: int
routed_scaling_factor: float
num_experts_per_tok: int
first_k_dense_replace: int
num_hidden_layers: int
num_key_value_heads: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
use_qk_norm: bool
tie_word_embeddings: bool
attention_bias: bool
partial_rotary_factor: float
scoring_func: str
topk_method: str
class Attention(nn.Module):
n_heads: int
n_kv_heads: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
use_qk_norm: bool
q_norm: nn.RMSNorm
k_norm: nn.RMSNorm
rope: nn.RoPE
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class MLP(nn.Module):
config: ModelArgs
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class MoEGate(nn.Module):
config: ModelArgs
top_k: int
norm_topk_prob: bool
n_routed_experts: int
routed_scaling_factor: float
n_group: int
topk_group: int
weight: mx.array
e_score_correction_bias: mx.array
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class MoE(nn.Module):
config: ModelArgs
num_experts_per_tok: int
switch_mlp: SwitchGLU
gate: MoEGate
shared_experts: MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class DecoderLayer(nn.Module):
self_attn: Attention
mlp: MLP | MoE
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class LanguageModel(nn.Module):
vocab_size: int
embed_tokens: nn.Embedding
layers: list[DecoderLayer]
norm: nn.RMSNorm
pipeline_rank: int
pipeline_size: int
start_idx: int
end_idx: Optional[int]
num_layers: int
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
@property
def pipeline_layers(self) -> list[DecoderLayer]: ...
class Model(nn.Module):
args: ModelArgs
model_type: str
model: LanguageModel
lm_head: nn.Linear
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
@property
def layers(self) -> list[DecoderLayer]: ...
@property
def cast_predicate(self) -> Any: ...

View File

@@ -5,21 +5,21 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).

View File

@@ -64,6 +64,8 @@
"1024x1024",
"1024x768",
"768x1024",
"1024x1365",
"1365x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
@@ -148,6 +150,15 @@
setImageGenerationParams({ guidance: null });
}
function handleNumSyncStepsChange(event: Event) {
const value = parseInt((event.target as HTMLInputElement).value, 10);
setImageGenerationParams({ numSyncSteps: value });
}
function clearNumSyncSteps() {
setImageGenerationParams({ numSyncSteps: null });
}
function handleReset() {
resetImageGenerationParams();
showAdvanced = false;
@@ -157,7 +168,8 @@
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
(params.negativePrompt !== null && params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null,
);
</script>
@@ -578,7 +590,50 @@
</div>
</div>
<!-- Row 3: Negative Prompt -->
<!-- Row 3: Sync Steps -->
<div class="flex items-center gap-1.5">
<span
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
>SYNC STEPS:</span
>
<div class="flex items-center gap-2 flex-1 max-w-xs">
<input
type="range"
min="1"
max="100"
value={params.numSyncSteps ?? 1}
oninput={handleNumSyncStepsChange}
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
/>
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
{params.numSyncSteps ?? "--"}
</span>
{#if params.numSyncSteps !== null}
<button
type="button"
onclick={clearNumSyncSteps}
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
title="Clear"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
{/if}
</div>
</div>
<!-- Row 4: Negative Prompt -->
<div class="flex flex-col gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>NEGATIVE PROMPT:</span

View File

@@ -286,7 +286,14 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
// Image generation params interface matching backend API
export interface ImageGenerationParams {
// Basic params
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
size:
| "512x512"
| "768x768"
| "1024x1024"
| "1024x768"
| "768x1024"
| "1024x1365"
| "1365x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
@@ -298,6 +305,7 @@ export interface ImageGenerationParams {
numInferenceSteps: number | null;
guidance: number | null;
negativePrompt: string | null;
numSyncSteps: number | null;
// Edit mode params
inputFidelity: "low" | "high";
}
@@ -319,6 +327,7 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
numInferenceSteps: null,
guidance: null,
negativePrompt: null,
numSyncSteps: null,
inputFidelity: "low",
};
@@ -2396,7 +2405,9 @@ class AppStore {
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "");
(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null;
const requestBody: Record<string, unknown> = {
model,
@@ -2421,6 +2432,9 @@ class AppStore {
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
...(params.numSyncSteps !== null && {
num_sync_steps: params.numSyncSteps,
}),
};
}
@@ -2670,29 +2684,19 @@ class AppStore {
formData.append("input_fidelity", params.inputFidelity);
// Advanced params
if (params.seed !== null) {
formData.append(
"advanced_params",
JSON.stringify({
seed: params.seed,
...(params.numInferenceSteps !== null && {
num_inference_steps: params.numInferenceSteps,
}),
...(params.guidance !== null && { guidance: params.guidance }),
...(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
}),
);
} else if (
const hasAdvancedParams =
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "")
) {
(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null;
if (hasAdvancedParams) {
formData.append(
"advanced_params",
JSON.stringify({
...(params.seed !== null && { seed: params.seed }),
...(params.numInferenceSteps !== null && {
num_inference_steps: params.numInferenceSteps,
}),
@@ -2701,6 +2705,9 @@ class AppStore {
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
...(params.numSyncSteps !== null && {
num_sync_steps: params.numSyncSteps,
}),
}),
);
}

View File

@@ -26,7 +26,7 @@ dependencies = [
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux==0.15.4",
"mflux==0.15.5",
"python-multipart>=0.0.21",
]

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Kontext-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Kontext-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Kontext-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -176,7 +176,7 @@ async def generate_chat_stream(
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ChatCompletionResponse:
) -> AsyncGenerator[str]:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
@@ -223,7 +223,7 @@ async def collect_chat_response(
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
yield ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
@@ -241,4 +241,5 @@ async def collect_chat_response(
finish_reason=finish_reason,
)
],
)
).model_dump_json()
return

View File

@@ -123,6 +123,7 @@ from exo.shared.types.commands import (
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
TextGeneration,
)
@@ -529,16 +530,14 @@ class API:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
@@ -633,11 +632,14 @@ class API:
"X-Accel-Buffering": "no",
},
)
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
else:
return StreamingResponse(
collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
@@ -653,8 +655,7 @@ class API:
command = TextGeneration(task_params=task_params)
await self._send(command)
response = await self._collect_text_generation_with_stats(command.command_id)
return response
return await self._collect_text_generation_with_stats(command.command_id)
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
@@ -856,6 +857,11 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -937,6 +943,11 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))

View File

@@ -23,6 +23,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
TextGeneration,
@@ -38,6 +39,7 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
TraceEventData,
TracesCollected,
TracesMerged,
@@ -278,7 +280,7 @@ class Master:
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
for cmd in cancel_unnecessary_downloads(
placement, self.state.downloads
@@ -298,7 +300,7 @@ class Master:
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case CreateInstance():
@@ -308,7 +310,7 @@ class Master:
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
@@ -318,6 +320,18 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -326,17 +340,12 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time
for i in range(
command.since_idx,
min(command.since_idx + 1000, len(self._event_log)),
):
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)

View File

@@ -22,13 +22,19 @@ from exo.shared.types.commands import (
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.events import (
Event,
InstanceCreated,
InstanceDeleted,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -186,6 +192,7 @@ def delete_instance(
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
tasks: Mapping[TaskId, Task],
) -> Sequence[Event]:
events: list[Event] = []
@@ -201,6 +208,18 @@ def get_transition_events(
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
for task in tasks.values():
if task.instance_id == instance_id and task.task_status in [
TaskStatus.Pending,
TaskStatus.Running,
]:
events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
events.append(
InstanceDeleted(
instance_id=instance_id,

View File

@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
target_instances = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 0
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 1
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
target_instances: dict[InstanceId, Instance] = {}
# act
events = get_transition_events(current_instances, target_instances)
events = get_transition_events(current_instances, target_instances, {})
# assert
assert len(events) == 1

View File

@@ -272,6 +272,7 @@ class AdvancedImageParams(BaseModel):
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
negative_prompt: str | None = None
num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None
class ImageGenerationTaskParams(BaseModel):

View File

@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -89,6 +93,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
runner_id: RunnerId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -87,6 +93,7 @@ Task = (
| LoadModel
| StartWarmup
| TextGeneration
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -125,7 +125,9 @@ class MpSender[T]:
self._state.buffer.put(item, block=True)
async def send_async(self, item: T) -> None:
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
await to_thread.run_sync(
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -1,5 +1,4 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
@@ -23,7 +22,7 @@ class ImageModelConfig(BaseModel):
block_configs: tuple[TransformerBlockConfig, ...]
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
num_sync_steps: int # Number of sync steps for distributed inference
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@@ -45,6 +44,3 @@ class ImageModelConfig(BaseModel):
def get_steps_for_quality(self, quality: str) -> int:
return self.default_steps[quality]
def get_num_sync_steps(self, steps: int) -> int:
return ceil(steps * self.num_sync_steps_factor)

View File

@@ -150,7 +150,10 @@ class DistributedImageModel:
guidance=guidance_override if guidance_override is not None else 4.0,
)
num_sync_steps = self._config.get_num_sync_steps(steps)
if advanced_params is not None and advanced_params.num_sync_steps is not None:
num_sync_steps = advanced_params.num_sync_steps
else:
num_sync_steps = self._config.num_sync_steps
for result in self._runner.generate_image(
runtime_config=config,

View File

@@ -5,7 +5,9 @@ from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_KONTEXT_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxKontextModelAdapter,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
@@ -26,13 +28,16 @@ AdapterFactory = Callable[
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"flux-kontext": FluxKontextModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
# Order matters: longer/more-specific patterns must come before shorter ones
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-kontext": FLUX_KONTEXT_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching

View File

@@ -66,6 +66,19 @@ class PromptData(ABC):
"""
...
@property
@abstractmethod
def kontext_image_ids(self) -> mx.array | None:
"""Kontext-style position IDs for image conditioning.
For FLUX.1-Kontext models, returns position IDs with first_coord=1
to distinguish conditioning tokens from generation tokens (first_coord=0).
Returns:
Position IDs array [1, seq_len, 3] for Kontext, None for other models.
"""
...
@abstractmethod
def get_batched_cfg_data(
self,

View File

@@ -1,11 +1,17 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_KONTEXT_CONFIG,
FLUX_SCHNELL_CONFIG,
)
from exo.worker.engines.image.models.flux.kontext_adapter import (
FluxKontextModelAdapter,
)
__all__ = [
"FluxModelAdapter",
"FluxKontextModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_KONTEXT_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -59,6 +59,10 @@ class FluxPromptData(PromptData):
def conditioning_latents(self) -> mx.array | None:
return None
@property
def kontext_image_ids(self) -> mx.array | None:
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:

View File

@@ -15,7 +15,7 @@ FLUX_SCHNELL_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
num_sync_steps=1,
)
@@ -30,5 +30,21 @@ FLUX_DEV_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
num_sync_steps=4,
)
FLUX_KONTEXT_CONFIG = ImageModelConfig(
model_family="flux-kontext",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps=4,
guidance_scale=4.0,
)

View File

@@ -0,0 +1,348 @@
import math
from pathlib import Path
from typing import Any, final
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
from mflux.models.flux.variants.kontext.kontext_util import KontextUtil
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
ModelAdapter,
PromptData,
RotaryEmbeddings,
)
from exo.worker.engines.image.models.flux.wrappers import (
FluxJointBlockWrapper,
FluxSingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
@final
class FluxKontextPromptData(PromptData):
"""Prompt data for FLUX.1-Kontext image editing.
Stores text embeddings along with conditioning latents and position IDs
for the input image.
"""
def __init__(
self,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
conditioning_latents: mx.array,
kontext_image_ids: mx.array,
):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
self._conditioning_latents = conditioning_latents
self._kontext_image_ids = kontext_image_ids
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
return None
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
return None
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
return None
@property
def conditioning_latents(self) -> mx.array | None:
"""VAE-encoded input image latents for Kontext conditioning."""
return self._conditioning_latents
@property
def kontext_image_ids(self) -> mx.array | None:
"""Position IDs for Kontext conditioning (first_coord=1)."""
return self._kontext_image_ids
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Kontext doesn't use CFG, but we return positive data for compatibility."""
return (
self._prompt_embeds,
None,
self._pooled_prompt_embeds,
self._conditioning_latents,
)
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
# Kontext doesn't use CFG
return None
@final
class FluxKontextModelAdapter(ModelAdapter[Flux1Kontext, Transformer]):
"""Adapter for FLUX.1-Kontext image editing model.
Key differences from standard FluxModelAdapter:
- Takes an input image and computes output dimensions from it
- Creates conditioning latents from the input image via VAE
- Creates special position IDs (kontext_image_ids) for conditioning tokens
- Creates pure noise latents (not img2img blending)
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1Kontext(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
# Stores image path and computed dimensions after set_image_dimensions
self._image_path: str | None = None
self._output_height: int | None = None
self._output_width: int | None = None
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
@property
def needs_cfg(self) -> bool:
return False
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper[Any]]:
"""Create wrapped joint blocks for Flux Kontext."""
return [
FluxJointBlockWrapper(block, text_seq_len)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper[Any]]:
"""Create wrapped single blocks for Flux Kontext."""
return [
FluxSingleBlockWrapper(block, text_seq_len)
for block in self._transformer.single_transformer_blocks
]
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
all_joint = list(self._transformer.transformer_blocks)
all_single = list(self._transformer.single_transformer_blocks)
total_joint_blocks = len(all_joint)
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_path for use in encode_prompt().
Args:
image_path: Path to the input image
Returns:
(output_width, output_height) for runtime config
"""
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Compute output dimensions from input image aspect ratio
# Target area of 1024x1024 = ~1M pixels
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
self._image_path = str(image_path)
self._output_width = int(output_width)
self._output_height = int(output_height)
return self._output_width, self._output_height
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial noise latents for Kontext.
Unlike standard img2img which blends noise with encoded input,
Kontext uses pure noise latents. The input image is provided
separately as conditioning.
"""
return FluxLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> FluxKontextPromptData:
"""Encode prompt and create conditioning from stored input image.
Must call set_image_dimensions() before this method.
Args:
prompt: Text prompt for editing
negative_prompt: Ignored (Kontext doesn't use CFG)
Returns:
FluxKontextPromptData with text embeddings and image conditioning
"""
del negative_prompt # Kontext doesn't support negative prompts or CFG
if (
self._image_path is None
or self._output_height is None
or self._output_width is None
):
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for FluxKontextModelAdapter"
)
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
# Encode text prompt
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self.model.prompt_cache,
t5_tokenizer=self.model.tokenizers["t5"], # pyright: ignore[reportAny]
clip_tokenizer=self.model.tokenizers["clip"], # pyright: ignore[reportAny]
t5_text_encoder=self.model.t5_text_encoder,
clip_text_encoder=self.model.clip_text_encoder,
)
# Create conditioning latents from input image
conditioning_latents, kontext_image_ids = (
KontextUtil.create_image_conditioning_latents(
vae=self.model.vae,
height=self._output_height,
width=self._output_width,
image_path=self._image_path,
)
)
return FluxKontextPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
conditioning_latents=conditioning_latents,
kontext_image_ids=kontext_image_ids,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux Kontext text embeddings"
)
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> RotaryEmbeddings:
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux Kontext does not use classifier-free guidance")

View File

@@ -69,6 +69,10 @@ class QwenPromptData(PromptData):
def conditioning_latents(self) -> mx.array | None:
return None
@property
def kontext_image_ids(self) -> mx.array | None:
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:

View File

@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.25,
num_sync_steps=7,
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
)
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.25,
num_sync_steps=7,
guidance_scale=3.5,
)

View File

@@ -85,6 +85,10 @@ class QwenEditPromptData(PromptData):
def qwen_image_ids(self) -> mx.array:
return self._qwen_image_ids
@property
def kontext_image_ids(self) -> mx.array | None:
return None
@property
def is_edit_mode(self) -> bool:
return True

View File

@@ -567,6 +567,7 @@ class DiffusionRunner:
| list[tuple[int, int, int]]
| None = None,
conditioning_latents: mx.array | None = None,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
"""Run a single forward pass through the transformer.
Args:
@@ -578,6 +579,7 @@ class DiffusionRunner:
encoder_hidden_states_mask: Attention mask for text (Qwen)
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
conditioning_latents: Conditioning latents for edit mode
kontext_image_ids: Position IDs for Kontext conditioning (Flux Kontext)
Returns:
Noise prediction tensor
@@ -610,6 +612,7 @@ class DiffusionRunner:
config,
encoder_hidden_states_mask=encoder_hidden_states_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
assert self.joint_block_wrappers is not None
@@ -681,6 +684,7 @@ class DiffusionRunner:
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
@@ -700,6 +704,7 @@ class DiffusionRunner:
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
conditioning_latents=branch.cond_latents,
kontext_image_ids=kontext_image_ids,
)
results.append((branch.positive, noise))
@@ -902,10 +907,10 @@ class DiffusionRunner:
config: Config,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
@@ -979,10 +984,10 @@ class DiffusionRunner:
latents: mx.array,
prompt_data: PromptData,
is_first_async_step: bool,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
cond_image_grid = prompt_data.cond_image_grid
kontext_image_ids = prompt_data.kontext_image_ids
prev_patch_latents = [p for p in patch_latents]

View File

@@ -386,7 +386,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
elif isinstance(model, Glm4MoeModel):
tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -512,9 +520,6 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.q_b_proj
)
# layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
# layer.self_attn.kv_b_proj
# )
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
@@ -544,7 +549,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
mx.eval(layer)
@@ -552,7 +557,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
return model
class ShardedDeepseekV3MoE(CustomMlxLayer):
class ShardedMoE(CustomMlxLayer):
"""Wraps any MoE layer with distributed sum_gradients / all_sum."""
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
@@ -623,27 +630,13 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # type: ignore
mx.eval(layer)
return model
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class WrappedMiniMaxAttention(CustomMlxLayer):
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
super().__init__(layer)
@@ -756,7 +749,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe = ShardedMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
@@ -861,9 +854,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
if isinstance(
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
):
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
@@ -875,7 +866,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.shared_expert.down_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp = ShardedMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -888,18 +879,50 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
return model
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Glm4MoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.n_heads //= self.N
layer.self_attn.n_kv_heads //= self.N
if isinstance(layer.mlp, MoE):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
layer.mlp = ShardedMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group
else:
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model
class GptOssShardingStrategy(TensorParallelShardingStrategy):
@@ -937,21 +960,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class ShardedGptOssMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -290,7 +290,6 @@ def make_kv_cache(
) -> KVCacheType:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache"):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore

View File

@@ -67,8 +67,6 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -86,30 +84,6 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -586,3 +560,23 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -32,6 +32,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -218,15 +219,22 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
await runner.start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
runner.shutdown()
case CancelTask(
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -264,18 +272,18 @@ class Worker:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
await self._start_runner_task(modified_task)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
await self._start_runner_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:
await self.runners[
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
@@ -314,8 +322,6 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ConnectToGroup,
CreateRunner,
DownloadModel,
@@ -53,13 +54,14 @@ def plan(
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
_cancel_tasks(runners, tasks)
or _kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
)
@@ -270,7 +272,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -284,7 +286,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
received = len(input_chunk_buffer.get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -292,16 +294,33 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner_id, runner in runners.items():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id,
cancelled_task_id=task.task_id,
runner_id=runner_id,
)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,6 +15,7 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -38,7 +39,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -1,5 +1,6 @@
import base64
import json
import math
import time
from collections.abc import Generator
from functools import cache
@@ -87,6 +88,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -111,6 +113,7 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -125,11 +128,15 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
model: Model | DistributedImageModel | None = None
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -142,6 +149,7 @@ def main(
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -187,7 +195,7 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
model, tokenizer = load_mlx_items(
inference_model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
@@ -199,7 +207,7 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
model = initialize_image_model(bound_instance)
image_model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -207,8 +215,6 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -220,16 +226,31 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
t = time.perf_counter()
toks = warmup_inference(
model=model,
model=inference_model,
tokenizer=tokenizer,
group=group,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / (time.perf_counter() - t)), 100
)
if group is not None:
check_for_cancel_every = int(
mx.max(
mx.distributed.all_gather(
mx.array([check_for_cancel_every]), group=group
)
).item()
)
logger.info(
f"runner checking for cancellation every {check_for_cancel_every} tokens"
)
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
@@ -237,8 +258,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -258,9 +279,9 @@ def main(
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
assert model and not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
assert check_for_cancel_every
try:
_check_for_debug_prompts(task_params)
@@ -270,7 +291,7 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
model=inference_model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
@@ -295,11 +316,11 @@ def main(
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
elif isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
model, GptOssModel
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -312,7 +333,18 @@ def main(
)
completion_tokens = 0
tokens_since_last_cancel_check = 0
for response in mlx_generator:
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= check_for_cancel_every:
tokens_since_last_cancel_check = 0
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
completion_tokens += 1
@@ -384,7 +416,7 @@ def main(
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -397,7 +429,9 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
@@ -447,7 +481,7 @@ def main(
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, DistributedImageModel)
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -460,7 +494,9 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
for response in generate_image(
model=image_model, task=task_params
):
if _is_primary_output_node(shard_metadata):
match response:
case PartialImageResponse():
@@ -526,7 +562,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc

View File

@@ -47,9 +47,11 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -60,8 +62,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -69,6 +71,7 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -83,6 +86,7 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -97,6 +101,8 @@ class RunnerSupervisor:
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
@@ -112,14 +118,6 @@ class RunnerSupervisor:
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(
@@ -141,6 +139,17 @@ class RunnerSupervisor:
return
await event.wait()
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
with anyio.move_on_after(0.5) as scope:
await self._cancel_sender.send_async(task_id)
if scope.cancel_called:
logger.error("RunnerSupervisor cancel pipe blocked")
await self._check_runner(TimeoutError("cancel pipe blocked"))
async def _forward_events(self):
with self._ev_recv as events:
try:

View File

@@ -1,7 +1,9 @@
# Check tasks are complete before runner is ever ready.
import unittest.mock
from collections.abc import Iterable
from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.runner as mlx_runner
@@ -19,6 +21,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
TextGeneration,
)
@@ -113,6 +116,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
@@ -163,6 +167,7 @@ def _run(tasks: Iterable[Task]):
)
task_sender, task_receiver = mp_channel[Task]()
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
event_sender = EventCollector()
with task_sender:
@@ -173,8 +178,13 @@ def _run(tasks: Iterable[Task]):
# this is some c++ nonsense
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
with unittest.mock.patch(
"exo.worker.runner.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
mlx_runner.main(
bound_instance, event_sender, task_receiver, cancel_receiver
) # pyright: ignore[reportArgumentType]
return event_sender.events

377
tmp/quantize_and_upload.py Executable file
View File

@@ -0,0 +1,377 @@
#!/usr/bin/env python3
"""
Download an mflux model, quantize it, and upload to HuggingFace.
Usage (run from mflux project directory):
cd /path/to/mflux
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
Requires:
- Must be run from mflux project directory using `uv run`
- huggingface_hub installed (add to mflux deps or install separately)
- HuggingFace authentication: run `huggingface-cli login` or set HF_TOKEN
"""
from __future__ import annotations
import argparse
import re
import shutil
import sys
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from mflux.models.flux.variants.txt2img.flux import Flux1
HF_ORG = "exolabs"
def get_model_class(model_name: str) -> type:
"""Get the appropriate model class based on model name."""
from mflux.models.fibo.variants.txt2img.fibo import FIBO
from mflux.models.flux.variants.txt2img.flux import Flux1
from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from mflux.models.z_image.variants.turbo.z_image_turbo import ZImageTurbo
model_name_lower = model_name.lower()
if "qwen" in model_name_lower:
return QwenImage
elif "fibo" in model_name_lower:
return FIBO
elif "z-image" in model_name_lower or "zimage" in model_name_lower:
return ZImageTurbo
elif "flux2" in model_name_lower or "flux.2" in model_name_lower:
return Flux2Klein
else:
return Flux1
def get_repo_name(model_name: str, bits: int | None) -> str:
"""Get the HuggingFace repo name for a model variant."""
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
suffix = f"-{bits}bit" if bits else ""
return f"{HF_ORG}/{base_name}{suffix}"
def get_local_path(output_dir: Path, model_name: str, bits: int | None) -> Path:
"""Get the local save path for a model variant."""
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
suffix = f"-{bits}bit" if bits else ""
return output_dir / f"{base_name}{suffix}"
def copy_source_repo(
source_repo: str,
local_path: Path,
dry_run: bool = False,
) -> None:
"""Copy all files from source repo (replicating original HF structure)."""
print(f"\n{'=' * 60}")
print(f"Copying full repo from source: {source_repo}")
print(f"Output path: {local_path}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would download all files from source repo")
return
from huggingface_hub import snapshot_download
# Download all files to our local path
snapshot_download(
repo_id=source_repo,
local_dir=local_path,
)
# Remove root-level safetensors files (flux.1-dev.safetensors, etc.)
# These are redundant with the component directories
for f in local_path.glob("*.safetensors"):
print(f"Removing root-level safetensors: {f.name}")
if not dry_run:
f.unlink()
print(f"Source repo copied to {local_path}")
def load_and_save_quantized_model(
model_name: str,
bits: int,
output_path: Path,
dry_run: bool = False,
) -> None:
"""Load a model with quantization and save it in mflux format."""
print(f"\n{'=' * 60}")
print(f"Loading {model_name} with {bits}-bit quantization...")
print(f"Output path: {output_path}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would load and save quantized model")
return
from mflux.models.common.config.model_config import ModelConfig
model_class = get_model_class(model_name)
model_config = ModelConfig.from_name(model_name=model_name, base_model=None)
model: Flux1 = model_class(
quantize=bits,
model_config=model_config,
)
print(f"Saving model to {output_path}...")
model.save_model(str(output_path))
print(f"Model saved successfully to {output_path}")
def copy_source_metadata(
source_repo: str,
local_path: Path,
dry_run: bool = False,
) -> None:
"""Copy metadata files (LICENSE, README, etc.) from source repo, excluding safetensors."""
print(f"\n{'=' * 60}")
print(f"Copying metadata from source repo: {source_repo}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would download metadata files (excluding *.safetensors)")
return
from huggingface_hub import snapshot_download
# Download all files except safetensors to our local path
snapshot_download(
repo_id=source_repo,
local_dir=local_path,
ignore_patterns=["*.safetensors"],
)
print(f"Metadata files copied to {local_path}")
def upload_to_huggingface(
local_path: Path,
repo_id: str,
dry_run: bool = False,
clean_remote: bool = False,
) -> None:
"""Upload a saved model to HuggingFace."""
print(f"\n{'=' * 60}")
print(f"Uploading to HuggingFace: {repo_id}")
print(f"Local path: {local_path}")
print(f"Clean remote first: {clean_remote}")
print(f"{'=' * 60}")
if dry_run:
print("[DRY RUN] Would upload to HuggingFace")
return
from huggingface_hub import HfApi
api = HfApi()
# Create the repo if it doesn't exist
print(f"Creating/verifying repo: {repo_id}")
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
# Clean remote repo if requested (delete old mflux-format files)
if clean_remote:
print("Cleaning old mflux-format files from remote...")
try:
# Pattern for mflux numbered shards: <dir>/<number>.safetensors
numbered_pattern = re.compile(r".*/\d+\.safetensors$")
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="model")
for file_path in repo_files:
# Delete numbered safetensors (mflux format) and mflux index files
if numbered_pattern.match(file_path) or file_path.endswith(
"/model.safetensors.index.json"
):
print(f" Deleting: {file_path}")
api.delete_file(
path_in_repo=file_path, repo_id=repo_id, repo_type="model"
)
except Exception as e:
print(f"Warning: Could not clean remote files: {e}")
# Upload the folder
print("Uploading folder contents...")
api.upload_folder(
folder_path=str(local_path),
repo_id=repo_id,
repo_type="model",
)
print(f"Upload complete: https://huggingface.co/{repo_id}")
def clean_local_files(local_path: Path, dry_run: bool = False) -> None:
"""Remove local model files after upload."""
print(f"\nCleaning up: {local_path}")
if dry_run:
print("[DRY RUN] Would remove local files")
return
if local_path.exists():
shutil.rmtree(local_path)
print(f"Removed {local_path}")
def main() -> int:
parser = argparse.ArgumentParser(
description="Download an mflux model, quantize it, and upload to HuggingFace.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Process all variants (base, 4-bit, 8-bit) for FLUX.1-Kontext-dev
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
# Only process 4-bit variant
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
# Save locally without uploading
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-upload
# Preview what would happen
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
""",
)
parser.add_argument(
"--model",
"-m",
required=True,
help="HuggingFace model path (e.g., black-forest-labs/FLUX.1-Kontext-dev)",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("./tmp/models"),
help="Local directory to save models (default: ./tmp/models)",
)
parser.add_argument(
"--skip-base",
action="store_true",
help="Skip base model (no quantization)",
)
parser.add_argument(
"--skip-4bit",
action="store_true",
help="Skip 4-bit quantized model",
)
parser.add_argument(
"--skip-8bit",
action="store_true",
help="Skip 8-bit quantized model",
)
parser.add_argument(
"--skip-download",
action="store_true",
help="Skip downloading/processing, only do upload/clean operations",
)
parser.add_argument(
"--skip-upload",
action="store_true",
help="Only save locally, don't upload to HuggingFace",
)
parser.add_argument(
"--clean",
action="store_true",
help="Remove local files after upload",
)
parser.add_argument(
"--clean-remote",
action="store_true",
help="Delete old mflux-format files from remote repo before uploading",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print actions without executing",
)
args = parser.parse_args()
# Determine which variants to process
variants: list[int | None] = []
if not args.skip_base:
variants.append(None) # Base model (no quantization)
if not args.skip_4bit:
variants.append(4)
if not args.skip_8bit:
variants.append(8)
if not variants:
print("Error: All variants skipped. Nothing to do.")
return 1
# Create output directory
args.output_dir.mkdir(parents=True, exist_ok=True)
print(f"Model: {args.model}")
print(f"Output directory: {args.output_dir}")
print(
f"Variants to process: {['base' if v is None else f'{v}-bit' for v in variants]}"
)
print(f"Upload to HuggingFace: {not args.skip_upload}")
print(f"Clean after upload: {args.clean}")
if args.dry_run:
print("\n*** DRY RUN MODE - No actual changes will be made ***")
# Process each variant
for bits in variants:
local_path = get_local_path(args.output_dir, args.model, bits)
repo_id = get_repo_name(args.model, bits)
if not args.skip_download:
if bits is None:
# Base model: copy original HF repo structure (no mflux conversion)
copy_source_repo(
source_repo=args.model,
local_path=local_path,
dry_run=args.dry_run,
)
else:
# Quantized model: load, quantize, and save with mflux
load_and_save_quantized_model(
model_name=args.model,
bits=bits,
output_path=local_path,
dry_run=args.dry_run,
)
# Copy metadata from source repo (LICENSE, README, etc.)
copy_source_metadata(
source_repo=args.model,
local_path=local_path,
dry_run=args.dry_run,
)
# Upload
if not args.skip_upload:
upload_to_huggingface(
local_path=local_path,
repo_id=repo_id,
dry_run=args.dry_run,
clean_remote=args.clean_remote,
)
# Clean up if requested
if args.clean:
clean_local_files(local_path, dry_run=args.dry_run)
print("\n" + "=" * 60)
print("All done!")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())

20
uv.lock generated
View File

@@ -192,20 +192,14 @@ sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8
wheels = [
{ url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446, upload-time = "2025-09-08T23:23:03.472Z" },
{ url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101, upload-time = "2025-09-08T23:23:04.792Z" },
{ url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948, upload-time = "2025-09-08T23:23:06.127Z" },
{ url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422, upload-time = "2025-09-08T23:23:07.753Z" },
{ url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499, upload-time = "2025-09-08T23:23:09.648Z" },
{ url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928, upload-time = "2025-09-08T23:23:10.928Z" },
{ url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302, upload-time = "2025-09-08T23:23:12.42Z" },
{ url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049, upload-time = "2025-09-08T23:23:20.853Z" },
{ url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793, upload-time = "2025-09-08T23:23:22.08Z" },
{ url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300, upload-time = "2025-09-08T23:23:23.314Z" },
{ url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244, upload-time = "2025-09-08T23:23:24.541Z" },
{ url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828, upload-time = "2025-09-08T23:23:26.143Z" },
{ url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926, upload-time = "2025-09-08T23:23:27.873Z" },
{ url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593, upload-time = "2025-09-08T23:23:31.91Z" },
{ url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354, upload-time = "2025-09-08T23:23:33.214Z" },
{ url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480, upload-time = "2025-09-08T23:23:34.495Z" },
{ url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584, upload-time = "2025-09-08T23:23:36.096Z" },
{ url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443, upload-time = "2025-09-08T23:23:37.328Z" },
{ url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437, upload-time = "2025-09-08T23:23:38.945Z" },
@@ -311,10 +305,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5c/49/498c86566a1d80e978b42f0d702795f69887005548c041636df6ae1ca64c/cryptography-46.0.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ca9ff2885f3acc98c29f1860552e37f6d7c7d013d7334ff2a9de43a449315d", size = 4450807, upload-time = "2025-10-15T23:16:56.414Z" },
{ url = "https://files.pythonhosted.org/packages/4b/0a/863a3604112174c8624a2ac3c038662d9e59970c7f926acdcfaed8d61142/cryptography-46.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6eae65d4c3d33da080cff9c4ab1f711b15c1d9760809dad6ea763f3812d254cb", size = 4299615, upload-time = "2025-10-15T23:16:58.442Z" },
{ url = "https://files.pythonhosted.org/packages/64/02/b73a533f6b64a69f3cd3872acb6ebc12aef924d8d103133bb3ea750dc703/cryptography-46.0.3-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5bf0ed4490068a2e72ac03d786693adeb909981cc596425d09032d372bcc849", size = 4016800, upload-time = "2025-10-15T23:17:00.378Z" },
{ url = "https://files.pythonhosted.org/packages/25/d5/16e41afbfa450cde85a3b7ec599bebefaef16b5c6ba4ec49a3532336ed72/cryptography-46.0.3-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5ecfccd2329e37e9b7112a888e76d9feca2347f12f37918facbb893d7bb88ee8", size = 4984707, upload-time = "2025-10-15T23:17:01.98Z" },
{ url = "https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a2c0cd47381a3229c403062f764160d57d4d175e022c1df84e168c6251a22eec", size = 4482541, upload-time = "2025-10-15T23:17:04.078Z" },
{ url = "https://files.pythonhosted.org/packages/78/f6/50736d40d97e8483172f1bb6e698895b92a223dba513b0ca6f06b2365339/cryptography-46.0.3-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:549e234ff32571b1f4076ac269fcce7a808d3bf98b76c8dd560e42dbc66d7d91", size = 4299464, upload-time = "2025-10-15T23:17:05.483Z" },
{ url = "https://files.pythonhosted.org/packages/00/de/d8e26b1a855f19d9994a19c702fa2e93b0456beccbcfe437eda00e0701f2/cryptography-46.0.3-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:c0a7bb1a68a5d3471880e264621346c48665b3bf1c3759d682fc0864c540bd9e", size = 4950838, upload-time = "2025-10-15T23:17:07.425Z" },
{ url = "https://files.pythonhosted.org/packages/8f/29/798fc4ec461a1c9e9f735f2fc58741b0daae30688f41b2497dcbc9ed1355/cryptography-46.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:10b01676fc208c3e6feeb25a8b83d81767e8059e1fe86e1dc62d10a3018fa926", size = 4481596, upload-time = "2025-10-15T23:17:09.343Z" },
{ url = "https://files.pythonhosted.org/packages/15/8d/03cd48b20a573adfff7652b76271078e3045b9f49387920e7f1f631d125e/cryptography-46.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0abf1ffd6e57c67e92af68330d05760b7b7efb243aab8377e583284dbab72c71", size = 4426782, upload-time = "2025-10-15T23:17:11.22Z" },
{ url = "https://files.pythonhosted.org/packages/fa/b1/ebacbfe53317d55cf33165bda24c86523497a6881f339f9aae5c2e13e57b/cryptography-46.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a04bee9ab6a4da801eb9b51f1b708a1b5b5c9eb48c03f74198464c66f0d344ac", size = 4698381, upload-time = "2025-10-15T23:17:12.829Z" },
@@ -322,10 +314,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c5/fd/bc1daf8230eaa075184cbbf5f8cd00ba9db4fd32d63fb83da4671b72ed8a/cryptography-46.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39b6755623145ad5eff1dab323f4eae2a32a77a7abef2c5089a04a3d04366715", size = 4435078, upload-time = "2025-10-15T23:17:23.042Z" },
{ url = "https://files.pythonhosted.org/packages/82/98/d3bd5407ce4c60017f8ff9e63ffee4200ab3e23fe05b765cab805a7db008/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:db391fa7c66df6762ee3f00c95a89e6d428f4d60e7abc8328f4fe155b5ac6e54", size = 4293460, upload-time = "2025-10-15T23:17:24.885Z" },
{ url = "https://files.pythonhosted.org/packages/26/e9/e23e7900983c2b8af7a08098db406cf989d7f09caea7897e347598d4cd5b/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:78a97cf6a8839a48c49271cdcbd5cf37ca2c1d6b7fdd86cc864f302b5e9bf459", size = 3995237, upload-time = "2025-10-15T23:17:26.449Z" },
{ url = "https://files.pythonhosted.org/packages/91/15/af68c509d4a138cfe299d0d7ddb14afba15233223ebd933b4bbdbc7155d3/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:dfb781ff7eaa91a6f7fd41776ec37c5853c795d3b358d4896fdbb5df168af422", size = 4967344, upload-time = "2025-10-15T23:17:28.06Z" },
{ url = "https://files.pythonhosted.org/packages/ca/e3/8643d077c53868b681af077edf6b3cb58288b5423610f21c62aadcbe99f4/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6f61efb26e76c45c4a227835ddeae96d83624fb0d29eb5df5b96e14ed1a0afb7", size = 4466564, upload-time = "2025-10-15T23:17:29.665Z" },
{ url = "https://files.pythonhosted.org/packages/0e/43/c1e8726fa59c236ff477ff2b5dc071e54b21e5a1e51aa2cee1676f1c986f/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:23b1a8f26e43f47ceb6d6a43115f33a5a37d57df4ea0ca295b780ae8546e8044", size = 4292415, upload-time = "2025-10-15T23:17:31.686Z" },
{ url = "https://files.pythonhosted.org/packages/42/f9/2f8fefdb1aee8a8e3256a0568cffc4e6d517b256a2fe97a029b3f1b9fe7e/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:b419ae593c86b87014b9be7396b385491ad7f320bde96826d0dd174459e54665", size = 4931457, upload-time = "2025-10-15T23:17:33.478Z" },
{ url = "https://files.pythonhosted.org/packages/79/30/9b54127a9a778ccd6d27c3da7563e9f2d341826075ceab89ae3b41bf5be2/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:50fc3343ac490c6b08c0cf0d704e881d0d660be923fd3076db3e932007e726e3", size = 4466074, upload-time = "2025-10-15T23:17:35.158Z" },
{ url = "https://files.pythonhosted.org/packages/ac/68/b4f4a10928e26c941b1b6a179143af9f4d27d88fe84a6a3c53592d2e76bf/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22d7e97932f511d6b0b04f2bfd818d73dcd5928db509460aaf48384778eb6d20", size = 4420569, upload-time = "2025-10-15T23:17:37.188Z" },
{ url = "https://files.pythonhosted.org/packages/a3/49/3746dab4c0d1979888f125226357d3262a6dd40e114ac29e3d2abdf1ec55/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d55f3dffadd674514ad19451161118fd010988540cee43d8bc20675e775925de", size = 4681941, upload-time = "2025-10-15T23:17:39.236Z" },
@@ -333,10 +323,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/26/42/fa8389d4478368743e24e61eea78846a0006caffaf72ea24a15159215a14/cryptography-46.0.3-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:15ab9b093e8f09daab0f2159bb7e47532596075139dd74365da52ecc9cb46c5d", size = 4440029, upload-time = "2025-10-15T23:17:49.837Z" },
{ url = "https://files.pythonhosted.org/packages/5f/eb/f483db0ec5ac040824f269e93dd2bd8a21ecd1027e77ad7bdf6914f2fd80/cryptography-46.0.3-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:46acf53b40ea38f9c6c229599a4a13f0d46a6c3fa9ef19fc1a124d62e338dfa0", size = 4297222, upload-time = "2025-10-15T23:17:51.357Z" },
{ url = "https://files.pythonhosted.org/packages/fd/cf/da9502c4e1912cb1da3807ea3618a6829bee8207456fbbeebc361ec38ba3/cryptography-46.0.3-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10ca84c4668d066a9878890047f03546f3ae0a6b8b39b697457b7757aaf18dbc", size = 4012280, upload-time = "2025-10-15T23:17:52.964Z" },
{ url = "https://files.pythonhosted.org/packages/6b/8f/9adb86b93330e0df8b3dcf03eae67c33ba89958fc2e03862ef1ac2b42465/cryptography-46.0.3-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:36e627112085bb3b81b19fed209c05ce2a52ee8b15d161b7c643a7d5a88491f3", size = 4978958, upload-time = "2025-10-15T23:17:54.965Z" },
{ url = "https://files.pythonhosted.org/packages/d1/a0/5fa77988289c34bdb9f913f5606ecc9ada1adb5ae870bd0d1054a7021cc4/cryptography-46.0.3-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1000713389b75c449a6e979ffc7dcc8ac90b437048766cef052d4d30b8220971", size = 4473714, upload-time = "2025-10-15T23:17:56.754Z" },
{ url = "https://files.pythonhosted.org/packages/14/e5/fc82d72a58d41c393697aa18c9abe5ae1214ff6f2a5c18ac470f92777895/cryptography-46.0.3-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:b02cf04496f6576afffef5ddd04a0cb7d49cf6be16a9059d793a30b035f6b6ac", size = 4296970, upload-time = "2025-10-15T23:17:58.588Z" },
{ url = "https://files.pythonhosted.org/packages/78/06/5663ed35438d0b09056973994f1aec467492b33bd31da36e468b01ec1097/cryptography-46.0.3-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:71e842ec9bc7abf543b47cf86b9a743baa95f4677d22baa4c7d5c69e49e9bc04", size = 4940236, upload-time = "2025-10-15T23:18:00.897Z" },
{ url = "https://files.pythonhosted.org/packages/fc/59/873633f3f2dcd8a053b8dd1d38f783043b5fce589c0f6988bf55ef57e43e/cryptography-46.0.3-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:402b58fc32614f00980b66d6e56a5b4118e6cb362ae8f3fda141ba4689bd4506", size = 4472642, upload-time = "2025-10-15T23:18:02.749Z" },
{ url = "https://files.pythonhosted.org/packages/3d/39/8e71f3930e40f6877737d6f69248cf74d4e34b886a3967d32f919cc50d3b/cryptography-46.0.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ef639cb3372f69ec44915fafcd6698b6cc78fbe0c2ea41be867f6ed612811963", size = 4423126, upload-time = "2025-10-15T23:18:04.85Z" },
{ url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573, upload-time = "2025-10-15T23:18:06.908Z" },
@@ -412,7 +400,7 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.5" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.5" },
{ name = "mlx-lm", specifier = "==0.30.6" },
@@ -987,7 +975,7 @@ wheels = [
[[package]]
name = "mflux"
version = "0.15.4"
version = "0.15.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1013,9 +1001,9 @@ dependencies = [
{ name = "twine", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a6/f8/95322db7a865e4df6bad108b1c99aa7fbe211aac3f298f3ad696c2744a39/mflux-0.15.4.tar.gz", hash = "sha256:138e1aedae86e13eafeb8faec017945fcdcca42c3234daabcd81a83c9a202ace", size = 741228, upload-time = "2026-01-20T15:39:26.807Z" }
sdist = { url = "https://files.pythonhosted.org/packages/35/8e/f20de51bf9dc0a986535d9a825db4ae314163421b3d3ddaa90a2b959b9fd/mflux-0.15.5.tar.gz", hash = "sha256:9a3372bd64d51c4caff4ff9e7d7d698bea5833242fd849c59cbb0c92f7d7aa3b", size = 743700, upload-time = "2026-01-26T12:41:45.272Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/be/81cf4ce2d1933b9b210c028a05ac95e958008c0d43e377a5f2757b7f2d4d/mflux-0.15.4-py3-none-any.whl", hash = "sha256:f04d9b1d7c5cd67880f483ab29fb2097648a25459eef9c5ee6480fad46de5e82", size = 987644, upload-time = "2026-01-20T15:39:24.817Z" },
{ url = "https://files.pythonhosted.org/packages/ac/bb/ef936eae2ae78a47cd92ddffc18fc06ad3fd5f438a0915fb62d8bb9508ec/mflux-0.15.5-py3-none-any.whl", hash = "sha256:c94891d4a518047a818863bb099c755e93af90c524ced358baf5b31502c09e82", size = 990939, upload-time = "2026-01-26T12:41:42.898Z" },
]
[[package]]