mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 19:52:16 -05:00
Compare commits
1 Commits
ciaran/num
...
JakeHillio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0806c92ac9 |
@@ -148,15 +148,6 @@
|
||||
setImageGenerationParams({ guidance: null });
|
||||
}
|
||||
|
||||
function handleNumSyncStepsChange(event: Event) {
|
||||
const value = parseInt((event.target as HTMLInputElement).value, 10);
|
||||
setImageGenerationParams({ numSyncSteps: value });
|
||||
}
|
||||
|
||||
function clearNumSyncSteps() {
|
||||
setImageGenerationParams({ numSyncSteps: null });
|
||||
}
|
||||
|
||||
function handleReset() {
|
||||
resetImageGenerationParams();
|
||||
showAdvanced = false;
|
||||
@@ -166,8 +157,7 @@
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null,
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
|
||||
);
|
||||
</script>
|
||||
|
||||
@@ -588,50 +578,7 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 3: Sync Steps -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span
|
||||
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
|
||||
>SYNC STEPS:</span
|
||||
>
|
||||
<div class="flex items-center gap-2 flex-1 max-w-xs">
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max="100"
|
||||
value={params.numSyncSteps ?? 1}
|
||||
oninput={handleNumSyncStepsChange}
|
||||
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
|
||||
/>
|
||||
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
|
||||
{params.numSyncSteps ?? "--"}
|
||||
</span>
|
||||
{#if params.numSyncSteps !== null}
|
||||
<button
|
||||
type="button"
|
||||
onclick={clearNumSyncSteps}
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
title="Clear"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 4: Negative Prompt -->
|
||||
<!-- Row 3: Negative Prompt -->
|
||||
<div class="flex flex-col gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>NEGATIVE PROMPT:</span
|
||||
|
||||
@@ -298,7 +298,6 @@ export interface ImageGenerationParams {
|
||||
numInferenceSteps: number | null;
|
||||
guidance: number | null;
|
||||
negativePrompt: string | null;
|
||||
numSyncSteps: number | null;
|
||||
// Edit mode params
|
||||
inputFidelity: "low" | "high";
|
||||
}
|
||||
@@ -320,7 +319,6 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
numInferenceSteps: null,
|
||||
guidance: null,
|
||||
negativePrompt: null,
|
||||
numSyncSteps: null,
|
||||
inputFidelity: "low",
|
||||
};
|
||||
|
||||
@@ -2398,9 +2396,7 @@ class AppStore {
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null;
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "");
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
@@ -2425,9 +2421,6 @@ class AppStore {
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
...(params.numSyncSteps !== null && {
|
||||
num_sync_steps: params.numSyncSteps,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -2677,19 +2670,11 @@ class AppStore {
|
||||
formData.append("input_fidelity", params.inputFidelity);
|
||||
|
||||
// Advanced params
|
||||
const hasAdvancedParams =
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "") ||
|
||||
params.numSyncSteps !== null;
|
||||
|
||||
if (hasAdvancedParams) {
|
||||
if (params.seed !== null) {
|
||||
formData.append(
|
||||
"advanced_params",
|
||||
JSON.stringify({
|
||||
...(params.seed !== null && { seed: params.seed }),
|
||||
seed: params.seed,
|
||||
...(params.numInferenceSteps !== null && {
|
||||
num_inference_steps: params.numInferenceSteps,
|
||||
}),
|
||||
@@ -2698,9 +2683,24 @@ class AppStore {
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
...(params.numSyncSteps !== null && {
|
||||
num_sync_steps: params.numSyncSteps,
|
||||
}),
|
||||
);
|
||||
} else if (
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "")
|
||||
) {
|
||||
formData.append(
|
||||
"advanced_params",
|
||||
JSON.stringify({
|
||||
...(params.numInferenceSteps !== null && {
|
||||
num_inference_steps: params.numInferenceSteps,
|
||||
}),
|
||||
...(params.guidance !== null && { guidance: params.guidance }),
|
||||
...(params.negativePrompt !== null &&
|
||||
params.negativePrompt.trim() !== "" && {
|
||||
negative_prompt: params.negativePrompt,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -63,7 +63,6 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
|
||||
@@ -272,7 +272,6 @@ class AdvancedImageParams(BaseModel):
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
|
||||
negative_prompt: str | None = None
|
||||
num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
|
||||
|
||||
class ImageGenerationTaskParams(BaseModel):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -22,7 +23,7 @@ class ImageModelConfig(BaseModel):
|
||||
block_configs: tuple[TransformerBlockConfig, ...]
|
||||
|
||||
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
|
||||
num_sync_steps: int # Number of sync steps for distributed inference
|
||||
num_sync_steps_factor: float # Fraction of steps for sync phase
|
||||
|
||||
guidance_scale: float | None = None # None or <= 1.0 disables CFG
|
||||
|
||||
@@ -44,3 +45,6 @@ class ImageModelConfig(BaseModel):
|
||||
|
||||
def get_steps_for_quality(self, quality: str) -> int:
|
||||
return self.default_steps[quality]
|
||||
|
||||
def get_num_sync_steps(self, steps: int) -> int:
|
||||
return ceil(steps * self.num_sync_steps_factor)
|
||||
|
||||
@@ -150,10 +150,7 @@ class DistributedImageModel:
|
||||
guidance=guidance_override if guidance_override is not None else 4.0,
|
||||
)
|
||||
|
||||
if advanced_params is not None and advanced_params.num_sync_steps is not None:
|
||||
num_sync_steps = advanced_params.num_sync_steps
|
||||
else:
|
||||
num_sync_steps = self._config.num_sync_steps
|
||||
num_sync_steps = self._config.get_num_sync_steps(steps)
|
||||
|
||||
for result in self._runner.generate_image(
|
||||
runtime_config=config,
|
||||
|
||||
@@ -15,7 +15,7 @@ FLUX_SCHNELL_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 1, "medium": 2, "high": 4},
|
||||
num_sync_steps=1,
|
||||
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
|
||||
)
|
||||
|
||||
|
||||
@@ -30,5 +30,5 @@ FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps=4,
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps=7,
|
||||
num_sync_steps_factor=0.25,
|
||||
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps=7,
|
||||
num_sync_steps_factor=0.25,
|
||||
guidance_scale=3.5,
|
||||
)
|
||||
|
||||
10
uv.lock
generated
10
uv.lock
generated
@@ -415,7 +415,7 @@ requires-dist = [
|
||||
{ name = "mflux", specifier = "==0.15.4" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.4" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.4" },
|
||||
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
@@ -1072,8 +1072,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.5"
|
||||
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -1083,6 +1083,10 @@ dependencies = [
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
|
||||
Reference in New Issue
Block a user