mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-06 12:11:22 -05:00
Compare commits
1 Commits
ciaran/mes
...
update-tod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cc994c19a |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -35,6 +35,3 @@ hosts_*.json
|
||||
|
||||
# bench files
|
||||
bench/**/*.json
|
||||
|
||||
# tmp
|
||||
tmp/models
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
|
||||
|
||||
__all__ = ["Flux1Kontext"]
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from mlx import nn
|
||||
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
|
||||
CLIPEncoder,
|
||||
)
|
||||
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.model.flux_vae.vae import VAE
|
||||
from mflux.utils.generated_image import GeneratedImage
|
||||
|
||||
class Flux1Kontext(nn.Module):
|
||||
vae: VAE
|
||||
transformer: Transformer
|
||||
t5_text_encoder: T5Encoder
|
||||
clip_text_encoder: CLIPEncoder
|
||||
bits: int | None
|
||||
lora_paths: list[str] | None
|
||||
lora_scales: list[float] | None
|
||||
prompt_cache: dict[str, Any]
|
||||
tokenizers: dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quantize: int | None = ...,
|
||||
model_path: str | None = ...,
|
||||
lora_paths: list[str] | None = ...,
|
||||
lora_scales: list[float] | None = ...,
|
||||
model_config: ModelConfig = ...,
|
||||
) -> None: ...
|
||||
def generate_image(
|
||||
self,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
num_inference_steps: int = ...,
|
||||
height: int = ...,
|
||||
width: int = ...,
|
||||
guidance: float = ...,
|
||||
image_path: Path | str | None = ...,
|
||||
image_strength: float | None = ...,
|
||||
scheduler: str = ...,
|
||||
) -> GeneratedImage: ...
|
||||
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mflux.models.flux.model.flux_vae.vae import VAE
|
||||
|
||||
class KontextUtil:
|
||||
@staticmethod
|
||||
def create_image_conditioning_latents(
|
||||
vae: VAE,
|
||||
height: int,
|
||||
width: int,
|
||||
image_path: str,
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@@ -1,153 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
max_position_embeddings: int
|
||||
moe_intermediate_size: int
|
||||
norm_topk_prob: bool
|
||||
num_attention_heads: int
|
||||
n_group: int
|
||||
head_dim: int
|
||||
topk_group: int
|
||||
n_shared_experts: int
|
||||
n_routed_experts: int
|
||||
routed_scaling_factor: float
|
||||
num_experts_per_tok: int
|
||||
first_k_dense_replace: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
use_qk_norm: bool
|
||||
tie_word_embeddings: bool
|
||||
attention_bias: bool
|
||||
partial_rotary_factor: float
|
||||
scoring_func: str
|
||||
topk_method: str
|
||||
|
||||
class Attention(nn.Module):
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
use_qk_norm: bool
|
||||
q_norm: nn.RMSNorm
|
||||
k_norm: nn.RMSNorm
|
||||
rope: nn.RoPE
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class MLP(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
config: ModelArgs
|
||||
top_k: int
|
||||
norm_topk_prob: bool
|
||||
n_routed_experts: int
|
||||
routed_scaling_factor: float
|
||||
n_group: int
|
||||
topk_group: int
|
||||
weight: mx.array
|
||||
e_score_correction_bias: mx.array
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class MoE(nn.Module):
|
||||
config: ModelArgs
|
||||
num_experts_per_tok: int
|
||||
switch_mlp: SwitchGLU
|
||||
gate: MoEGate
|
||||
shared_experts: MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
self_attn: Attention
|
||||
mlp: MLP | MoE
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
vocab_size: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[DecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
pipeline_rank: int
|
||||
pipeline_size: int
|
||||
start_idx: int
|
||||
end_idx: Optional[int]
|
||||
num_layers: int
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
@property
|
||||
def pipeline_layers(self) -> list[DecoderLayer]: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
args: ModelArgs
|
||||
model_type: str
|
||||
model: LanguageModel
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
|
||||
@property
|
||||
def layers(self) -> list[DecoderLayer]: ...
|
||||
@property
|
||||
def cast_predicate(self) -> Any: ...
|
||||
14
TODO.md
14
TODO.md
@@ -1,28 +1,14 @@
|
||||
2. Currently a lot of requests from the API are timing out, but we still process those requests internally. If an API request times out, we should cancel all corresponding tasks to that API request (why process a request with nobody listening).
|
||||
3. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
4. I'd like to see profiled network latency / bandwidth.
|
||||
5. I'd like to see how much bandwidth each link is using.
|
||||
6. We should handle the case where one machine doesn't have the model downloaded and then other machines are waiting on it. In this case we get loads of timeout errors because the others are waiting for the one that needs to download the model.
|
||||
7. Solve the problem of in continuous batching when a new prompt comes in, it will block decode of the current batch until the prefill is complete.
|
||||
8. We want people to be able to copy models over to a new device without ever connecting EXO to the internet. Right now EXO require internet connection once to cache some files to check if a download is complete. Instead, we should simply check if there is a non-empty model folder locally with no .partial files. This indicates it's a fully downloaded model that can be loaded.
|
||||
10. More granular control over how to deploy instances.
|
||||
12. Nix is great but installing it is a pain and we have ended up in a lot of cases having PATH issues or installation issues. For example, after rebooting mike it seemed to no longer have a nix installation and needed reinstalling. It has a bunch of broken symlinks left over from nix that caused ssh to fail, making it even harder to debug. We need consistent environments (perhaps MDM) so we can guarantee nix is installed properly on each machine.
|
||||
13. Memory pressure instead of memory used.
|
||||
14. Show the type of each connection (TB5, Ethernet, etc.) in the UI. Refer to old exo: https://github.com/exo-explore/exo/blob/56f783b38dc6b08ce606b07a5386dc40dae00330/exo/helpers.py#L251
|
||||
15. Prioritise certain connection types (or by latency). TB5 > Ethernet > WiFi. Refer to old exo: https://github.com/exo-explore/exo/blob/56f783b38dc6b08ce606b07a5386dc40dae00330/exo/helpers.py#L251
|
||||
16. Dynamically switch to higher priority connection when it becomes available. Probably bring back InstanceReplacedAtomically.
|
||||
17. Faster model loads by streaming model from other devices in cluster.
|
||||
18. Add support for specifying the type of network connection to use in a test. Depends on 15/16.
|
||||
20. Add chat completion cancellations (e.g OpenWebUI has something for cancelling an ongoing request).
|
||||
23. Do we need cache_limit? We went back and forth on that a lot because we thought it might be causing issues. One problem is it sets it relative to model size. So if you have multiple models loaded in it will take the most recent model size for the cache_limit. This is problematic if you launch DeepSeek -> Llama for example.
|
||||
24. further openai/lmstudio api compatibility
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Validate RDMA connections with ibv_devinfo in the info gatherer
|
||||
|
||||
Potential refactors:
|
||||
|
||||
2. Topology can be simplified
|
||||
|
||||
Random errors we've run into:
|
||||
|
||||
@@ -221,7 +221,6 @@
|
||||
}
|
||||
|
||||
function handleDeleteClick(messageId: string) {
|
||||
if (loading) return;
|
||||
deleteConfirmId = messageId;
|
||||
}
|
||||
|
||||
@@ -252,7 +251,7 @@
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message, i (message.id)}
|
||||
{#each messageList as message (message.id)}
|
||||
<div
|
||||
class="group flex {message.role === 'user'
|
||||
? 'justify-end'
|
||||
@@ -314,11 +313,9 @@
|
||||
<!-- Delete confirmation -->
|
||||
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
|
||||
<p class="text-xs text-red-400 mb-3">
|
||||
{#if i === messageList.length - 1}
|
||||
Delete this message?
|
||||
{:else}
|
||||
Delete this message and all messages after it?
|
||||
{/if}
|
||||
Delete this message{message.role === "user"
|
||||
? " and all responses after it"
|
||||
: ""}?
|
||||
</p>
|
||||
<div class="flex gap-2 justify-end">
|
||||
<button
|
||||
@@ -716,13 +713,8 @@
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
onclick={() => handleDeleteClick(message.id)}
|
||||
disabled={loading}
|
||||
class="p-1.5 transition-colors rounded {loading
|
||||
? 'text-exo-light-gray/30 cursor-not-allowed'
|
||||
: 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}"
|
||||
title={loading
|
||||
? "Cannot delete while generating"
|
||||
: "Delete message"}
|
||||
class="p-1.5 text-exo-light-gray hover:text-red-400 transition-colors rounded hover:bg-red-500/10 cursor-pointer"
|
||||
title="Delete message"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
|
||||
@@ -64,8 +64,6 @@
|
||||
"1024x1024",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1365",
|
||||
"1365x1024",
|
||||
];
|
||||
|
||||
const qualityOptions: ImageGenerationParams["quality"][] = [
|
||||
@@ -150,15 +148,6 @@
|
||||
setImageGenerationParams({ guidance: null });
|
||||
}
|
||||
|
||||
function handleNumSyncStepsChange(event: Event) {
|
||||
const value = parseInt((event.target as HTMLInputElement).value, 10);
|
||||
setImageGenerationParams({ numSyncSteps: value });
|
||||
}
|
||||
|
||||
function clearNumSyncSteps() {
|
||||
setImageGenerationParams({ numSyncSteps: null });
|
||||
}
|
||||
|
||||
function handleReset() {
|
||||
resetImageGenerationParams();
|
||||
showAdvanced = false;
|
||||
@@ -168,8 +157,7 @@
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null,
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
|
||||
);
|
||||
</script>
|
||||
|
||||
@@ -590,50 +578,7 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 3: Sync Steps -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span
|
||||
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
|
||||
>SYNC STEPS:</span
|
||||
>
|
||||
<div class="flex items-center gap-2 flex-1 max-w-xs">
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max="100"
|
||||
value={params.numSyncSteps ?? 1}
|
||||
oninput={handleNumSyncStepsChange}
|
||||
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
|
||||
/>
|
||||
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
|
||||
{params.numSyncSteps ?? "--"}
|
||||
</span>
|
||||
{#if params.numSyncSteps !== null}
|
||||
<button
|
||||
type="button"
|
||||
onclick={clearNumSyncSteps}
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
title="Clear"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 4: Negative Prompt -->
|
||||
<!-- Row 3: Negative Prompt -->
|
||||
<div class="flex flex-col gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>NEGATIVE PROMPT:</span
|
||||
|
||||
@@ -286,14 +286,7 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
|
||||
// Image generation params interface matching backend API
|
||||
export interface ImageGenerationParams {
|
||||
// Basic params
|
||||
size:
|
||||
| "512x512"
|
||||
| "768x768"
|
||||
| "1024x1024"
|
||||
| "1024x768"
|
||||
| "768x1024"
|
||||
| "1024x1365"
|
||||
| "1365x1024";
|
||||
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
@@ -305,7 +298,6 @@ export interface ImageGenerationParams {
|
||||
numInferenceSteps: number | null;
|
||||
guidance: number | null;
|
||||
negativePrompt: string | null;
|
||||
numSyncSteps: number | null;
|
||||
// Edit mode params
|
||||
inputFidelity: "low" | "high";
|
||||
}
|
||||
@@ -327,7 +319,6 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
numInferenceSteps: null,
|
||||
guidance: null,
|
||||
negativePrompt: null,
|
||||
numSyncSteps: null,
|
||||
inputFidelity: "low",
|
||||
};
|
||||
|
||||
@@ -2405,9 +2396,7 @@ class AppStore {
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null;
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "");
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
@@ -2432,9 +2421,6 @@ class AppStore {
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
...(params.numSyncSteps !== null && {
|
||||
num_sync_steps: params.numSyncSteps,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -2684,19 +2670,11 @@ class AppStore {
|
||||
formData.append("input_fidelity", params.inputFidelity);
|
||||
|
||||
// Advanced params
|
||||
const hasAdvancedParams =
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null;
|
||||
|
||||
if (hasAdvancedParams) {
|
||||
if (params.seed !== null) {
|
||||
formData.append(
|
||||
"advanced_params",
|
||||
JSON.stringify({
|
||||
...(params.seed !== null && { seed: params.seed }),
|
||||
seed: params.seed,
|
||||
...(params.numInferenceSteps !== null && {
|
||||
num_inference_steps: params.numInferenceSteps,
|
||||
}),
|
||||
@@ -2705,9 +2683,24 @@ class AppStore {
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
...(params.numSyncSteps !== null && {
|
||||
num_sync_steps: params.numSyncSteps,
|
||||
}),
|
||||
);
|
||||
} else if (
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "")
|
||||
) {
|
||||
formData.append(
|
||||
"advanced_params",
|
||||
JSON.stringify({
|
||||
...(params.numInferenceSteps !== null && {
|
||||
num_inference_steps: params.numInferenceSteps,
|
||||
}),
|
||||
...(params.guidance !== null && { guidance: params.guidance }),
|
||||
...(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ dependencies = [
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux==0.15.5",
|
||||
"mflux==0.15.4",
|
||||
"python-multipart>=0.0.21",
|
||||
]
|
||||
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -272,7 +272,6 @@ class AdvancedImageParams(BaseModel):
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
|
||||
negative_prompt: str | None = None
|
||||
num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
|
||||
|
||||
class ImageGenerationTaskParams(BaseModel):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -22,7 +23,7 @@ class ImageModelConfig(BaseModel):
|
||||
block_configs: tuple[TransformerBlockConfig, ...]
|
||||
|
||||
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
|
||||
num_sync_steps: int # Number of sync steps for distributed inference
|
||||
num_sync_steps_factor: float # Fraction of steps for sync phase
|
||||
|
||||
guidance_scale: float | None = None # None or <= 1.0 disables CFG
|
||||
|
||||
@@ -44,3 +45,6 @@ class ImageModelConfig(BaseModel):
|
||||
|
||||
def get_steps_for_quality(self, quality: str) -> int:
|
||||
return self.default_steps[quality]
|
||||
|
||||
def get_num_sync_steps(self, steps: int) -> int:
|
||||
return ceil(steps * self.num_sync_steps_factor)
|
||||
|
||||
@@ -150,10 +150,7 @@ class DistributedImageModel:
|
||||
guidance=guidance_override if guidance_override is not None else 4.0,
|
||||
)
|
||||
|
||||
if advanced_params is not None and advanced_params.num_sync_steps is not None:
|
||||
num_sync_steps = advanced_params.num_sync_steps
|
||||
else:
|
||||
num_sync_steps = self._config.num_sync_steps
|
||||
num_sync_steps = self._config.get_num_sync_steps(steps)
|
||||
|
||||
for result in self._runner.generate_image(
|
||||
runtime_config=config,
|
||||
|
||||
@@ -5,9 +5,7 @@ from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter
|
||||
from exo.worker.engines.image.models.flux import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_KONTEXT_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
FluxKontextModelAdapter,
|
||||
FluxModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen import (
|
||||
@@ -28,16 +26,13 @@ AdapterFactory = Callable[
|
||||
# Registry maps model_family string to adapter factory
|
||||
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
"flux": FluxModelAdapter,
|
||||
"flux-kontext": FluxKontextModelAdapter,
|
||||
"qwen-edit": QwenEditModelAdapter,
|
||||
"qwen": QwenModelAdapter,
|
||||
}
|
||||
|
||||
# Config registry: maps model ID patterns to configs
|
||||
# Order matters: longer/more-specific patterns must come before shorter ones
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-kontext": FLUX_KONTEXT_CONFIG, # Must come before "flux.1-dev" for pattern matching
|
||||
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
|
||||
@@ -66,19 +66,6 @@ class PromptData(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
"""Kontext-style position IDs for image conditioning.
|
||||
|
||||
For FLUX.1-Kontext models, returns position IDs with first_coord=1
|
||||
to distinguish conditioning tokens from generation tokens (first_coord=0).
|
||||
|
||||
Returns:
|
||||
Position IDs array [1, seq_len, 3] for Kontext, None for other models.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
|
||||
from exo.worker.engines.image.models.flux.config import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_KONTEXT_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
)
|
||||
from exo.worker.engines.image.models.flux.kontext_adapter import (
|
||||
FluxKontextModelAdapter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FluxModelAdapter",
|
||||
"FluxKontextModelAdapter",
|
||||
"FLUX_DEV_CONFIG",
|
||||
"FLUX_KONTEXT_CONFIG",
|
||||
"FLUX_SCHNELL_CONFIG",
|
||||
]
|
||||
|
||||
@@ -59,10 +59,6 @@ class FluxPromptData(PromptData):
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
|
||||
@@ -15,7 +15,7 @@ FLUX_SCHNELL_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 1, "medium": 2, "high": 4},
|
||||
num_sync_steps=1,
|
||||
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
|
||||
)
|
||||
|
||||
|
||||
@@ -30,21 +30,5 @@ FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps=4,
|
||||
)
|
||||
|
||||
|
||||
FLUX_KONTEXT_CONFIG = ImageModelConfig(
|
||||
model_family="flux-kontext",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps=4,
|
||||
guidance_scale=4.0,
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
)
|
||||
|
||||
@@ -1,348 +0,0 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, final
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
|
||||
from mflux.models.flux.variants.kontext.kontext_util import KontextUtil
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
RotaryEmbeddings,
|
||||
)
|
||||
from exo.worker.engines.image.models.flux.wrappers import (
|
||||
FluxJointBlockWrapper,
|
||||
FluxSingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FluxKontextPromptData(PromptData):
|
||||
"""Prompt data for FLUX.1-Kontext image editing.
|
||||
|
||||
Stores text embeddings along with conditioning latents and position IDs
|
||||
for the input image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
conditioning_latents: mx.array,
|
||||
kontext_image_ids: mx.array,
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._pooled_prompt_embeds = pooled_prompt_embeds
|
||||
self._conditioning_latents = conditioning_latents
|
||||
self._kontext_image_ids = kontext_image_ids
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""VAE-encoded input image latents for Kontext conditioning."""
|
||||
return self._conditioning_latents
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
"""Position IDs for Kontext conditioning (first_coord=1)."""
|
||||
return self._kontext_image_ids
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Kontext doesn't use CFG, but we return positive data for compatibility."""
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
None,
|
||||
self._pooled_prompt_embeds,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
# Kontext doesn't use CFG
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
class FluxKontextModelAdapter(ModelAdapter[Flux1Kontext, Transformer]):
|
||||
"""Adapter for FLUX.1-Kontext image editing model.
|
||||
|
||||
Key differences from standard FluxModelAdapter:
|
||||
- Takes an input image and computes output dimensions from it
|
||||
- Creates conditioning latents from the input image via VAE
|
||||
- Creates special position IDs (kontext_image_ids) for conditioning tokens
|
||||
- Creates pure noise latents (not img2img blending)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = Flux1Kontext(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
model_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
# Stores image path and computed dimensions after set_image_dimensions
|
||||
self._image_path: str | None = None
|
||||
self._output_height: int | None = None
|
||||
self._output_width: int | None = None
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.x_embedder.weight.shape[0] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
return False
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return FluxLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper[Any]]:
|
||||
"""Create wrapped joint blocks for Flux Kontext."""
|
||||
return [
|
||||
FluxJointBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper[Any]]:
|
||||
"""Create wrapped single blocks for Flux Kontext."""
|
||||
return [
|
||||
FluxSingleBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.single_transformer_blocks
|
||||
]
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
all_joint = list(self._transformer.transformer_blocks)
|
||||
all_single = list(self._transformer.single_transformer_blocks)
|
||||
total_joint_blocks = len(all_joint)
|
||||
if end_layer <= total_joint_blocks:
|
||||
# All assigned are joint blocks
|
||||
joint_start, joint_end = start_layer, end_layer
|
||||
single_start, single_end = 0, 0
|
||||
elif start_layer >= total_joint_blocks:
|
||||
# All assigned are single blocks
|
||||
joint_start, joint_end = 0, 0
|
||||
single_start = start_layer - total_joint_blocks
|
||||
single_end = end_layer - total_joint_blocks
|
||||
else:
|
||||
# Spans both joint and single
|
||||
joint_start, joint_end = start_layer, total_joint_blocks
|
||||
single_start = 0
|
||||
single_end = end_layer - total_joint_blocks
|
||||
|
||||
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
|
||||
self._transformer.single_transformer_blocks = all_single[
|
||||
single_start:single_end
|
||||
]
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
|
||||
"""Compute and store dimensions from input image.
|
||||
|
||||
Also stores image_path for use in encode_prompt().
|
||||
|
||||
Args:
|
||||
image_path: Path to the input image
|
||||
|
||||
Returns:
|
||||
(output_width, output_height) for runtime config
|
||||
"""
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
|
||||
image_size = pil_image.size
|
||||
|
||||
# Compute output dimensions from input image aspect ratio
|
||||
# Target area of 1024x1024 = ~1M pixels
|
||||
target_area = 1024 * 1024
|
||||
ratio = image_size[0] / image_size[1]
|
||||
output_width = math.sqrt(target_area * ratio)
|
||||
output_height = output_width / ratio
|
||||
output_width = round(output_width / 32) * 32
|
||||
output_height = round(output_height / 32) * 32
|
||||
|
||||
# Ensure multiple of 16 for VAE
|
||||
vae_scale_factor = 8
|
||||
multiple_of = vae_scale_factor * 2
|
||||
output_width = output_width // multiple_of * multiple_of
|
||||
output_height = output_height // multiple_of * multiple_of
|
||||
|
||||
self._image_path = str(image_path)
|
||||
self._output_width = int(output_width)
|
||||
self._output_height = int(output_height)
|
||||
|
||||
return self._output_width, self._output_height
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
|
||||
"""Create initial noise latents for Kontext.
|
||||
|
||||
Unlike standard img2img which blends noise with encoded input,
|
||||
Kontext uses pure noise latents. The input image is provided
|
||||
separately as conditioning.
|
||||
"""
|
||||
return FluxLatentCreator.create_noise(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> FluxKontextPromptData:
|
||||
"""Encode prompt and create conditioning from stored input image.
|
||||
|
||||
Must call set_image_dimensions() before this method.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for editing
|
||||
negative_prompt: Ignored (Kontext doesn't use CFG)
|
||||
|
||||
Returns:
|
||||
FluxKontextPromptData with text embeddings and image conditioning
|
||||
"""
|
||||
del negative_prompt # Kontext doesn't support negative prompts or CFG
|
||||
|
||||
if (
|
||||
self._image_path is None
|
||||
or self._output_height is None
|
||||
or self._output_width is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"set_image_dimensions() must be called before encode_prompt() "
|
||||
"for FluxKontextModelAdapter"
|
||||
)
|
||||
|
||||
assert isinstance(self.model.prompt_cache, dict)
|
||||
assert isinstance(self.model.tokenizers, dict)
|
||||
|
||||
# Encode text prompt
|
||||
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_cache=self.model.prompt_cache,
|
||||
t5_tokenizer=self.model.tokenizers["t5"], # pyright: ignore[reportAny]
|
||||
clip_tokenizer=self.model.tokenizers["clip"], # pyright: ignore[reportAny]
|
||||
t5_text_encoder=self.model.t5_text_encoder,
|
||||
clip_text_encoder=self.model.clip_text_encoder,
|
||||
)
|
||||
|
||||
# Create conditioning latents from input image
|
||||
conditioning_latents, kontext_image_ids = (
|
||||
KontextUtil.create_image_conditioning_latents(
|
||||
vae=self.model.vae,
|
||||
height=self._output_height,
|
||||
width=self._output_width,
|
||||
image_path=self._image_path,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxKontextPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
conditioning_latents=conditioning_latents,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.x_embedder(hidden_states)
|
||||
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
if pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"pooled_prompt_embeds is required for Flux Kontext text embeddings"
|
||||
)
|
||||
|
||||
return Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
|
||||
)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> RotaryEmbeddings:
|
||||
return Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
self._transformer.pos_embed,
|
||||
runtime_config,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
raise NotImplementedError("Flux Kontext does not use classifier-free guidance")
|
||||
@@ -69,10 +69,6 @@ class QwenPromptData(PromptData):
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
|
||||
@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps=7,
|
||||
num_sync_steps_factor=0.25,
|
||||
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps=7,
|
||||
num_sync_steps_factor=0.25,
|
||||
guidance_scale=3.5,
|
||||
)
|
||||
|
||||
@@ -85,10 +85,6 @@ class QwenEditPromptData(PromptData):
|
||||
def qwen_image_ids(self) -> mx.array:
|
||||
return self._qwen_image_ids
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_edit_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -567,7 +567,6 @@ class DiffusionRunner:
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
conditioning_latents: mx.array | None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Run a single forward pass through the transformer.
|
||||
Args:
|
||||
@@ -579,7 +578,6 @@ class DiffusionRunner:
|
||||
encoder_hidden_states_mask: Attention mask for text (Qwen)
|
||||
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
|
||||
conditioning_latents: Conditioning latents for edit mode
|
||||
kontext_image_ids: Position IDs for Kontext conditioning (Flux Kontext)
|
||||
|
||||
Returns:
|
||||
Noise prediction tensor
|
||||
@@ -612,7 +610,6 @@ class DiffusionRunner:
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
@@ -684,7 +681,6 @@ class DiffusionRunner:
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
@@ -704,7 +700,6 @@ class DiffusionRunner:
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=branch.cond_latents,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
@@ -907,10 +902,10 @@ class DiffusionRunner:
|
||||
config: Config,
|
||||
hidden_states: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
|
||||
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
|
||||
@@ -984,10 +979,10 @@ class DiffusionRunner:
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
is_first_async_step: bool,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
|
||||
|
||||
@@ -386,15 +386,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Glm4MoeModel):
|
||||
tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Qwen3NextModel)):
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -520,6 +512,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
|
||||
# layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
# layer.self_attn.kv_b_proj
|
||||
# )
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
@@ -549,7 +544,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
mx.eval(layer)
|
||||
@@ -557,9 +552,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedMoE(CustomMlxLayer):
|
||||
"""Wraps any MoE layer with distributed sum_gradients / all_sum."""
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
@@ -630,13 +623,27 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
|
||||
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # type: ignore
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
@@ -749,7 +756,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
return model
|
||||
@@ -854,7 +861,9 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
|
||||
if isinstance(
|
||||
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
|
||||
):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
@@ -866,7 +875,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.shared_expert.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
|
||||
layer.mlp = ShardedMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -879,50 +888,18 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Glm4MoeModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
class ShardedQwenMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
|
||||
if isinstance(layer.mlp, MoE):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
layer.mlp = ShardedMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
else:
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -960,7 +937,21 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -290,6 +290,7 @@ def make_kv_cache(
|
||||
) -> KVCacheType:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache"):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
@@ -1,377 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download an mflux model, quantize it, and upload to HuggingFace.
|
||||
|
||||
Usage (run from mflux project directory):
|
||||
cd /path/to/mflux
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
|
||||
|
||||
Requires:
|
||||
- Must be run from mflux project directory using `uv run`
|
||||
- huggingface_hub installed (add to mflux deps or install separately)
|
||||
- HuggingFace authentication: run `huggingface-cli login` or set HF_TOKEN
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
|
||||
|
||||
HF_ORG = "exolabs"
|
||||
|
||||
|
||||
def get_model_class(model_name: str) -> type:
|
||||
"""Get the appropriate model class based on model name."""
|
||||
from mflux.models.fibo.variants.txt2img.fibo import FIBO
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
from mflux.models.z_image.variants.turbo.z_image_turbo import ZImageTurbo
|
||||
|
||||
model_name_lower = model_name.lower()
|
||||
if "qwen" in model_name_lower:
|
||||
return QwenImage
|
||||
elif "fibo" in model_name_lower:
|
||||
return FIBO
|
||||
elif "z-image" in model_name_lower or "zimage" in model_name_lower:
|
||||
return ZImageTurbo
|
||||
elif "flux2" in model_name_lower or "flux.2" in model_name_lower:
|
||||
return Flux2Klein
|
||||
else:
|
||||
return Flux1
|
||||
|
||||
|
||||
def get_repo_name(model_name: str, bits: int | None) -> str:
|
||||
"""Get the HuggingFace repo name for a model variant."""
|
||||
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
|
||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
suffix = f"-{bits}bit" if bits else ""
|
||||
return f"{HF_ORG}/{base_name}{suffix}"
|
||||
|
||||
|
||||
def get_local_path(output_dir: Path, model_name: str, bits: int | None) -> Path:
|
||||
"""Get the local save path for a model variant."""
|
||||
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
|
||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
suffix = f"-{bits}bit" if bits else ""
|
||||
return output_dir / f"{base_name}{suffix}"
|
||||
|
||||
|
||||
def copy_source_repo(
|
||||
source_repo: str,
|
||||
local_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Copy all files from source repo (replicating original HF structure)."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Copying full repo from source: {source_repo}")
|
||||
print(f"Output path: {local_path}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would download all files from source repo")
|
||||
return
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download all files to our local path
|
||||
snapshot_download(
|
||||
repo_id=source_repo,
|
||||
local_dir=local_path,
|
||||
)
|
||||
|
||||
# Remove root-level safetensors files (flux.1-dev.safetensors, etc.)
|
||||
# These are redundant with the component directories
|
||||
for f in local_path.glob("*.safetensors"):
|
||||
print(f"Removing root-level safetensors: {f.name}")
|
||||
if not dry_run:
|
||||
f.unlink()
|
||||
|
||||
print(f"Source repo copied to {local_path}")
|
||||
|
||||
|
||||
def load_and_save_quantized_model(
|
||||
model_name: str,
|
||||
bits: int,
|
||||
output_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Load a model with quantization and save it in mflux format."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Loading {model_name} with {bits}-bit quantization...")
|
||||
print(f"Output path: {output_path}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would load and save quantized model")
|
||||
return
|
||||
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
|
||||
model_class = get_model_class(model_name)
|
||||
model_config = ModelConfig.from_name(model_name=model_name, base_model=None)
|
||||
|
||||
model: Flux1 = model_class(
|
||||
quantize=bits,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
print(f"Saving model to {output_path}...")
|
||||
model.save_model(str(output_path))
|
||||
print(f"Model saved successfully to {output_path}")
|
||||
|
||||
|
||||
def copy_source_metadata(
|
||||
source_repo: str,
|
||||
local_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Copy metadata files (LICENSE, README, etc.) from source repo, excluding safetensors."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Copying metadata from source repo: {source_repo}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would download metadata files (excluding *.safetensors)")
|
||||
return
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download all files except safetensors to our local path
|
||||
snapshot_download(
|
||||
repo_id=source_repo,
|
||||
local_dir=local_path,
|
||||
ignore_patterns=["*.safetensors"],
|
||||
)
|
||||
print(f"Metadata files copied to {local_path}")
|
||||
|
||||
|
||||
def upload_to_huggingface(
|
||||
local_path: Path,
|
||||
repo_id: str,
|
||||
dry_run: bool = False,
|
||||
clean_remote: bool = False,
|
||||
) -> None:
|
||||
"""Upload a saved model to HuggingFace."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Uploading to HuggingFace: {repo_id}")
|
||||
print(f"Local path: {local_path}")
|
||||
print(f"Clean remote first: {clean_remote}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would upload to HuggingFace")
|
||||
return
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Create the repo if it doesn't exist
|
||||
print(f"Creating/verifying repo: {repo_id}")
|
||||
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
|
||||
|
||||
# Clean remote repo if requested (delete old mflux-format files)
|
||||
if clean_remote:
|
||||
print("Cleaning old mflux-format files from remote...")
|
||||
try:
|
||||
# Pattern for mflux numbered shards: <dir>/<number>.safetensors
|
||||
numbered_pattern = re.compile(r".*/\d+\.safetensors$")
|
||||
|
||||
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="model")
|
||||
for file_path in repo_files:
|
||||
# Delete numbered safetensors (mflux format) and mflux index files
|
||||
if numbered_pattern.match(file_path) or file_path.endswith(
|
||||
"/model.safetensors.index.json"
|
||||
):
|
||||
print(f" Deleting: {file_path}")
|
||||
api.delete_file(
|
||||
path_in_repo=file_path, repo_id=repo_id, repo_type="model"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not clean remote files: {e}")
|
||||
|
||||
# Upload the folder
|
||||
print("Uploading folder contents...")
|
||||
api.upload_folder(
|
||||
folder_path=str(local_path),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f"Upload complete: https://huggingface.co/{repo_id}")
|
||||
|
||||
|
||||
def clean_local_files(local_path: Path, dry_run: bool = False) -> None:
|
||||
"""Remove local model files after upload."""
|
||||
print(f"\nCleaning up: {local_path}")
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would remove local files")
|
||||
return
|
||||
|
||||
if local_path.exists():
|
||||
shutil.rmtree(local_path)
|
||||
print(f"Removed {local_path}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download an mflux model, quantize it, and upload to HuggingFace.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Process all variants (base, 4-bit, 8-bit) for FLUX.1-Kontext-dev
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
|
||||
|
||||
# Only process 4-bit variant
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
|
||||
|
||||
# Save locally without uploading
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-upload
|
||||
|
||||
# Preview what would happen
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
required=True,
|
||||
help="HuggingFace model path (e.g., black-forest-labs/FLUX.1-Kontext-dev)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("./tmp/models"),
|
||||
help="Local directory to save models (default: ./tmp/models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-base",
|
||||
action="store_true",
|
||||
help="Skip base model (no quantization)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-4bit",
|
||||
action="store_true",
|
||||
help="Skip 4-bit quantized model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-8bit",
|
||||
action="store_true",
|
||||
help="Skip 8-bit quantized model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-download",
|
||||
action="store_true",
|
||||
help="Skip downloading/processing, only do upload/clean operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload",
|
||||
action="store_true",
|
||||
help="Only save locally, don't upload to HuggingFace",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean",
|
||||
action="store_true",
|
||||
help="Remove local files after upload",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean-remote",
|
||||
action="store_true",
|
||||
help="Delete old mflux-format files from remote repo before uploading",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print actions without executing",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine which variants to process
|
||||
variants: list[int | None] = []
|
||||
if not args.skip_base:
|
||||
variants.append(None) # Base model (no quantization)
|
||||
if not args.skip_4bit:
|
||||
variants.append(4)
|
||||
if not args.skip_8bit:
|
||||
variants.append(8)
|
||||
|
||||
if not variants:
|
||||
print("Error: All variants skipped. Nothing to do.")
|
||||
return 1
|
||||
|
||||
# Create output directory
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Output directory: {args.output_dir}")
|
||||
print(
|
||||
f"Variants to process: {['base' if v is None else f'{v}-bit' for v in variants]}"
|
||||
)
|
||||
print(f"Upload to HuggingFace: {not args.skip_upload}")
|
||||
print(f"Clean after upload: {args.clean}")
|
||||
if args.dry_run:
|
||||
print("\n*** DRY RUN MODE - No actual changes will be made ***")
|
||||
|
||||
# Process each variant
|
||||
for bits in variants:
|
||||
local_path = get_local_path(args.output_dir, args.model, bits)
|
||||
repo_id = get_repo_name(args.model, bits)
|
||||
|
||||
if not args.skip_download:
|
||||
if bits is None:
|
||||
# Base model: copy original HF repo structure (no mflux conversion)
|
||||
copy_source_repo(
|
||||
source_repo=args.model,
|
||||
local_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
else:
|
||||
# Quantized model: load, quantize, and save with mflux
|
||||
load_and_save_quantized_model(
|
||||
model_name=args.model,
|
||||
bits=bits,
|
||||
output_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
# Copy metadata from source repo (LICENSE, README, etc.)
|
||||
copy_source_metadata(
|
||||
source_repo=args.model,
|
||||
local_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
# Upload
|
||||
if not args.skip_upload:
|
||||
upload_to_huggingface(
|
||||
local_path=local_path,
|
||||
repo_id=repo_id,
|
||||
dry_run=args.dry_run,
|
||||
clean_remote=args.clean_remote,
|
||||
)
|
||||
|
||||
# Clean up if requested
|
||||
if args.clean:
|
||||
clean_local_files(local_path, dry_run=args.dry_run)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All done!")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
20
uv.lock
generated
20
uv.lock
generated
@@ -192,14 +192,20 @@ sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446, upload-time = "2025-09-08T23:23:03.472Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101, upload-time = "2025-09-08T23:23:04.792Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948, upload-time = "2025-09-08T23:23:06.127Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422, upload-time = "2025-09-08T23:23:07.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499, upload-time = "2025-09-08T23:23:09.648Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928, upload-time = "2025-09-08T23:23:10.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302, upload-time = "2025-09-08T23:23:12.42Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049, upload-time = "2025-09-08T23:23:20.853Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793, upload-time = "2025-09-08T23:23:22.08Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300, upload-time = "2025-09-08T23:23:23.314Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244, upload-time = "2025-09-08T23:23:24.541Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828, upload-time = "2025-09-08T23:23:26.143Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926, upload-time = "2025-09-08T23:23:27.873Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593, upload-time = "2025-09-08T23:23:31.91Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354, upload-time = "2025-09-08T23:23:33.214Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480, upload-time = "2025-09-08T23:23:34.495Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584, upload-time = "2025-09-08T23:23:36.096Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443, upload-time = "2025-09-08T23:23:37.328Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437, upload-time = "2025-09-08T23:23:38.945Z" },
|
||||
@@ -305,8 +311,10 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/49/498c86566a1d80e978b42f0d702795f69887005548c041636df6ae1ca64c/cryptography-46.0.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ca9ff2885f3acc98c29f1860552e37f6d7c7d013d7334ff2a9de43a449315d", size = 4450807, upload-time = "2025-10-15T23:16:56.414Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/0a/863a3604112174c8624a2ac3c038662d9e59970c7f926acdcfaed8d61142/cryptography-46.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6eae65d4c3d33da080cff9c4ab1f711b15c1d9760809dad6ea763f3812d254cb", size = 4299615, upload-time = "2025-10-15T23:16:58.442Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/64/02/b73a533f6b64a69f3cd3872acb6ebc12aef924d8d103133bb3ea750dc703/cryptography-46.0.3-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5bf0ed4490068a2e72ac03d786693adeb909981cc596425d09032d372bcc849", size = 4016800, upload-time = "2025-10-15T23:17:00.378Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/d5/16e41afbfa450cde85a3b7ec599bebefaef16b5c6ba4ec49a3532336ed72/cryptography-46.0.3-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5ecfccd2329e37e9b7112a888e76d9feca2347f12f37918facbb893d7bb88ee8", size = 4984707, upload-time = "2025-10-15T23:17:01.98Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a2c0cd47381a3229c403062f764160d57d4d175e022c1df84e168c6251a22eec", size = 4482541, upload-time = "2025-10-15T23:17:04.078Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/f6/50736d40d97e8483172f1bb6e698895b92a223dba513b0ca6f06b2365339/cryptography-46.0.3-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:549e234ff32571b1f4076ac269fcce7a808d3bf98b76c8dd560e42dbc66d7d91", size = 4299464, upload-time = "2025-10-15T23:17:05.483Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/de/d8e26b1a855f19d9994a19c702fa2e93b0456beccbcfe437eda00e0701f2/cryptography-46.0.3-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:c0a7bb1a68a5d3471880e264621346c48665b3bf1c3759d682fc0864c540bd9e", size = 4950838, upload-time = "2025-10-15T23:17:07.425Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/29/798fc4ec461a1c9e9f735f2fc58741b0daae30688f41b2497dcbc9ed1355/cryptography-46.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:10b01676fc208c3e6feeb25a8b83d81767e8059e1fe86e1dc62d10a3018fa926", size = 4481596, upload-time = "2025-10-15T23:17:09.343Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/8d/03cd48b20a573adfff7652b76271078e3045b9f49387920e7f1f631d125e/cryptography-46.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0abf1ffd6e57c67e92af68330d05760b7b7efb243aab8377e583284dbab72c71", size = 4426782, upload-time = "2025-10-15T23:17:11.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/b1/ebacbfe53317d55cf33165bda24c86523497a6881f339f9aae5c2e13e57b/cryptography-46.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a04bee9ab6a4da801eb9b51f1b708a1b5b5c9eb48c03f74198464c66f0d344ac", size = 4698381, upload-time = "2025-10-15T23:17:12.829Z" },
|
||||
@@ -314,8 +322,10 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/fd/bc1daf8230eaa075184cbbf5f8cd00ba9db4fd32d63fb83da4671b72ed8a/cryptography-46.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39b6755623145ad5eff1dab323f4eae2a32a77a7abef2c5089a04a3d04366715", size = 4435078, upload-time = "2025-10-15T23:17:23.042Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/98/d3bd5407ce4c60017f8ff9e63ffee4200ab3e23fe05b765cab805a7db008/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:db391fa7c66df6762ee3f00c95a89e6d428f4d60e7abc8328f4fe155b5ac6e54", size = 4293460, upload-time = "2025-10-15T23:17:24.885Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/e9/e23e7900983c2b8af7a08098db406cf989d7f09caea7897e347598d4cd5b/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:78a97cf6a8839a48c49271cdcbd5cf37ca2c1d6b7fdd86cc864f302b5e9bf459", size = 3995237, upload-time = "2025-10-15T23:17:26.449Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/15/af68c509d4a138cfe299d0d7ddb14afba15233223ebd933b4bbdbc7155d3/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:dfb781ff7eaa91a6f7fd41776ec37c5853c795d3b358d4896fdbb5df168af422", size = 4967344, upload-time = "2025-10-15T23:17:28.06Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/e3/8643d077c53868b681af077edf6b3cb58288b5423610f21c62aadcbe99f4/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6f61efb26e76c45c4a227835ddeae96d83624fb0d29eb5df5b96e14ed1a0afb7", size = 4466564, upload-time = "2025-10-15T23:17:29.665Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/43/c1e8726fa59c236ff477ff2b5dc071e54b21e5a1e51aa2cee1676f1c986f/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:23b1a8f26e43f47ceb6d6a43115f33a5a37d57df4ea0ca295b780ae8546e8044", size = 4292415, upload-time = "2025-10-15T23:17:31.686Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/f9/2f8fefdb1aee8a8e3256a0568cffc4e6d517b256a2fe97a029b3f1b9fe7e/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:b419ae593c86b87014b9be7396b385491ad7f320bde96826d0dd174459e54665", size = 4931457, upload-time = "2025-10-15T23:17:33.478Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/30/9b54127a9a778ccd6d27c3da7563e9f2d341826075ceab89ae3b41bf5be2/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:50fc3343ac490c6b08c0cf0d704e881d0d660be923fd3076db3e932007e726e3", size = 4466074, upload-time = "2025-10-15T23:17:35.158Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/68/b4f4a10928e26c941b1b6a179143af9f4d27d88fe84a6a3c53592d2e76bf/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22d7e97932f511d6b0b04f2bfd818d73dcd5928db509460aaf48384778eb6d20", size = 4420569, upload-time = "2025-10-15T23:17:37.188Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/49/3746dab4c0d1979888f125226357d3262a6dd40e114ac29e3d2abdf1ec55/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d55f3dffadd674514ad19451161118fd010988540cee43d8bc20675e775925de", size = 4681941, upload-time = "2025-10-15T23:17:39.236Z" },
|
||||
@@ -323,8 +333,10 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/26/42/fa8389d4478368743e24e61eea78846a0006caffaf72ea24a15159215a14/cryptography-46.0.3-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:15ab9b093e8f09daab0f2159bb7e47532596075139dd74365da52ecc9cb46c5d", size = 4440029, upload-time = "2025-10-15T23:17:49.837Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/eb/f483db0ec5ac040824f269e93dd2bd8a21ecd1027e77ad7bdf6914f2fd80/cryptography-46.0.3-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:46acf53b40ea38f9c6c229599a4a13f0d46a6c3fa9ef19fc1a124d62e338dfa0", size = 4297222, upload-time = "2025-10-15T23:17:51.357Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/cf/da9502c4e1912cb1da3807ea3618a6829bee8207456fbbeebc361ec38ba3/cryptography-46.0.3-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10ca84c4668d066a9878890047f03546f3ae0a6b8b39b697457b7757aaf18dbc", size = 4012280, upload-time = "2025-10-15T23:17:52.964Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/8f/9adb86b93330e0df8b3dcf03eae67c33ba89958fc2e03862ef1ac2b42465/cryptography-46.0.3-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:36e627112085bb3b81b19fed209c05ce2a52ee8b15d161b7c643a7d5a88491f3", size = 4978958, upload-time = "2025-10-15T23:17:54.965Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/a0/5fa77988289c34bdb9f913f5606ecc9ada1adb5ae870bd0d1054a7021cc4/cryptography-46.0.3-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1000713389b75c449a6e979ffc7dcc8ac90b437048766cef052d4d30b8220971", size = 4473714, upload-time = "2025-10-15T23:17:56.754Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/e5/fc82d72a58d41c393697aa18c9abe5ae1214ff6f2a5c18ac470f92777895/cryptography-46.0.3-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:b02cf04496f6576afffef5ddd04a0cb7d49cf6be16a9059d793a30b035f6b6ac", size = 4296970, upload-time = "2025-10-15T23:17:58.588Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/06/5663ed35438d0b09056973994f1aec467492b33bd31da36e468b01ec1097/cryptography-46.0.3-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:71e842ec9bc7abf543b47cf86b9a743baa95f4677d22baa4c7d5c69e49e9bc04", size = 4940236, upload-time = "2025-10-15T23:18:00.897Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/59/873633f3f2dcd8a053b8dd1d38f783043b5fce589c0f6988bf55ef57e43e/cryptography-46.0.3-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:402b58fc32614f00980b66d6e56a5b4118e6cb362ae8f3fda141ba4689bd4506", size = 4472642, upload-time = "2025-10-15T23:18:02.749Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/39/8e71f3930e40f6877737d6f69248cf74d4e34b886a3967d32f919cc50d3b/cryptography-46.0.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ef639cb3372f69ec44915fafcd6698b6cc78fbe0c2ea41be867f6ed612811963", size = 4423126, upload-time = "2025-10-15T23:18:04.85Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573, upload-time = "2025-10-15T23:18:06.908Z" },
|
||||
@@ -400,7 +412,7 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.5" },
|
||||
{ name = "mflux", specifier = "==0.15.4" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.5" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.5" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
@@ -975,7 +987,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mflux"
|
||||
version = "0.15.5"
|
||||
version = "0.15.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1001,9 +1013,9 @@ dependencies = [
|
||||
{ name = "twine", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/8e/f20de51bf9dc0a986535d9a825db4ae314163421b3d3ddaa90a2b959b9fd/mflux-0.15.5.tar.gz", hash = "sha256:9a3372bd64d51c4caff4ff9e7d7d698bea5833242fd849c59cbb0c92f7d7aa3b", size = 743700, upload-time = "2026-01-26T12:41:45.272Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a6/f8/95322db7a865e4df6bad108b1c99aa7fbe211aac3f298f3ad696c2744a39/mflux-0.15.4.tar.gz", hash = "sha256:138e1aedae86e13eafeb8faec017945fcdcca42c3234daabcd81a83c9a202ace", size = 741228, upload-time = "2026-01-20T15:39:26.807Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/bb/ef936eae2ae78a47cd92ddffc18fc06ad3fd5f438a0915fb62d8bb9508ec/mflux-0.15.5-py3-none-any.whl", hash = "sha256:c94891d4a518047a818863bb099c755e93af90c524ced358baf5b31502c09e82", size = 990939, upload-time = "2026-01-26T12:41:42.898Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/be/81cf4ce2d1933b9b210c028a05ac95e958008c0d43e377a5f2757b7f2d4d/mflux-0.15.4-py3-none-any.whl", hash = "sha256:f04d9b1d7c5cd67880f483ab29fb2097648a25459eef9c5ee6480fad46de5e82", size = 987644, upload-time = "2026-01-20T15:39:24.817Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user