mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-06 20:21:39 -05:00
Compare commits
14 Commits
rust-explo
...
ciaran/req
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9f063db99 | ||
|
|
3b2f553a25 | ||
|
|
5455a97a8c | ||
|
|
6f0cb99919 | ||
|
|
c8d3154f83 | ||
|
|
63e9cc4fea | ||
|
|
9b5cae3db6 | ||
|
|
cf7201f91e | ||
|
|
b315035ae0 | ||
|
|
c8dbbee27b | ||
|
|
f0107e9670 | ||
|
|
9f502793c1 | ||
|
|
c8371349d5 | ||
|
|
6b907398a4 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -32,3 +32,9 @@ dashboard/.svelte-kit/
|
||||
# host config snapshots
|
||||
hosts_*.json
|
||||
.swp
|
||||
|
||||
# bench files
|
||||
bench/**/*.json
|
||||
|
||||
# tmp
|
||||
tmp/models
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
|
||||
|
||||
__all__ = ["Flux1Kontext"]
|
||||
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from mlx import nn
|
||||
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
|
||||
CLIPEncoder,
|
||||
)
|
||||
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.model.flux_vae.vae import VAE
|
||||
from mflux.utils.generated_image import GeneratedImage
|
||||
|
||||
class Flux1Kontext(nn.Module):
|
||||
vae: VAE
|
||||
transformer: Transformer
|
||||
t5_text_encoder: T5Encoder
|
||||
clip_text_encoder: CLIPEncoder
|
||||
bits: int | None
|
||||
lora_paths: list[str] | None
|
||||
lora_scales: list[float] | None
|
||||
prompt_cache: dict[str, Any]
|
||||
tokenizers: dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quantize: int | None = ...,
|
||||
model_path: str | None = ...,
|
||||
lora_paths: list[str] | None = ...,
|
||||
lora_scales: list[float] | None = ...,
|
||||
model_config: ModelConfig = ...,
|
||||
) -> None: ...
|
||||
def generate_image(
|
||||
self,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
num_inference_steps: int = ...,
|
||||
height: int = ...,
|
||||
width: int = ...,
|
||||
guidance: float = ...,
|
||||
image_path: Path | str | None = ...,
|
||||
image_strength: float | None = ...,
|
||||
scheduler: str = ...,
|
||||
) -> GeneratedImage: ...
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mflux.models.flux.model.flux_vae.vae import VAE
|
||||
|
||||
class KontextUtil:
|
||||
@staticmethod
|
||||
def create_image_conditioning_latents(
|
||||
vae: VAE,
|
||||
height: int,
|
||||
width: int,
|
||||
image_path: str,
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@@ -1139,7 +1139,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`flatten`."""
|
||||
|
||||
def reshape(self, *shape, stream: Stream | Device | None = ...) -> array:
|
||||
def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`reshape` but the shape can be passed either as a
|
||||
:obj:`tuple` or as separate arguments.
|
||||
@@ -1222,7 +1222,7 @@ class array:
|
||||
) -> array:
|
||||
"""See :func:`swapaxes`."""
|
||||
|
||||
def transpose(self, *axes, stream: Stream | Device | None = ...) -> array:
|
||||
def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Equivalent to :func:`transpose` but the axes can be passed either as
|
||||
a tuple or as separate arguments.
|
||||
|
||||
@@ -30,6 +30,9 @@ class Conv1d(Module):
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
weight: mx.array
|
||||
groups: int
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
||||
@@ -11,7 +11,10 @@ import mlx.core as mx
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
|
||||
offset: int
|
||||
def update_and_fetch(
|
||||
self, keys: mx.array, values: mx.array
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
@@ -87,6 +90,7 @@ def create_attention_mask(
|
||||
class _BaseCache(Cache):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
offset: int
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.mla import MultiLinear
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
@@ -60,7 +61,10 @@ class DeepseekV3Attention(nn.Module):
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
# kv_b_proj: nn.Linear
|
||||
embed_q: MultiLinear
|
||||
unembed_out: MultiLinear
|
||||
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
|
||||
153
.mlx_typings/mlx_lm/models/glm4_moe.pyi
Normal file
153
.mlx_typings/mlx_lm/models/glm4_moe.pyi
Normal file
@@ -0,0 +1,153 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
max_position_embeddings: int
|
||||
moe_intermediate_size: int
|
||||
norm_topk_prob: bool
|
||||
num_attention_heads: int
|
||||
n_group: int
|
||||
head_dim: int
|
||||
topk_group: int
|
||||
n_shared_experts: int
|
||||
n_routed_experts: int
|
||||
routed_scaling_factor: float
|
||||
num_experts_per_tok: int
|
||||
first_k_dense_replace: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
use_qk_norm: bool
|
||||
tie_word_embeddings: bool
|
||||
attention_bias: bool
|
||||
partial_rotary_factor: float
|
||||
scoring_func: str
|
||||
topk_method: str
|
||||
|
||||
class Attention(nn.Module):
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
use_qk_norm: bool
|
||||
q_norm: nn.RMSNorm
|
||||
k_norm: nn.RMSNorm
|
||||
rope: nn.RoPE
|
||||
|
||||
def __init__(self, args: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class MLP(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
config: ModelArgs
|
||||
top_k: int
|
||||
norm_topk_prob: bool
|
||||
n_routed_experts: int
|
||||
routed_scaling_factor: float
|
||||
n_group: int
|
||||
topk_group: int
|
||||
weight: mx.array
|
||||
e_score_correction_bias: mx.array
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class MoE(nn.Module):
|
||||
config: ModelArgs
|
||||
num_experts_per_tok: int
|
||||
switch_mlp: SwitchGLU
|
||||
gate: MoEGate
|
||||
shared_experts: MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
self_attn: Attention
|
||||
mlp: MLP | MoE
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
vocab_size: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[DecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
pipeline_rank: int
|
||||
pipeline_size: int
|
||||
start_idx: int
|
||||
end_idx: Optional[int]
|
||||
num_layers: int
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
@property
|
||||
def pipeline_layers(self) -> list[DecoderLayer]: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
args: ModelArgs
|
||||
model_type: str
|
||||
model: LanguageModel
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
|
||||
@property
|
||||
def layers(self) -> list[DecoderLayer]: ...
|
||||
@property
|
||||
def cast_predicate(self) -> Any: ...
|
||||
114
.mlx_typings/mlx_lm/models/qwen3_next.pyi
Normal file
114
.mlx_typings/mlx_lm/models/qwen3_next.pyi
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Type stubs for mlx_lm.models.qwen3_next"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
class Qwen3NextMLP(nn.Module):
|
||||
gate_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextGatedDeltaNet(nn.Module):
|
||||
hidden_size: int
|
||||
num_v_heads: int
|
||||
num_k_heads: int
|
||||
head_k_dim: int
|
||||
head_v_dim: int
|
||||
key_dim: int
|
||||
value_dim: int
|
||||
conv_kernel_size: int
|
||||
conv_dim: int
|
||||
conv1d: nn.Conv1d
|
||||
in_proj_qkvz: nn.Linear
|
||||
in_proj_ba: nn.Linear
|
||||
dt_bias: mx.array
|
||||
A_log: mx.array
|
||||
out_proj: nn.Linear
|
||||
|
||||
def __init__(self, config: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
k_proj: nn.Linear
|
||||
v_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
norm_topk_prob: bool
|
||||
num_experts: int
|
||||
top_k: int
|
||||
gate: nn.Linear
|
||||
switch_mlp: SwitchGLU
|
||||
shared_expert: Qwen3NextMLP
|
||||
shared_expert_gate: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class Qwen3NextDecoderLayer(nn.Module):
|
||||
is_linear: bool
|
||||
linear_attn: Qwen3NextGatedDeltaNet
|
||||
self_attn: Qwen3NextAttention
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
mlp: Qwen3NextMLP | Qwen3NextSparseMoeBlock
|
||||
|
||||
def __init__(self, args: Any, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Qwen3NextModel(nn.Module):
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[Qwen3NextDecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: Qwen3NextModel
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, args: Any) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[Qwen3NextDecoderLayer]: ...
|
||||
@@ -113,6 +113,10 @@ class TokenizerWrapper:
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
think_start: str | None
|
||||
think_end: str | None
|
||||
think_start_id: int | None
|
||||
think_end_id: int | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -431,7 +431,12 @@ def main() -> int:
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Pipeline jaccl is often pointless, skip by default",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
@@ -450,6 +455,7 @@ def main() -> int:
|
||||
default="bench/results.json",
|
||||
help="Write raw per-run results JSON to this path.",
|
||||
)
|
||||
ap.add_argument("--stdout", action="store_true", help="Write results to stdout")
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
@@ -533,6 +539,16 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
@@ -652,7 +668,9 @@ def main() -> int:
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
if args.json_out:
|
||||
if args.stdout:
|
||||
json.dump(all_rows, sys.stdout, indent=2, ensure_ascii=False)
|
||||
elif args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump(all_rows, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||||
|
||||
@@ -254,6 +254,7 @@
|
||||
|
||||
function handleSubmit() {
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
|
||||
if (isEditOnlyWithoutImage) return;
|
||||
|
||||
const content = message.trim();
|
||||
const files = [...uploadedFiles];
|
||||
@@ -278,7 +279,11 @@
|
||||
if (imageFile.preview) {
|
||||
editImage(content, imageFile.preview);
|
||||
}
|
||||
} else if (isImageModel() && content) {
|
||||
} else if (
|
||||
currentModel &&
|
||||
modelSupportsTextToImage(currentModel) &&
|
||||
content
|
||||
) {
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
|
||||
@@ -64,6 +64,8 @@
|
||||
"1024x1024",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1365",
|
||||
"1365x1024",
|
||||
];
|
||||
|
||||
const qualityOptions: ImageGenerationParams["quality"][] = [
|
||||
@@ -148,6 +150,15 @@
|
||||
setImageGenerationParams({ guidance: null });
|
||||
}
|
||||
|
||||
function handleNumSyncStepsChange(event: Event) {
|
||||
const value = parseInt((event.target as HTMLInputElement).value, 10);
|
||||
setImageGenerationParams({ numSyncSteps: value });
|
||||
}
|
||||
|
||||
function clearNumSyncSteps() {
|
||||
setImageGenerationParams({ numSyncSteps: null });
|
||||
}
|
||||
|
||||
function handleReset() {
|
||||
resetImageGenerationParams();
|
||||
showAdvanced = false;
|
||||
@@ -157,7 +168,8 @@
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null,
|
||||
);
|
||||
</script>
|
||||
|
||||
@@ -578,7 +590,50 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 3: Negative Prompt -->
|
||||
<!-- Row 3: Sync Steps -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span
|
||||
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
|
||||
>SYNC STEPS:</span
|
||||
>
|
||||
<div class="flex items-center gap-2 flex-1 max-w-xs">
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max="100"
|
||||
value={params.numSyncSteps ?? 1}
|
||||
oninput={handleNumSyncStepsChange}
|
||||
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
|
||||
/>
|
||||
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
|
||||
{params.numSyncSteps ?? "--"}
|
||||
</span>
|
||||
{#if params.numSyncSteps !== null}
|
||||
<button
|
||||
type="button"
|
||||
onclick={clearNumSyncSteps}
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
title="Clear"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 4: Negative Prompt -->
|
||||
<div class="flex flex-col gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>NEGATIVE PROMPT:</span
|
||||
|
||||
@@ -286,7 +286,14 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
|
||||
// Image generation params interface matching backend API
|
||||
export interface ImageGenerationParams {
|
||||
// Basic params
|
||||
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
|
||||
size:
|
||||
| "512x512"
|
||||
| "768x768"
|
||||
| "1024x1024"
|
||||
| "1024x768"
|
||||
| "768x1024"
|
||||
| "1024x1365"
|
||||
| "1365x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
@@ -298,6 +305,7 @@ export interface ImageGenerationParams {
|
||||
numInferenceSteps: number | null;
|
||||
guidance: number | null;
|
||||
negativePrompt: string | null;
|
||||
numSyncSteps: number | null;
|
||||
// Edit mode params
|
||||
inputFidelity: "low" | "high";
|
||||
}
|
||||
@@ -319,6 +327,7 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
numInferenceSteps: null,
|
||||
guidance: null,
|
||||
negativePrompt: null,
|
||||
numSyncSteps: null,
|
||||
inputFidelity: "low",
|
||||
};
|
||||
|
||||
@@ -2396,7 +2405,9 @@ class AppStore {
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "");
|
||||
(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null;
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
@@ -2421,6 +2432,9 @@ class AppStore {
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
...(params.numSyncSteps !== null && {
|
||||
num_sync_steps: params.numSyncSteps,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -2670,29 +2684,19 @@ class AppStore {
|
||||
formData.append("input_fidelity", params.inputFidelity);
|
||||
|
||||
// Advanced params
|
||||
if (params.seed !== null) {
|
||||
formData.append(
|
||||
"advanced_params",
|
||||
JSON.stringify({
|
||||
seed: params.seed,
|
||||
...(params.numInferenceSteps !== null && {
|
||||
num_inference_steps: params.numInferenceSteps,
|
||||
}),
|
||||
...(params.guidance !== null && { guidance: params.guidance }),
|
||||
...(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
} else if (
|
||||
const hasAdvancedParams =
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "")
|
||||
) {
|
||||
(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null;
|
||||
|
||||
if (hasAdvancedParams) {
|
||||
formData.append(
|
||||
"advanced_params",
|
||||
JSON.stringify({
|
||||
...(params.seed !== null && { seed: params.seed }),
|
||||
...(params.numInferenceSteps !== null && {
|
||||
num_inference_steps: params.numInferenceSteps,
|
||||
}),
|
||||
@@ -2701,6 +2705,9 @@ class AppStore {
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
...(params.numSyncSteps !== null && {
|
||||
num_sync_steps: params.numSyncSteps,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
2
justfile
2
justfile
@@ -20,7 +20,7 @@ sync-clean:
|
||||
|
||||
rust-rebuild:
|
||||
cargo run --bin stub_gen
|
||||
just sync-clean
|
||||
uv sync --reinstall-package exo_pyo3_bindings
|
||||
|
||||
build-dashboard:
|
||||
#!/usr/bin/env bash
|
||||
|
||||
@@ -41,7 +41,7 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.4"; in
|
||||
version = let v = "0.30.5"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
@@ -17,22 +17,20 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"mlx==0.30.5; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.5; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux==0.15.4",
|
||||
"mflux==0.15.5",
|
||||
"python-multipart>=0.0.21",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
|
||||
# dependencies only required for development
|
||||
@@ -63,7 +61,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
@@ -105,6 +103,7 @@ root = "src"
|
||||
|
||||
# supported platforms for this project
|
||||
[tool.uv]
|
||||
required-version = ">=0.8.6"
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
|
||||
@@ -59,6 +59,22 @@
|
||||
}
|
||||
);
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ exoVenv ];
|
||||
runtimeEnv = {
|
||||
EXO_DASHBOARD_DIR = self'.packages.dashboard;
|
||||
EXO_RESOURCES_DIR = inputs.self + /resources;
|
||||
};
|
||||
text = ''exec python ${path} "$@"'';
|
||||
};
|
||||
|
||||
mkSimplePythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ pkgs.python313 ];
|
||||
text = ''exec python ${path} "$@"'';
|
||||
};
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
@@ -66,28 +82,30 @@
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + "/resources"} \
|
||||
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
done
|
||||
# Create wrapper script
|
||||
makeWrapper ${exoVenv}/bin/exo $out/bin/exo \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + /resources} \
|
||||
${lib.optionalString pkgs.stdenv.hostPlatform.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin
|
||||
{
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
exo-bench = mkPythonScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
} // {
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-Kontext-dev.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-Kontext-dev.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -16,6 +16,7 @@ from exo.download.download_utils import (
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
@@ -107,6 +108,13 @@ class DownloadCoordinator:
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
case CancelDownload(model_id=model_id):
|
||||
await self._cancel_download(model_id)
|
||||
|
||||
async def _cancel_download(self, model_id: ModelId) -> None:
|
||||
if model_id in self.active_downloads and model_id in self.download_status:
|
||||
logger.info(f"Cancelling download for {model_id}")
|
||||
self.active_downloads.pop(model_id).cancel()
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
@@ -158,6 +158,78 @@ async def seed_models(seed_dir: str | Path):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def _build_file_list_from_local_directory(
|
||||
model_id: ModelId,
|
||||
recursive: bool = False,
|
||||
) -> list[FileListEntry] | None:
|
||||
"""Build a file list from locally existing model files.
|
||||
|
||||
We can only figure out the files we need from safetensors index, so
|
||||
a local directory must contain a *.safetensors.index.json and
|
||||
safetensors listed there.
|
||||
"""
|
||||
model_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
if not await aios.path.exists(model_dir):
|
||||
return None
|
||||
|
||||
def _scan() -> list[FileListEntry] | None:
|
||||
index_files = list(model_dir.glob("**/*.safetensors.index.json"))
|
||||
if not index_files:
|
||||
return None
|
||||
|
||||
entries_by_path: dict[str, FileListEntry] = {}
|
||||
|
||||
if recursive:
|
||||
for dirpath, _, filenames in os.walk(model_dir):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".partial"):
|
||||
continue
|
||||
full_path = Path(dirpath) / filename
|
||||
rel_path = str(full_path.relative_to(model_dir))
|
||||
entries_by_path[rel_path] = FileListEntry(
|
||||
type="file",
|
||||
path=rel_path,
|
||||
size=full_path.stat().st_size,
|
||||
)
|
||||
else:
|
||||
for item in model_dir.iterdir():
|
||||
if item.is_file() and not item.name.endswith(".partial"):
|
||||
entries_by_path[item.name] = FileListEntry(
|
||||
type="file",
|
||||
path=item.name,
|
||||
size=item.stat().st_size,
|
||||
)
|
||||
|
||||
# Add expected weight files from index that haven't been downloaded yet
|
||||
for index_file in index_files:
|
||||
try:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(
|
||||
index_file.read_text()
|
||||
)
|
||||
relative_dir = index_file.parent.relative_to(model_dir)
|
||||
for filename in set(index_data.weight_map.values()):
|
||||
rel_path = (
|
||||
str(relative_dir / filename)
|
||||
if relative_dir != Path(".")
|
||||
else filename
|
||||
)
|
||||
if rel_path not in entries_by_path:
|
||||
entries_by_path[rel_path] = FileListEntry(
|
||||
type="file",
|
||||
path=rel_path,
|
||||
size=None,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return list(entries_by_path.values())
|
||||
|
||||
file_list = await asyncio.to_thread(_scan)
|
||||
if not file_list:
|
||||
return None
|
||||
return file_list
|
||||
|
||||
|
||||
_fetched_file_lists_this_session: set[str] = set()
|
||||
|
||||
|
||||
@@ -183,6 +255,14 @@ async def fetch_file_list_with_cache(
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
local_file_list = await _build_file_list_from_local_directory(
|
||||
model_id, recursive
|
||||
)
|
||||
if local_file_list is not None:
|
||||
logger.warning(
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
)
|
||||
return local_file_list
|
||||
raise FileNotFoundError(
|
||||
f"No internet connection and no cached file list for {model_id}"
|
||||
)
|
||||
@@ -203,10 +283,18 @@ async def fetch_file_list_with_cache(
|
||||
except Exception as e:
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
local_file_list = await _build_file_list_from_local_directory(
|
||||
model_id, recursive
|
||||
)
|
||||
if local_file_list is not None:
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id} and no cache exists, "
|
||||
)
|
||||
return local_file_list
|
||||
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
|
||||
|
||||
|
||||
@@ -378,10 +466,14 @@ async def download_file_with_retry(
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
except Exception as e:
|
||||
on_connection_lost()
|
||||
if attempt == n_attempts - 1:
|
||||
on_connection_lost()
|
||||
raise e
|
||||
break
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
|
||||
@@ -195,6 +195,10 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
_, progress = await download_shard(
|
||||
shard, self.on_progress_wrapper, skip_download=True
|
||||
shard,
|
||||
self.on_progress_wrapper,
|
||||
skip_download=True,
|
||||
skip_internet=not self.internet_connection,
|
||||
on_connection_lost=lambda: self.set_internet_connection(False),
|
||||
)
|
||||
return progress
|
||||
|
||||
@@ -105,6 +105,7 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -188,6 +189,9 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
@@ -141,7 +140,7 @@ async def generate_chat_stream(
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
@@ -207,7 +206,7 @@ async def collect_chat_response(
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
id=tool.id,
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
@@ -179,7 +178,7 @@ async def collect_claude_response(
|
||||
for tool in chunk.tool_calls:
|
||||
tool_use_blocks.append(
|
||||
ClaudeToolUseBlock(
|
||||
id=f"toolu_{uuid4().hex[:24]}",
|
||||
id=f"toolu_{tool.id}",
|
||||
name=tool.name,
|
||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||
)
|
||||
@@ -264,7 +263,7 @@ async def generate_claude_stream(
|
||||
|
||||
# Emit tool_use content blocks
|
||||
for tool in chunk.tool_calls:
|
||||
tool_id = f"toolu_{uuid4().hex[:24]}"
|
||||
tool_id = f"toolu_{tool.id}"
|
||||
tool_input_json = tool.arguments
|
||||
|
||||
# content_block_start for tool_use
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -140,8 +139,8 @@ async def collect_responses_response(
|
||||
for tool in chunk.tool_calls:
|
||||
function_call_items.append(
|
||||
ResponseFunctionCallItem(
|
||||
id=f"fc_{uuid4().hex[:24]}",
|
||||
call_id=f"call_{uuid4().hex[:24]}",
|
||||
id=f"fc_{tool.id}",
|
||||
call_id=f"call_{tool.id}",
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
@@ -246,8 +245,8 @@ async def generate_responses_stream(
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
for tool in chunk.tool_calls:
|
||||
fc_id = f"fc_{uuid4().hex[:24]}"
|
||||
call_id = f"call_{uuid4().hex[:24]}"
|
||||
fc_id = f"fc_{tool.id}"
|
||||
call_id = f"call_{tool.id}"
|
||||
|
||||
# response.output_item.added for function_call
|
||||
fc_item = ResponseFunctionCallItem(
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
add_instance_to_placements,
|
||||
cancel_unnecessary_downloads,
|
||||
delete_instance,
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
@@ -16,6 +17,7 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
@@ -66,12 +68,9 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
# Receiving indexed events from the forwarder to be applied to state
|
||||
# Ideally these would be WorkerForwarderEvents but type system says no :(
|
||||
local_event_receiver: Receiver[ForwarderEvent],
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -81,6 +80,7 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -280,6 +280,14 @@ class Master:
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
for cmd in cancel_unnecessary_downloads(
|
||||
placement, self.state.downloads
|
||||
):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id, command=cmd
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
|
||||
@@ -15,14 +15,20 @@ from exo.master.placement_utils import (
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -202,3 +208,29 @@ def get_transition_events(
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def cancel_unnecessary_downloads(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> Sequence[DownloadCommand]:
|
||||
commands: list[DownloadCommand] = []
|
||||
currently_downloading = [
|
||||
(k, v.shard_metadata.model_card.model_id)
|
||||
for k, vs in download_status.items()
|
||||
for v in vs
|
||||
if isinstance(v, (DownloadOngoing))
|
||||
]
|
||||
active_models = set(
|
||||
(
|
||||
node_id,
|
||||
instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,
|
||||
)
|
||||
for instance in instances.values()
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
)
|
||||
for pair in currently_downloading:
|
||||
if pair not in active_models:
|
||||
commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))
|
||||
|
||||
return commands
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -47,6 +48,7 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +69,7 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
download_command_sender=fcds,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
@@ -60,6 +61,7 @@ class ChatCompletionMessageText(BaseModel):
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
@@ -272,6 +274,7 @@ class AdvancedImageParams(BaseModel):
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
|
||||
negative_prompt: str | None = None
|
||||
num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
|
||||
|
||||
class ImageGenerationTaskParams(BaseModel):
|
||||
|
||||
@@ -72,7 +72,12 @@ class DeleteDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
class CancelDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
Command = (
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -23,7 +22,7 @@ class ImageModelConfig(BaseModel):
|
||||
block_configs: tuple[TransformerBlockConfig, ...]
|
||||
|
||||
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
|
||||
num_sync_steps_factor: float # Fraction of steps for sync phase
|
||||
num_sync_steps: int # Number of sync steps for distributed inference
|
||||
|
||||
guidance_scale: float | None = None # None or <= 1.0 disables CFG
|
||||
|
||||
@@ -45,6 +44,3 @@ class ImageModelConfig(BaseModel):
|
||||
|
||||
def get_steps_for_quality(self, quality: str) -> int:
|
||||
return self.default_steps[quality]
|
||||
|
||||
def get_num_sync_steps(self, steps: int) -> int:
|
||||
return ceil(steps * self.num_sync_steps_factor)
|
||||
|
||||
@@ -150,7 +150,10 @@ class DistributedImageModel:
|
||||
guidance=guidance_override if guidance_override is not None else 4.0,
|
||||
)
|
||||
|
||||
num_sync_steps = self._config.get_num_sync_steps(steps)
|
||||
if advanced_params is not None and advanced_params.num_sync_steps is not None:
|
||||
num_sync_steps = advanced_params.num_sync_steps
|
||||
else:
|
||||
num_sync_steps = self._config.num_sync_steps
|
||||
|
||||
for result in self._runner.generate_image(
|
||||
runtime_config=config,
|
||||
|
||||
@@ -5,7 +5,9 @@ from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter
|
||||
from exo.worker.engines.image.models.flux import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_KONTEXT_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
FluxKontextModelAdapter,
|
||||
FluxModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen import (
|
||||
@@ -26,13 +28,16 @@ AdapterFactory = Callable[
|
||||
# Registry maps model_family string to adapter factory
|
||||
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
"flux": FluxModelAdapter,
|
||||
"flux-kontext": FluxKontextModelAdapter,
|
||||
"qwen-edit": QwenEditModelAdapter,
|
||||
"qwen": QwenModelAdapter,
|
||||
}
|
||||
|
||||
# Config registry: maps model ID patterns to configs
|
||||
# Order matters: longer/more-specific patterns must come before shorter ones
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-kontext": FLUX_KONTEXT_CONFIG, # Must come before "flux.1-dev" for pattern matching
|
||||
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
|
||||
@@ -66,6 +66,19 @@ class PromptData(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
"""Kontext-style position IDs for image conditioning.
|
||||
|
||||
For FLUX.1-Kontext models, returns position IDs with first_coord=1
|
||||
to distinguish conditioning tokens from generation tokens (first_coord=0).
|
||||
|
||||
Returns:
|
||||
Position IDs array [1, seq_len, 3] for Kontext, None for other models.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
|
||||
from exo.worker.engines.image.models.flux.config import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_KONTEXT_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
)
|
||||
from exo.worker.engines.image.models.flux.kontext_adapter import (
|
||||
FluxKontextModelAdapter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FluxModelAdapter",
|
||||
"FluxKontextModelAdapter",
|
||||
"FLUX_DEV_CONFIG",
|
||||
"FLUX_KONTEXT_CONFIG",
|
||||
"FLUX_SCHNELL_CONFIG",
|
||||
]
|
||||
|
||||
@@ -59,6 +59,10 @@ class FluxPromptData(PromptData):
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
|
||||
@@ -15,7 +15,7 @@ FLUX_SCHNELL_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 1, "medium": 2, "high": 4},
|
||||
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
|
||||
num_sync_steps=1,
|
||||
)
|
||||
|
||||
|
||||
@@ -30,5 +30,21 @@ FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
num_sync_steps=4,
|
||||
)
|
||||
|
||||
|
||||
FLUX_KONTEXT_CONFIG = ImageModelConfig(
|
||||
model_family="flux-kontext",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps=4,
|
||||
guidance_scale=4.0,
|
||||
)
|
||||
|
||||
348
src/exo/worker/engines/image/models/flux/kontext_adapter.py
Normal file
348
src/exo/worker/engines/image/models/flux/kontext_adapter.py
Normal file
@@ -0,0 +1,348 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, final
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
|
||||
from mflux.models.flux.variants.kontext.kontext_util import KontextUtil
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
RotaryEmbeddings,
|
||||
)
|
||||
from exo.worker.engines.image.models.flux.wrappers import (
|
||||
FluxJointBlockWrapper,
|
||||
FluxSingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FluxKontextPromptData(PromptData):
|
||||
"""Prompt data for FLUX.1-Kontext image editing.
|
||||
|
||||
Stores text embeddings along with conditioning latents and position IDs
|
||||
for the input image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
conditioning_latents: mx.array,
|
||||
kontext_image_ids: mx.array,
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._pooled_prompt_embeds = pooled_prompt_embeds
|
||||
self._conditioning_latents = conditioning_latents
|
||||
self._kontext_image_ids = kontext_image_ids
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""VAE-encoded input image latents for Kontext conditioning."""
|
||||
return self._conditioning_latents
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
"""Position IDs for Kontext conditioning (first_coord=1)."""
|
||||
return self._kontext_image_ids
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Kontext doesn't use CFG, but we return positive data for compatibility."""
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
None,
|
||||
self._pooled_prompt_embeds,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
# Kontext doesn't use CFG
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
class FluxKontextModelAdapter(ModelAdapter[Flux1Kontext, Transformer]):
|
||||
"""Adapter for FLUX.1-Kontext image editing model.
|
||||
|
||||
Key differences from standard FluxModelAdapter:
|
||||
- Takes an input image and computes output dimensions from it
|
||||
- Creates conditioning latents from the input image via VAE
|
||||
- Creates special position IDs (kontext_image_ids) for conditioning tokens
|
||||
- Creates pure noise latents (not img2img blending)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = Flux1Kontext(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
model_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
# Stores image path and computed dimensions after set_image_dimensions
|
||||
self._image_path: str | None = None
|
||||
self._output_height: int | None = None
|
||||
self._output_width: int | None = None
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.x_embedder.weight.shape[0] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
return False
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return FluxLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper[Any]]:
|
||||
"""Create wrapped joint blocks for Flux Kontext."""
|
||||
return [
|
||||
FluxJointBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper[Any]]:
|
||||
"""Create wrapped single blocks for Flux Kontext."""
|
||||
return [
|
||||
FluxSingleBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.single_transformer_blocks
|
||||
]
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
all_joint = list(self._transformer.transformer_blocks)
|
||||
all_single = list(self._transformer.single_transformer_blocks)
|
||||
total_joint_blocks = len(all_joint)
|
||||
if end_layer <= total_joint_blocks:
|
||||
# All assigned are joint blocks
|
||||
joint_start, joint_end = start_layer, end_layer
|
||||
single_start, single_end = 0, 0
|
||||
elif start_layer >= total_joint_blocks:
|
||||
# All assigned are single blocks
|
||||
joint_start, joint_end = 0, 0
|
||||
single_start = start_layer - total_joint_blocks
|
||||
single_end = end_layer - total_joint_blocks
|
||||
else:
|
||||
# Spans both joint and single
|
||||
joint_start, joint_end = start_layer, total_joint_blocks
|
||||
single_start = 0
|
||||
single_end = end_layer - total_joint_blocks
|
||||
|
||||
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
|
||||
self._transformer.single_transformer_blocks = all_single[
|
||||
single_start:single_end
|
||||
]
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
|
||||
"""Compute and store dimensions from input image.
|
||||
|
||||
Also stores image_path for use in encode_prompt().
|
||||
|
||||
Args:
|
||||
image_path: Path to the input image
|
||||
|
||||
Returns:
|
||||
(output_width, output_height) for runtime config
|
||||
"""
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
|
||||
image_size = pil_image.size
|
||||
|
||||
# Compute output dimensions from input image aspect ratio
|
||||
# Target area of 1024x1024 = ~1M pixels
|
||||
target_area = 1024 * 1024
|
||||
ratio = image_size[0] / image_size[1]
|
||||
output_width = math.sqrt(target_area * ratio)
|
||||
output_height = output_width / ratio
|
||||
output_width = round(output_width / 32) * 32
|
||||
output_height = round(output_height / 32) * 32
|
||||
|
||||
# Ensure multiple of 16 for VAE
|
||||
vae_scale_factor = 8
|
||||
multiple_of = vae_scale_factor * 2
|
||||
output_width = output_width // multiple_of * multiple_of
|
||||
output_height = output_height // multiple_of * multiple_of
|
||||
|
||||
self._image_path = str(image_path)
|
||||
self._output_width = int(output_width)
|
||||
self._output_height = int(output_height)
|
||||
|
||||
return self._output_width, self._output_height
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
|
||||
"""Create initial noise latents for Kontext.
|
||||
|
||||
Unlike standard img2img which blends noise with encoded input,
|
||||
Kontext uses pure noise latents. The input image is provided
|
||||
separately as conditioning.
|
||||
"""
|
||||
return FluxLatentCreator.create_noise(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> FluxKontextPromptData:
|
||||
"""Encode prompt and create conditioning from stored input image.
|
||||
|
||||
Must call set_image_dimensions() before this method.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for editing
|
||||
negative_prompt: Ignored (Kontext doesn't use CFG)
|
||||
|
||||
Returns:
|
||||
FluxKontextPromptData with text embeddings and image conditioning
|
||||
"""
|
||||
del negative_prompt # Kontext doesn't support negative prompts or CFG
|
||||
|
||||
if (
|
||||
self._image_path is None
|
||||
or self._output_height is None
|
||||
or self._output_width is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"set_image_dimensions() must be called before encode_prompt() "
|
||||
"for FluxKontextModelAdapter"
|
||||
)
|
||||
|
||||
assert isinstance(self.model.prompt_cache, dict)
|
||||
assert isinstance(self.model.tokenizers, dict)
|
||||
|
||||
# Encode text prompt
|
||||
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_cache=self.model.prompt_cache,
|
||||
t5_tokenizer=self.model.tokenizers["t5"], # pyright: ignore[reportAny]
|
||||
clip_tokenizer=self.model.tokenizers["clip"], # pyright: ignore[reportAny]
|
||||
t5_text_encoder=self.model.t5_text_encoder,
|
||||
clip_text_encoder=self.model.clip_text_encoder,
|
||||
)
|
||||
|
||||
# Create conditioning latents from input image
|
||||
conditioning_latents, kontext_image_ids = (
|
||||
KontextUtil.create_image_conditioning_latents(
|
||||
vae=self.model.vae,
|
||||
height=self._output_height,
|
||||
width=self._output_width,
|
||||
image_path=self._image_path,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxKontextPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
conditioning_latents=conditioning_latents,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.x_embedder(hidden_states)
|
||||
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
if pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"pooled_prompt_embeds is required for Flux Kontext text embeddings"
|
||||
)
|
||||
|
||||
return Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
|
||||
)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> RotaryEmbeddings:
|
||||
return Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
self._transformer.pos_embed,
|
||||
runtime_config,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
raise NotImplementedError("Flux Kontext does not use classifier-free guidance")
|
||||
@@ -69,6 +69,10 @@ class QwenPromptData(PromptData):
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
|
||||
@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.25,
|
||||
num_sync_steps=7,
|
||||
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.25,
|
||||
num_sync_steps=7,
|
||||
guidance_scale=3.5,
|
||||
)
|
||||
|
||||
@@ -85,6 +85,10 @@ class QwenEditPromptData(PromptData):
|
||||
def qwen_image_ids(self) -> mx.array:
|
||||
return self._qwen_image_ids
|
||||
|
||||
@property
|
||||
def kontext_image_ids(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_edit_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -567,6 +567,7 @@ class DiffusionRunner:
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
conditioning_latents: mx.array | None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Run a single forward pass through the transformer.
|
||||
Args:
|
||||
@@ -578,6 +579,7 @@ class DiffusionRunner:
|
||||
encoder_hidden_states_mask: Attention mask for text (Qwen)
|
||||
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
|
||||
conditioning_latents: Conditioning latents for edit mode
|
||||
kontext_image_ids: Position IDs for Kontext conditioning (Flux Kontext)
|
||||
|
||||
Returns:
|
||||
Noise prediction tensor
|
||||
@@ -610,6 +612,7 @@ class DiffusionRunner:
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
@@ -681,6 +684,7 @@ class DiffusionRunner:
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
@@ -700,6 +704,7 @@ class DiffusionRunner:
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=branch.cond_latents,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
@@ -902,10 +907,10 @@ class DiffusionRunner:
|
||||
config: Config,
|
||||
hidden_states: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
|
||||
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
|
||||
@@ -979,10 +984,10 @@ class DiffusionRunner:
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
is_first_async_step: bool,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
kontext_image_ids = prompt_data.kontext_image_ids
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.base import (
|
||||
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
@@ -25,16 +28,21 @@ from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import MiniMaxAttention
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
@@ -378,7 +386,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
elif isinstance(model, Glm4MoeModel):
|
||||
tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -503,12 +519,21 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
@@ -524,7 +549,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
mx.eval(layer)
|
||||
@@ -532,7 +557,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
class ShardedMoE(CustomMlxLayer):
|
||||
"""Wraps any MoE layer with distributed sum_gradients / all_sum."""
|
||||
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
@@ -603,25 +630,89 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
|
||||
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # type: ignore
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
self.group = group
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: "Cache | None" = None,
|
||||
) -> mx.array:
|
||||
batch_dim, seq_dim, _ = x.shape
|
||||
|
||||
self._original_layer = cast(MiniMaxAttention, self.original_layer) # type: ignore
|
||||
|
||||
queries: mx.array = self._original_layer.q_proj(x)
|
||||
keys: mx.array = self._original_layer.k_proj(x)
|
||||
values: mx.array = self._original_layer.v_proj(x)
|
||||
|
||||
if getattr(self, "use_qk_norm", False):
|
||||
q_dim = queries.shape[-1]
|
||||
k_dim = keys.shape[-1]
|
||||
n = self.group.size()
|
||||
|
||||
qk = mx.concatenate(
|
||||
[queries, keys], axis=-1
|
||||
) # (batch_dim, seq_dim, q_dim + k_dim)
|
||||
qk = mx.distributed.all_gather(
|
||||
qk, group=self.group
|
||||
) # (n*batch_dim, seq_dim, q_dim + k_dim)
|
||||
|
||||
qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)
|
||||
queries = qk[..., :q_dim].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * q_dim)
|
||||
keys = qk[..., q_dim:].reshape(
|
||||
batch_dim, seq_dim, -1
|
||||
) # (batch_dim, seq_dim, n * k_dim)
|
||||
|
||||
queries = self._original_layer.q_norm(queries)
|
||||
keys = self._original_layer.k_norm(keys)
|
||||
|
||||
# Split back and take this rank's portion
|
||||
queries = mx.split(queries, n, axis=-1)[self.group.rank()]
|
||||
keys = mx.split(keys, n, axis=-1)[self.group.rank()]
|
||||
|
||||
queries = queries.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_attention_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(
|
||||
batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self._original_layer.rope(queries, offset=cache.offset)
|
||||
keys = self._original_layer.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self._original_layer.rope(queries)
|
||||
keys = self._original_layer.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self._original_layer.scale, # type: ignore
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)
|
||||
|
||||
return self._original_layer.o_proj(output)
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -632,7 +723,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
rank = self.group.rank()
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -643,18 +733,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
# Shard qk_norm weights if present (must match sharded head count)
|
||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
@@ -666,7 +749,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe = ShardedMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
return model
|
||||
@@ -679,28 +762,111 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
if isinstance(layer, Qwen3DecoderLayer):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer)
|
||||
if hasattr(layer, "linear_attn"):
|
||||
linear_attn = layer.linear_attn
|
||||
|
||||
linear_attn.in_proj_qkvz = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_qkvz
|
||||
)
|
||||
linear_attn.in_proj_ba = self.all_to_sharded_linear(
|
||||
linear_attn.in_proj_ba
|
||||
)
|
||||
linear_attn.out_proj = self.sharded_to_all_linear(
|
||||
linear_attn.out_proj
|
||||
)
|
||||
|
||||
# Shard conv1d: depthwise conv with non-contiguous channel slicing.
|
||||
# Channel layout is [q(key_dim), k(key_dim), v(value_dim)].
|
||||
# Each rank takes its head-slice from each of the three sections.
|
||||
rank = self.group.rank()
|
||||
key_dim = linear_attn.key_dim
|
||||
value_dim = linear_attn.value_dim
|
||||
key_dim_shard = key_dim // self.N
|
||||
value_dim_shard = value_dim // self.N
|
||||
|
||||
q_idx = mx.arange(rank * key_dim_shard, (rank + 1) * key_dim_shard)
|
||||
k_idx = mx.arange(
|
||||
key_dim + rank * key_dim_shard,
|
||||
key_dim + (rank + 1) * key_dim_shard,
|
||||
)
|
||||
v_idx = mx.arange(
|
||||
2 * key_dim + rank * value_dim_shard,
|
||||
2 * key_dim + (rank + 1) * value_dim_shard,
|
||||
)
|
||||
conv_indices = mx.concatenate([q_idx, k_idx, v_idx])
|
||||
linear_attn.conv1d.weight = linear_attn.conv1d.weight[conv_indices]
|
||||
new_conv_dim = key_dim_shard * 2 + value_dim_shard
|
||||
linear_attn.conv1d.groups = new_conv_dim
|
||||
|
||||
num_v_shard = linear_attn.num_v_heads // self.N
|
||||
v_start = rank * num_v_shard
|
||||
v_end = v_start + num_v_shard
|
||||
linear_attn.A_log = linear_attn.A_log[v_start:v_end]
|
||||
linear_attn.dt_bias = linear_attn.dt_bias[v_start:v_end]
|
||||
|
||||
linear_attn.num_k_heads //= self.N
|
||||
linear_attn.num_v_heads //= self.N
|
||||
linear_attn.key_dim = (
|
||||
linear_attn.head_k_dim * linear_attn.num_k_heads
|
||||
)
|
||||
linear_attn.value_dim = (
|
||||
linear_attn.head_v_dim * linear_attn.num_v_heads
|
||||
)
|
||||
linear_attn.conv_dim = (
|
||||
linear_attn.key_dim * 2 + linear_attn.value_dim
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(
|
||||
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
|
||||
):
|
||||
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_expert.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_expert.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)
|
||||
layer.mlp = ShardedMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -713,18 +879,50 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedQwenMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Glm4MoeModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
|
||||
if isinstance(layer.mlp, MoE):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
layer.mlp = ShardedMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
else:
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -762,21 +960,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp = ShardedMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -26,51 +24,119 @@ _MEMORY_THRESHOLD = float(
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
class CacheSnapshot:
|
||||
"""Snapshot of states at a known token position."""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
self, states: list[RotatingKVCache | ArraysCache | None], token_count: int
|
||||
):
|
||||
self.states = states
|
||||
self.token_count = token_count
|
||||
|
||||
|
||||
def snapshot_ssm_states(cache: KVCacheType) -> CacheSnapshot:
|
||||
states: list[ArraysCache | RotatingKVCache | None] = []
|
||||
for c in cache:
|
||||
if isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
states.append(deepcopy(c))
|
||||
else:
|
||||
states.append(None)
|
||||
token_count = cache_length(cache)
|
||||
return CacheSnapshot(states=states, token_count=token_count)
|
||||
|
||||
|
||||
def _find_nearest_snapshot(
|
||||
snapshots: list[CacheSnapshot],
|
||||
target_token_count: int,
|
||||
) -> CacheSnapshot | None:
|
||||
best: CacheSnapshot | None = None
|
||||
for snap in snapshots:
|
||||
if snap.token_count <= target_token_count and (
|
||||
best is None or snap.token_count > best.token_count
|
||||
):
|
||||
best = snap
|
||||
return best
|
||||
|
||||
|
||||
def has_non_kv_caches(cache: KVCacheType) -> bool:
|
||||
"""Check if a cache contains any ArraysCache (SSM) entries."""
|
||||
return any(isinstance(c, (ArraysCache, RotatingKVCache)) for c in cache)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, group: mx.distributed.Group | None = None):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._snapshots: list[list[CacheSnapshot] | None] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self._snapshots.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
||||
def add_kv_cache(
|
||||
self,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
ssm_snapshots: list[CacheSnapshot] | None = None,
|
||||
):
|
||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
||||
self._evict_if_needed()
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.prompts.append(prompt_tokens)
|
||||
self.caches.append(deepcopy(cache))
|
||||
self._snapshots.append(ssm_snapshots)
|
||||
self._access_counter += 1
|
||||
self._last_used.append(self._access_counter)
|
||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
||||
logger.info(f"KV cache added: {len(prompt_tokens)} tokens")
|
||||
|
||||
def update_kv_cache(
|
||||
self,
|
||||
index: int,
|
||||
prompt: str,
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
snapshots: list[CacheSnapshot] | None,
|
||||
restore_pos: int,
|
||||
):
|
||||
"""Update an existing cache entry in-place."""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts[index] = tokenized_prompt
|
||||
old_snapshots = self._snapshots[index]
|
||||
merged: list[CacheSnapshot] = []
|
||||
if old_snapshots:
|
||||
merged = [s for s in old_snapshots if s.token_count <= restore_pos]
|
||||
if snapshots:
|
||||
merged.extend(snapshots)
|
||||
|
||||
self.prompts[index] = prompt_tokens
|
||||
self.caches[index] = deepcopy(cache)
|
||||
self._snapshots[index] = merged or None
|
||||
self._access_counter += 1
|
||||
self._last_used[index] = self._access_counter
|
||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
||||
logger.info(f"KV cache updated (index {index}): {len(prompt_tokens)} tokens")
|
||||
|
||||
def _get_snapshot(
|
||||
self, entry_index: int, target_token_count: int
|
||||
) -> tuple[int, CacheSnapshot | None]:
|
||||
if not has_non_kv_caches(self.caches[entry_index]):
|
||||
return target_token_count, None
|
||||
|
||||
snapshots = self._snapshots[entry_index]
|
||||
if not snapshots:
|
||||
return 0, None
|
||||
|
||||
snap = _find_nearest_snapshot(snapshots, target_token_count)
|
||||
if snap is not None:
|
||||
return snap.token_count, snap
|
||||
|
||||
return 0, None
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
prompt: str,
|
||||
prompt_tokens: mx.array,
|
||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
@@ -79,76 +145,71 @@ class KVPrefixCache:
|
||||
- cache: KV cache to use for generation
|
||||
- remaining_tokens: tokens that still need prefilling
|
||||
- matched_index: index of the matched entry (None if no match)
|
||||
|
||||
For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the
|
||||
nearest SSM snapshot position at or before the match point for correctness.
|
||||
Same for rotating KV Cache.
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
max_length = len(prompt_tokens)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
best_index: int | None = None
|
||||
best_length = 0
|
||||
is_exact = False
|
||||
|
||||
# Find best cache
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
length = get_prefix_length(prompt_tokens, cached_prompt)
|
||||
if length > best_length:
|
||||
best_index, best_length = i, length
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
self._access_counter += 1
|
||||
self._last_used[i] = self._access_counter
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:], i
|
||||
is_exact = True
|
||||
best_index, best_length = i, length
|
||||
break
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
if best_index is None:
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
# For exact match: trim to max_length-1 so remaining has the last token
|
||||
# For partial match: trim to best_length, remaining has suffix to prefill
|
||||
# This ensures stream_generate always has at least one token to start with
|
||||
target = (max_length - 1) if is_exact else best_length
|
||||
restore_pos, restore_snap = self._get_snapshot(best_index, target)
|
||||
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
# No usable snapshot — need fresh cache
|
||||
if restore_snap is None and has_non_kv_caches(self.caches[best_index]):
|
||||
return make_kv_cache(model), prompt_tokens, None
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
prompt_cache = deepcopy(self.caches[best_index])
|
||||
cached_length = cache_length(self.caches[best_index])
|
||||
tokens_to_trim = cached_length - restore_pos
|
||||
if tokens_to_trim > 0:
|
||||
trim_cache(prompt_cache, tokens_to_trim, restore_snap)
|
||||
# Reset cache offset to match trimmed length
|
||||
for c in prompt_cache:
|
||||
if hasattr(c, "offset"):
|
||||
c.offset = restore_pos
|
||||
|
||||
self._access_counter += 1
|
||||
self._last_used[best_snapshot_index] = self._access_counter
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
||||
self._access_counter += 1
|
||||
self._last_used[best_index] = self._access_counter
|
||||
remaining = prompt_tokens[restore_pos:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
return prompt_cache, remaining, best_index
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
# Evict LRU entries until below threshold
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
len(self.caches) > 0
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._snapshots.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
@@ -169,6 +230,21 @@ class KVPrefixCache:
|
||||
return max_pressure
|
||||
|
||||
|
||||
def trim_cache(
|
||||
cache: KVCacheType,
|
||||
num_tokens: int,
|
||||
snapshot: CacheSnapshot | None = None,
|
||||
) -> None:
|
||||
for i, c in enumerate(cache):
|
||||
if isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
if snapshot is not None and snapshot.states[i] is not None:
|
||||
cache[i] = deepcopy(snapshot.states[i]) # type: ignore
|
||||
else:
|
||||
c.state = [None] * len(c.state) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
else:
|
||||
c.trim(num_tokens) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
@@ -177,14 +253,14 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(prompt_tokens)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
# Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).
|
||||
return max(getattr(c, "offset", 0) for c in cache)
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
@@ -214,8 +290,7 @@ def make_kv_cache(
|
||||
) -> KVCacheType:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
if hasattr(model, "make_cache"):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.models.cache import ArraysCache, RotatingKVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
@@ -23,7 +24,14 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
CacheSnapshot,
|
||||
KVPrefixCache,
|
||||
encode_prompt,
|
||||
has_non_kv_caches,
|
||||
make_kv_cache,
|
||||
snapshot_ssm_states,
|
||||
)
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
DEFAULT_TOP_LOGPROBS,
|
||||
KV_BITS,
|
||||
@@ -32,6 +40,7 @@ from exo.worker.engines.mlx.constants import (
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
fix_unmatched_think_end_tokens,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -47,7 +56,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> tuple[float, int]:
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -58,17 +67,21 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0, 0
|
||||
return 0.0, 0, []
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
has_ssm = has_non_kv_caches(cache)
|
||||
snapshots: list[CacheSnapshot] = []
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
if has_ssm:
|
||||
snapshots.append(snapshot_ssm_states(cache))
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
@@ -85,7 +98,18 @@ def prefill(
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
||||
|
||||
# stream_generate added 1 extra generated token to the cache, so we should trim it.
|
||||
# Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.
|
||||
pre_gen = deepcopy(snapshots[-2]) if has_ssm else None
|
||||
for i, c in enumerate(cache):
|
||||
if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):
|
||||
assert pre_gen is not None
|
||||
if pre_gen.states[i] is not None:
|
||||
cache[i] = deepcopy(pre_gen.states[i]) # type: ignore
|
||||
else:
|
||||
assert not isinstance(c, (ArraysCache, RotatingKVCache))
|
||||
c.trim(2) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
@@ -93,12 +117,14 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec, num_tokens
|
||||
# Exclude the last snapshot
|
||||
return tokens_per_sec, num_tokens, snapshots[:-1] if snapshots else []
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -117,7 +143,7 @@ def warmup_inference(
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.7)
|
||||
sampler = make_sampler(temp=0.0)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(
|
||||
@@ -136,9 +162,7 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
mx_barrier(group)
|
||||
|
||||
return tokens_generated
|
||||
|
||||
@@ -221,11 +245,17 @@ def mlx_generate(
|
||||
task: TextGenerationTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
# TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.
|
||||
seed = task.seed or 42
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Encode prompt once at the top and fix unmatched think tags
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)
|
||||
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
is_bench = task.bench
|
||||
@@ -237,13 +267,16 @@ def mlx_generate(
|
||||
matched_index: int | None = None
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prompt_tokens = all_prompt_tokens
|
||||
else:
|
||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, all_prompt_tokens
|
||||
)
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
||||
if prefix_hit_length > 0:
|
||||
logger.info(
|
||||
f"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
|
||||
)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -265,13 +298,21 @@ def mlx_generate(
|
||||
)
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Ready to prefill")
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
model,
|
||||
tokenizer,
|
||||
sampler,
|
||||
prompt_tokens[:-1],
|
||||
caches,
|
||||
)
|
||||
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
last_token = prompt_tokens[-2:]
|
||||
|
||||
max_tokens = task.max_output_tokens or MAX_TOKENS
|
||||
accumulated_text = ""
|
||||
@@ -299,7 +340,6 @@ def mlx_generate(
|
||||
start=1,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
accumulated_text += out.text
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
@@ -367,16 +407,6 @@ def mlx_generate(
|
||||
selected_token=out.token,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
@@ -390,18 +420,44 @@ def mlx_generate(
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
if kv_prefix_cache is not None:
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
generated_tokens_array = mx.array(
|
||||
tokenizer.encode(
|
||||
"".join(generated_text_parts), add_special_tokens=False
|
||||
)
|
||||
)
|
||||
full_prompt_tokens = mx.concatenate(
|
||||
[all_prompt_tokens, generated_tokens_array]
|
||||
)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
||||
kv_prefix_cache.update_kv_cache(
|
||||
matched_index,
|
||||
full_prompt_tokens,
|
||||
caches,
|
||||
cache_snapshots,
|
||||
restore_pos=prefix_hit_length,
|
||||
)
|
||||
else:
|
||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
||||
kv_prefix_cache.add_kv_cache(
|
||||
full_prompt_tokens, caches, cache_snapshots
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if is_done:
|
||||
mx_barrier(group)
|
||||
break
|
||||
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -490,6 +490,30 @@ def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> b
|
||||
return think_token is not None and prompt.rstrip().endswith(think_token)
|
||||
|
||||
|
||||
def fix_unmatched_think_end_tokens(
|
||||
tokens: mx.array, tokenizer: TokenizerWrapper
|
||||
) -> mx.array:
|
||||
if not tokenizer.has_thinking:
|
||||
return tokens
|
||||
assert tokenizer.think_start_id
|
||||
assert tokenizer.think_end_id
|
||||
think_start_id: int = tokenizer.think_start_id
|
||||
think_end_id: int = tokenizer.think_end_id
|
||||
token_list: list[int] = cast(list[int], tokens.tolist())
|
||||
result: list[int] = []
|
||||
depth = 0
|
||||
for token in token_list:
|
||||
if token == think_start_id:
|
||||
depth += 1
|
||||
elif token == think_end_id:
|
||||
if depth == 0:
|
||||
result.append(think_start_id)
|
||||
else:
|
||||
depth -= 1
|
||||
result.append(token)
|
||||
return mx.array(result)
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
"""
|
||||
A KVCache that pretends to exist but holds zero tokens.
|
||||
|
||||
@@ -193,7 +193,7 @@ def main(
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
kv_prefix_cache = KVPrefixCache(group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
@@ -226,6 +226,7 @@ def main(
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
group=group,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
@@ -274,6 +275,7 @@ def main(
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
group=group,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
@@ -627,7 +629,7 @@ def parse_thinking_models(
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id, # type: ignore
|
||||
"token": tokenizer.think_start_id,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
@@ -808,8 +810,9 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
|
||||
@@ -833,9 +836,10 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
original_func_name = func_name_match.group(1)
|
||||
tool_id = func_name_match.group(2)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = func_name[func_name.find(".") + 1 :]
|
||||
func_name = original_func_name[original_func_name.find(".") + 1 :]
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
@@ -844,7 +848,11 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
|
||||
return dict(
|
||||
id=f"{original_func_name}:{tool_id}",
|
||||
name=func_name,
|
||||
arguments=arg_dct, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
@@ -927,7 +935,13 @@ def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
and ((args := obj.get("arguments")) is not None)
|
||||
and isinstance(name, str)
|
||||
):
|
||||
return ToolCallItem(name=name, arguments=json.dumps(args))
|
||||
raw_id: object = obj.get("id")
|
||||
extra = {"id": str(raw_id)} if raw_id is not None else {}
|
||||
return ToolCallItem(
|
||||
**extra,
|
||||
name=name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
else:
|
||||
raise ValidationError
|
||||
|
||||
|
||||
@@ -88,12 +88,12 @@ class TestKVPrefix:
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
@@ -101,7 +101,7 @@ class TestKVPrefix:
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache = KVPrefixCache()
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
@@ -142,10 +142,12 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert cache_length(cache) == len(tokens)
|
||||
# Cache should now hold the prompt tokens minus one
|
||||
assert cache_length(cache) == len(tokens) - 1
|
||||
# Snapshots should be available for models with non-KV caches
|
||||
assert len(snapshots) > 0
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
@@ -159,10 +161,10 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
@@ -170,7 +172,7 @@ class TestKVPrefixCacheWithModel:
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, tokens
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
@@ -191,10 +193,12 @@ class TestKVPrefixCacheWithModel:
|
||||
short_tokens = encode_prompt(tokenizer, short_prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
||||
_, _, snapshots = prefill(
|
||||
model, tokenizer, make_sampler(0.0), short_tokens, cache
|
||||
)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
long_task = TextGenerationTaskParams(
|
||||
@@ -212,13 +216,12 @@ class TestKVPrefixCacheWithModel:
|
||||
)
|
||||
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, long_prompt
|
||||
model, long_tokens
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# remaining_tokens should be the suffix after the shared prefix
|
||||
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
|
||||
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
|
||||
# remaining_tokens covers from snapshot restore position to end
|
||||
assert len(remaining_tokens) >= len(long_tokens) - expected_prefix
|
||||
|
||||
def test_stored_cache_not_mutated_after_get_and_generation(
|
||||
self, model_and_tokenizer
|
||||
@@ -235,15 +238,15 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
assert matched_index == 0
|
||||
|
||||
# Simulate generation: feed many additional tokens through the cache
|
||||
@@ -273,15 +276,15 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
_, _, snapshots = prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokens)
|
||||
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
@@ -298,7 +301,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello")],
|
||||
@@ -328,7 +331,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Reuse test")],
|
||||
@@ -352,20 +355,20 @@ class TestKVPrefixCacheWithModel:
|
||||
# Second call should find a prefix match (the stored cache contains
|
||||
# prompt + generated tokens, which shares the prompt prefix)
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
model, prompt_tokens
|
||||
)
|
||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
||||
# so this is a prefix match where our prompt is fully contained
|
||||
assert matched_index == 0
|
||||
# Exact match: remaining_tokens is just the last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
|
||||
# Exact match: remaining_tokens is just the last token and the one before
|
||||
assert len(remaining_tokens) == 2
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-2:])
|
||||
|
||||
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -444,7 +447,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Immutable test")],
|
||||
@@ -481,7 +484,7 @@ class TestKVPrefixCacheWithModel:
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
@@ -495,7 +498,7 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
|
||||
@@ -505,19 +508,10 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache._last_used[2] = 100.0
|
||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
||||
|
||||
# Simulate memory pressure: active memory exceeds threshold
|
||||
fake_limit = 1000
|
||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
||||
return_value=fake_active,
|
||||
),
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
||||
return_value={"max_recommended_working_set_size": fake_limit},
|
||||
),
|
||||
# Simulate memory pressure: return usage above _MEMORY_THRESHOLD (0.9)
|
||||
with patch(
|
||||
"exo.worker.engines.mlx.cache.get_memory_used_percentage",
|
||||
return_value=0.95,
|
||||
):
|
||||
# Trigger eviction by adding a new entry
|
||||
task = TextGenerationTaskParams(
|
||||
@@ -529,14 +523,11 @@ class TestKVPrefixCacheWithModel:
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
kv_prefix_cache.add_kv_cache(tokens, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
||||
# all old entries get evicted, leaving only the newly added one
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], tokens) == len(tokens)
|
||||
|
||||
@@ -34,6 +34,7 @@ TOKENIZER_FILE_PATTERNS = [
|
||||
"added_tokens.json",
|
||||
"tokenizer.model",
|
||||
"tokenization_*.py", # Custom tokenizer implementations
|
||||
"tool_declaration_ts.py", # Dependency of tokenization_kimi.py
|
||||
]
|
||||
|
||||
|
||||
|
||||
53
tests/auto_bench.sh
Executable file
53
tests/auto_bench.sh
Executable file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ $# -lt 1 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
sleep 1
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
}
|
||||
trap 'cleanup' EXIT INT TERM
|
||||
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit" &
|
||||
done
|
||||
wait
|
||||
for host; do
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" &>/dev/null &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..." 1>&2
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
|
||||
echo "Waiting 30s for cluster setup" 1>&2
|
||||
sleep 30
|
||||
echo "EXO loaded" 1>&2
|
||||
bench_runner="${hosts[0]}"
|
||||
mkdir -p "./bench/$commit"
|
||||
nix run .#exo-get-all-models-on-cluster -- "$bench_runner" | while IFS= read -r model; do
|
||||
echo "running bench for $model" 1>&2
|
||||
ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 "$bench_runner@$bench_runner" "/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-bench -- --model $model --pp 128 4096 --tg 128 --stdout --skip-tensor-ring" >>"./bench/$commit/${model//\//--}.json"
|
||||
echo
|
||||
done
|
||||
36
tests/get_all_models_on_cluster.py
Executable file
36
tests/get_all_models_on_cluster.py
Executable file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
from urllib.request import urlopen
|
||||
|
||||
h = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f"USAGE: {sys.argv[0]} host")
|
||||
ts = subprocess.run(
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = next(
|
||||
(sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None
|
||||
) or sys.exit(f"{h} not found in tailscale")
|
||||
with urlopen(f"http://{ip}:52415/state", timeout=5) as r:
|
||||
data = json.loads(r.read()).get("downloads", {})
|
||||
|
||||
|
||||
def mid(x: dict[str, Any]) -> str | None:
|
||||
for k in (
|
||||
"DownloadCompleted",
|
||||
"shardMetadata",
|
||||
"PipelineShardMetadata",
|
||||
"modelCard",
|
||||
"modelId",
|
||||
):
|
||||
x = x.get(k, {})
|
||||
return cast(str | None, x if x != {} else None)
|
||||
|
||||
|
||||
common = set[str].intersection(
|
||||
*[{m for d in nid if (m := mid(d))} for nid in data.values()]
|
||||
)
|
||||
for c in common:
|
||||
print(c)
|
||||
@@ -35,7 +35,7 @@ i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
done
|
||||
|
||||
|
||||
377
tmp/quantize_and_upload.py
Executable file
377
tmp/quantize_and_upload.py
Executable file
@@ -0,0 +1,377 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download an mflux model, quantize it, and upload to HuggingFace.
|
||||
|
||||
Usage (run from mflux project directory):
|
||||
cd /path/to/mflux
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
|
||||
uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
|
||||
|
||||
Requires:
|
||||
- Must be run from mflux project directory using `uv run`
|
||||
- huggingface_hub installed (add to mflux deps or install separately)
|
||||
- HuggingFace authentication: run `huggingface-cli login` or set HF_TOKEN
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
|
||||
|
||||
HF_ORG = "exolabs"
|
||||
|
||||
|
||||
def get_model_class(model_name: str) -> type:
|
||||
"""Get the appropriate model class based on model name."""
|
||||
from mflux.models.fibo.variants.txt2img.fibo import FIBO
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
from mflux.models.z_image.variants.turbo.z_image_turbo import ZImageTurbo
|
||||
|
||||
model_name_lower = model_name.lower()
|
||||
if "qwen" in model_name_lower:
|
||||
return QwenImage
|
||||
elif "fibo" in model_name_lower:
|
||||
return FIBO
|
||||
elif "z-image" in model_name_lower or "zimage" in model_name_lower:
|
||||
return ZImageTurbo
|
||||
elif "flux2" in model_name_lower or "flux.2" in model_name_lower:
|
||||
return Flux2Klein
|
||||
else:
|
||||
return Flux1
|
||||
|
||||
|
||||
def get_repo_name(model_name: str, bits: int | None) -> str:
|
||||
"""Get the HuggingFace repo name for a model variant."""
|
||||
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
|
||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
suffix = f"-{bits}bit" if bits else ""
|
||||
return f"{HF_ORG}/{base_name}{suffix}"
|
||||
|
||||
|
||||
def get_local_path(output_dir: Path, model_name: str, bits: int | None) -> Path:
|
||||
"""Get the local save path for a model variant."""
|
||||
# Extract repo name from HF path (e.g., "black-forest-labs/FLUX.1-Kontext-dev" -> "FLUX.1-Kontext-dev")
|
||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
suffix = f"-{bits}bit" if bits else ""
|
||||
return output_dir / f"{base_name}{suffix}"
|
||||
|
||||
|
||||
def copy_source_repo(
|
||||
source_repo: str,
|
||||
local_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Copy all files from source repo (replicating original HF structure)."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Copying full repo from source: {source_repo}")
|
||||
print(f"Output path: {local_path}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would download all files from source repo")
|
||||
return
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download all files to our local path
|
||||
snapshot_download(
|
||||
repo_id=source_repo,
|
||||
local_dir=local_path,
|
||||
)
|
||||
|
||||
# Remove root-level safetensors files (flux.1-dev.safetensors, etc.)
|
||||
# These are redundant with the component directories
|
||||
for f in local_path.glob("*.safetensors"):
|
||||
print(f"Removing root-level safetensors: {f.name}")
|
||||
if not dry_run:
|
||||
f.unlink()
|
||||
|
||||
print(f"Source repo copied to {local_path}")
|
||||
|
||||
|
||||
def load_and_save_quantized_model(
|
||||
model_name: str,
|
||||
bits: int,
|
||||
output_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Load a model with quantization and save it in mflux format."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Loading {model_name} with {bits}-bit quantization...")
|
||||
print(f"Output path: {output_path}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would load and save quantized model")
|
||||
return
|
||||
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
|
||||
model_class = get_model_class(model_name)
|
||||
model_config = ModelConfig.from_name(model_name=model_name, base_model=None)
|
||||
|
||||
model: Flux1 = model_class(
|
||||
quantize=bits,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
print(f"Saving model to {output_path}...")
|
||||
model.save_model(str(output_path))
|
||||
print(f"Model saved successfully to {output_path}")
|
||||
|
||||
|
||||
def copy_source_metadata(
|
||||
source_repo: str,
|
||||
local_path: Path,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Copy metadata files (LICENSE, README, etc.) from source repo, excluding safetensors."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Copying metadata from source repo: {source_repo}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would download metadata files (excluding *.safetensors)")
|
||||
return
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download all files except safetensors to our local path
|
||||
snapshot_download(
|
||||
repo_id=source_repo,
|
||||
local_dir=local_path,
|
||||
ignore_patterns=["*.safetensors"],
|
||||
)
|
||||
print(f"Metadata files copied to {local_path}")
|
||||
|
||||
|
||||
def upload_to_huggingface(
|
||||
local_path: Path,
|
||||
repo_id: str,
|
||||
dry_run: bool = False,
|
||||
clean_remote: bool = False,
|
||||
) -> None:
|
||||
"""Upload a saved model to HuggingFace."""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Uploading to HuggingFace: {repo_id}")
|
||||
print(f"Local path: {local_path}")
|
||||
print(f"Clean remote first: {clean_remote}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would upload to HuggingFace")
|
||||
return
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# Create the repo if it doesn't exist
|
||||
print(f"Creating/verifying repo: {repo_id}")
|
||||
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
|
||||
|
||||
# Clean remote repo if requested (delete old mflux-format files)
|
||||
if clean_remote:
|
||||
print("Cleaning old mflux-format files from remote...")
|
||||
try:
|
||||
# Pattern for mflux numbered shards: <dir>/<number>.safetensors
|
||||
numbered_pattern = re.compile(r".*/\d+\.safetensors$")
|
||||
|
||||
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="model")
|
||||
for file_path in repo_files:
|
||||
# Delete numbered safetensors (mflux format) and mflux index files
|
||||
if numbered_pattern.match(file_path) or file_path.endswith(
|
||||
"/model.safetensors.index.json"
|
||||
):
|
||||
print(f" Deleting: {file_path}")
|
||||
api.delete_file(
|
||||
path_in_repo=file_path, repo_id=repo_id, repo_type="model"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not clean remote files: {e}")
|
||||
|
||||
# Upload the folder
|
||||
print("Uploading folder contents...")
|
||||
api.upload_folder(
|
||||
folder_path=str(local_path),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f"Upload complete: https://huggingface.co/{repo_id}")
|
||||
|
||||
|
||||
def clean_local_files(local_path: Path, dry_run: bool = False) -> None:
|
||||
"""Remove local model files after upload."""
|
||||
print(f"\nCleaning up: {local_path}")
|
||||
if dry_run:
|
||||
print("[DRY RUN] Would remove local files")
|
||||
return
|
||||
|
||||
if local_path.exists():
|
||||
shutil.rmtree(local_path)
|
||||
print(f"Removed {local_path}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download an mflux model, quantize it, and upload to HuggingFace.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Process all variants (base, 4-bit, 8-bit) for FLUX.1-Kontext-dev
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev
|
||||
|
||||
# Only process 4-bit variant
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit
|
||||
|
||||
# Save locally without uploading
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-upload
|
||||
|
||||
# Preview what would happen
|
||||
python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
required=True,
|
||||
help="HuggingFace model path (e.g., black-forest-labs/FLUX.1-Kontext-dev)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("./tmp/models"),
|
||||
help="Local directory to save models (default: ./tmp/models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-base",
|
||||
action="store_true",
|
||||
help="Skip base model (no quantization)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-4bit",
|
||||
action="store_true",
|
||||
help="Skip 4-bit quantized model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-8bit",
|
||||
action="store_true",
|
||||
help="Skip 8-bit quantized model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-download",
|
||||
action="store_true",
|
||||
help="Skip downloading/processing, only do upload/clean operations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-upload",
|
||||
action="store_true",
|
||||
help="Only save locally, don't upload to HuggingFace",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean",
|
||||
action="store_true",
|
||||
help="Remove local files after upload",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean-remote",
|
||||
action="store_true",
|
||||
help="Delete old mflux-format files from remote repo before uploading",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print actions without executing",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine which variants to process
|
||||
variants: list[int | None] = []
|
||||
if not args.skip_base:
|
||||
variants.append(None) # Base model (no quantization)
|
||||
if not args.skip_4bit:
|
||||
variants.append(4)
|
||||
if not args.skip_8bit:
|
||||
variants.append(8)
|
||||
|
||||
if not variants:
|
||||
print("Error: All variants skipped. Nothing to do.")
|
||||
return 1
|
||||
|
||||
# Create output directory
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Output directory: {args.output_dir}")
|
||||
print(
|
||||
f"Variants to process: {['base' if v is None else f'{v}-bit' for v in variants]}"
|
||||
)
|
||||
print(f"Upload to HuggingFace: {not args.skip_upload}")
|
||||
print(f"Clean after upload: {args.clean}")
|
||||
if args.dry_run:
|
||||
print("\n*** DRY RUN MODE - No actual changes will be made ***")
|
||||
|
||||
# Process each variant
|
||||
for bits in variants:
|
||||
local_path = get_local_path(args.output_dir, args.model, bits)
|
||||
repo_id = get_repo_name(args.model, bits)
|
||||
|
||||
if not args.skip_download:
|
||||
if bits is None:
|
||||
# Base model: copy original HF repo structure (no mflux conversion)
|
||||
copy_source_repo(
|
||||
source_repo=args.model,
|
||||
local_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
else:
|
||||
# Quantized model: load, quantize, and save with mflux
|
||||
load_and_save_quantized_model(
|
||||
model_name=args.model,
|
||||
bits=bits,
|
||||
output_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
# Copy metadata from source repo (LICENSE, README, etc.)
|
||||
copy_source_metadata(
|
||||
source_repo=args.model,
|
||||
local_path=local_path,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
# Upload
|
||||
if not args.skip_upload:
|
||||
upload_to_huggingface(
|
||||
local_path=local_path,
|
||||
repo_id=repo_id,
|
||||
dry_run=args.dry_run,
|
||||
clean_remote=args.clean_remote,
|
||||
)
|
||||
|
||||
# Clean up if requested
|
||||
if args.clean:
|
||||
clean_local_files(local_path, dry_run=args.dry_run)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All done!")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user