mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 11:11:45 -05:00
Compare commits
4 Commits
aiohttp
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4800713524 | ||
|
|
a0f4f36355 | ||
|
|
acb97127bf | ||
|
|
d90605f198 |
2
.github/workflows/pipeline.yml
vendored
2
.github/workflows/pipeline.yml
vendored
@@ -142,4 +142,4 @@ jobs:
|
||||
# Run pytest outside sandbox (needs GPU access for MLX)
|
||||
export HOME="$RUNNER_TEMP"
|
||||
export EXO_TESTS=1
|
||||
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
|
||||
@@ -10,6 +10,7 @@ PROJECT_ROOT = Path.cwd()
|
||||
SOURCE_ROOT = PROJECT_ROOT / "src"
|
||||
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
|
||||
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
|
||||
RESOURCES_DIR = PROJECT_ROOT / "resources"
|
||||
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
|
||||
|
||||
if not ENTRYPOINT.is_file():
|
||||
@@ -18,6 +19,9 @@ if not ENTRYPOINT.is_file():
|
||||
if not DASHBOARD_DIR.is_dir():
|
||||
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
|
||||
|
||||
if not RESOURCES_DIR.is_dir():
|
||||
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
|
||||
|
||||
if not EXO_SHARED_MODELS_DIR.is_dir():
|
||||
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
|
||||
|
||||
@@ -58,6 +62,7 @@ HIDDEN_IMPORTS = sorted(
|
||||
|
||||
DATAS: list[tuple[str, str]] = [
|
||||
(str(DASHBOARD_DIR), "dashboard"),
|
||||
(str(RESOURCES_DIR), "resources"),
|
||||
(str(MLX_LIB_DIR), "mlx/lib"),
|
||||
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
|
||||
]
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-Krea-dev.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-Krea-dev.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-dev-4bit.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-dev-4bit.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-dev-8bit.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-dev-8bit.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-dev.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-dev.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15470210592
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5945589280
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21415799872
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11891178560
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-schnell.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-schnell.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-schnell"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33306978432
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23782357120
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
35
resources/image_model_cards/exolabs--Qwen-Image-4bit.toml
Normal file
35
resources/image_model_cards/exolabs--Qwen-Image-4bit.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-4bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 10215200544
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
35
resources/image_model_cards/exolabs--Qwen-Image-8bit.toml
Normal file
35
resources/image_model_cards/exolabs--Qwen-Image-8bit.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-8bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 20430401088
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 10215200544
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 20430401088
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 40860802176
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
35
resources/image_model_cards/exolabs--Qwen-Image.toml
Normal file
35
resources/image_model_cards/exolabs--Qwen-Image.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 40860802176
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-8bit"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-bf16"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-4bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-6bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 620622774272
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Kimi-K2-Thinking"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Kimi-K2.5"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
n_layers = 16
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 729808896
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1863319552
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 3501195264
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 76799803392
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 4637851648
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 8954839040
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 16882073600
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-3bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 289910292480
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 579820584960
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 45644286500
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57657697020
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 68899327465
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 89357758772
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 157548627945
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 46976204800
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 47080074240
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
n_layers = 36
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 70652212224
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
n_layers = 24
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 12025908224
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 144383672320
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
@@ -60,10 +61,37 @@ class DownloadCoordinator:
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
self._test_internet_connection()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
|
||||
def _test_internet_connection(self) -> None:
|
||||
try:
|
||||
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
|
||||
self.shard_downloader.set_internet_connection(True)
|
||||
except OSError:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
logger.debug(
|
||||
f"Internet connectivity: {self.shard_downloader.internet_connection}"
|
||||
)
|
||||
|
||||
async def _check_internet_connection(self) -> None:
|
||||
first_connection = True
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Assume that internet connection is set to False on 443 errors.
|
||||
if self.shard_downloader.internet_connection:
|
||||
continue
|
||||
|
||||
self._test_internet_connection()
|
||||
|
||||
if first_connection and self.shard_downloader.internet_connection:
|
||||
first_connection = False
|
||||
self._tg.start_soon(self._emit_existing_download_progress)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_scope.cancel()
|
||||
@@ -241,7 +269,7 @@ class DownloadCoordinator:
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"DownloadCoordinator: Fetching and emitting existing download progress..."
|
||||
)
|
||||
async for (
|
||||
@@ -274,10 +302,10 @@ class DownloadCoordinator:
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"DownloadCoordinator: Done emitting existing download progress."
|
||||
)
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
await anyio.sleep(60)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DownloadCoordinator: Error emitting existing download progress: {e}"
|
||||
|
||||
@@ -49,6 +49,10 @@ class HuggingFaceAuthenticationError(Exception):
|
||||
"""Raised when HuggingFace returns 401/403 for a model download."""
|
||||
|
||||
|
||||
class HuggingFaceRateLimitError(Exception):
|
||||
"""429 Huggingface code"""
|
||||
|
||||
|
||||
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
|
||||
token = await get_hf_token()
|
||||
if status_code == 401 and token is None:
|
||||
@@ -154,49 +158,76 @@ async def seed_models(seed_dir: str | Path):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
_fetched_file_lists_this_session: set[str] = set()
|
||||
|
||||
|
||||
async def fetch_file_list_with_cache(
|
||||
model_id: ModelId, revision: str = "main", recursive: bool = False
|
||||
model_id: ModelId,
|
||||
revision: str = "main",
|
||||
recursive: bool = False,
|
||||
skip_internet: bool = False,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
) -> list[FileListEntry]:
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
cache_key = f"{model_id.normalize()}--{revision}"
|
||||
|
||||
if cache_key in _fetched_file_lists_this_session and await aios.path.exists(
|
||||
cache_file
|
||||
):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
|
||||
if skip_internet:
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
raise FileNotFoundError(
|
||||
f"No internet connection and no cached file list for {model_id}"
|
||||
)
|
||||
|
||||
# Always try fresh first
|
||||
try:
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
model_id,
|
||||
revision,
|
||||
recursive=recursive,
|
||||
on_connection_lost=on_connection_lost,
|
||||
)
|
||||
# Update cache with fresh data
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
|
||||
)
|
||||
_fetched_file_lists_this_session.add(cache_key)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
# Fetch failed - try cache fallback
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
# No cache available, propagate the error
|
||||
raise
|
||||
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
|
||||
|
||||
|
||||
async def fetch_file_list_with_retry(
|
||||
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||
model_id: ModelId,
|
||||
revision: str = "main",
|
||||
path: str = "",
|
||||
recursive: bool = False,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
) -> list[FileListEntry]:
|
||||
n_attempts = 30
|
||||
n_attempts = 3
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
on_connection_lost()
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
raise Exception(
|
||||
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
||||
)
|
||||
@@ -216,7 +247,11 @@ async def _fetch_file_list(
|
||||
if response.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(response.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
if response.status == 200:
|
||||
elif response.status == 429:
|
||||
raise HuggingFaceRateLimitError(
|
||||
f"Couldn't download {model_id} because of HuggingFace rate limit."
|
||||
)
|
||||
elif response.status == 200:
|
||||
data_json = await response.text()
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
|
||||
files: list[FileListEntry] = []
|
||||
@@ -249,7 +284,7 @@ def create_http_session(
|
||||
else:
|
||||
total_timeout = 1800
|
||||
connect_timeout = 60
|
||||
sock_read_timeout = 1800
|
||||
sock_read_timeout = 60
|
||||
sock_connect_timeout = 60
|
||||
|
||||
ssl_context = ssl.create_default_context(
|
||||
@@ -324,8 +359,9 @@ async def download_file_with_retry(
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
) -> Path:
|
||||
n_attempts = 30
|
||||
n_attempts = 3
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
@@ -333,14 +369,19 @@ async def download_file_with_retry(
|
||||
)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||
except HuggingFaceRateLimitError as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
except Exception as e:
|
||||
on_connection_lost()
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
break
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
@@ -542,7 +583,9 @@ async def download_shard(
|
||||
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
max_parallel_downloads: int = 8,
|
||||
skip_download: bool = False,
|
||||
skip_internet: bool = False,
|
||||
allow_patterns: list[str] | None = None,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
if not skip_download:
|
||||
logger.debug(f"Downloading {shard.model_card.model_id=}")
|
||||
@@ -562,7 +605,11 @@ async def download_shard(
|
||||
|
||||
all_start_time = time.time()
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
shard.model_card.model_id, revision, recursive=True
|
||||
shard.model_card.model_id,
|
||||
revision,
|
||||
recursive=True,
|
||||
skip_internet=skip_internet,
|
||||
on_connection_lost=on_connection_lost,
|
||||
)
|
||||
filtered_file_list = list(
|
||||
filter_repo_objects(
|
||||
@@ -672,6 +719,7 @@ async def download_shard(
|
||||
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
|
||||
file, curr_bytes, total_bytes, is_renamed
|
||||
),
|
||||
on_connection_lost=on_connection_lost,
|
||||
)
|
||||
|
||||
if not skip_download:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from asyncio import create_task
|
||||
from collections.abc import Awaitable
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
@@ -7,7 +8,7 @@ from loguru import logger
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -49,6 +50,10 @@ class SingletonShardDownloader(ShardDownloader):
|
||||
self.shard_downloader = shard_downloader
|
||||
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
|
||||
|
||||
def set_internet_connection(self, value: bool) -> None:
|
||||
self.internet_connection = value
|
||||
self.shard_downloader.set_internet_connection(value)
|
||||
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
@@ -85,6 +90,10 @@ class CachedShardDownloader(ShardDownloader):
|
||||
self.shard_downloader = shard_downloader
|
||||
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
|
||||
|
||||
def set_internet_connection(self, value: bool) -> None:
|
||||
self.internet_connection = value
|
||||
self.shard_downloader.set_internet_connection(value)
|
||||
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
@@ -142,6 +151,8 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
self.on_progress_wrapper,
|
||||
max_parallel_downloads=self.max_parallel_downloads,
|
||||
allow_patterns=allow_patterns,
|
||||
skip_internet=not self.internet_connection,
|
||||
on_connection_lost=lambda: self.set_internet_connection(False),
|
||||
)
|
||||
return target_dir
|
||||
|
||||
@@ -154,13 +165,24 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
||||
shard = await build_full_shard(model_id)
|
||||
return await download_shard(
|
||||
shard, self.on_progress_wrapper, skip_download=True
|
||||
shard,
|
||||
self.on_progress_wrapper,
|
||||
skip_download=True,
|
||||
skip_internet=not self.internet_connection,
|
||||
on_connection_lost=lambda: self.set_internet_connection(False),
|
||||
)
|
||||
|
||||
# Kick off download status coroutines concurrently
|
||||
semaphore = asyncio.Semaphore(self.max_parallel_downloads)
|
||||
|
||||
async def download_with_semaphore(
|
||||
model_card: ModelCard,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
async with semaphore:
|
||||
return await _status_for_model(model_card.model_id)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(_status_for_model(model_card.model_id))
|
||||
for model_card in MODEL_CARDS.values()
|
||||
create_task(download_with_semaphore(model_card))
|
||||
for model_card in await get_model_cards()
|
||||
]
|
||||
|
||||
for task in asyncio.as_completed(tasks):
|
||||
|
||||
@@ -16,6 +16,11 @@ from exo.shared.types.worker.shards import (
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
class ShardDownloader(ABC):
|
||||
internet_connection: bool = False
|
||||
|
||||
def set_internet_connection(self, value: bool) -> None:
|
||||
self.internet_connection = value
|
||||
|
||||
@abstractmethod
|
||||
async def ensure_shard(
|
||||
self, shard: ShardMetadata, config_only: bool = False
|
||||
|
||||
@@ -66,7 +66,9 @@ def chat_request_to_text_generation(
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages if input_messages else "",
|
||||
input=input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")],
|
||||
instructions=instructions,
|
||||
max_output_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
|
||||
@@ -141,7 +141,9 @@ def claude_request_to_text_generation(
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages if input_messages else "",
|
||||
input=input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")],
|
||||
instructions=instructions,
|
||||
max_output_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
|
||||
@@ -43,10 +43,10 @@ def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||
def responses_request_to_text_generation(
|
||||
request: ResponsesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
input_value: str | list[InputMessage]
|
||||
input_value: list[InputMessage]
|
||||
built_chat_template: list[dict[str, Any]] | None = None
|
||||
if isinstance(request.input, str):
|
||||
input_value = request.input
|
||||
input_value = [InputMessage(role="user", content=request.input)]
|
||||
else:
|
||||
input_messages: list[InputMessage] = []
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
@@ -95,7 +95,11 @@ def responses_request_to_text_generation(
|
||||
}
|
||||
)
|
||||
|
||||
input_value = input_messages if input_messages else ""
|
||||
input_value = (
|
||||
input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")]
|
||||
)
|
||||
built_chat_template = chat_template_messages if chat_template_messages else None
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
|
||||
@@ -40,6 +40,7 @@ from exo.master.image_store import ImageStore
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import (
|
||||
DASHBOARD_DIR,
|
||||
EXO_IMAGE_CACHE_DIR,
|
||||
EXO_MAX_CHUNK_SIZE,
|
||||
EXO_TRACING_CACHE_DIR,
|
||||
@@ -47,9 +48,9 @@ from exo.shared.constants import (
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import (
|
||||
MODEL_CARDS,
|
||||
ModelCard,
|
||||
ModelId,
|
||||
get_model_cards,
|
||||
)
|
||||
from exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file
|
||||
from exo.shared.types.api import (
|
||||
@@ -138,7 +139,6 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
@@ -146,18 +146,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
return f"image/{image_format or 'png'}"
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
return card
|
||||
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -204,7 +192,7 @@ class API:
|
||||
self.app.mount(
|
||||
"/",
|
||||
StaticFiles(
|
||||
directory=find_dashboard(),
|
||||
directory=DASHBOARD_DIR,
|
||||
html=True,
|
||||
),
|
||||
name="dashboard",
|
||||
@@ -381,10 +369,7 @@ class API:
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
if not cards:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
|
||||
|
||||
model_card = await ModelCard.load(model_id)
|
||||
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
|
||||
for sharding in (Sharding.Pipeline, Sharding.Tensor):
|
||||
for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
|
||||
@@ -399,96 +384,93 @@ class API:
|
||||
# TODO: PDD
|
||||
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
|
||||
|
||||
for model_card in cards:
|
||||
for sharding, instance_meta, min_nodes in instance_combinations:
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
required_nodes=required_nodes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
current_ids = set(self.state.instances.keys())
|
||||
new_instances = [
|
||||
instance
|
||||
for instance_id, instance in placements.items()
|
||||
if instance_id not in current_ids
|
||||
]
|
||||
|
||||
if len(new_instances) != 1:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
instance = new_instances[0]
|
||||
shard_assignments = instance.shard_assignments
|
||||
placement_node_ids = list(shard_assignments.node_to_runner.keys())
|
||||
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if placement_node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
per_node = total_bytes // len(placement_node_ids)
|
||||
remainder = total_bytes % len(placement_node_ids)
|
||||
for index, node_id in enumerate(
|
||||
sorted(placement_node_ids, key=str)
|
||||
):
|
||||
extra = 1 if index < remainder else 0
|
||||
memory_delta_by_node[str(node_id)] = per_node + extra
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
for sharding, instance_meta, min_nodes in instance_combinations:
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
required_nodes=required_nodes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
current_ids = set(self.state.instances.keys())
|
||||
new_instances = [
|
||||
instance
|
||||
for instance_id, instance in placements.items()
|
||||
if instance_id not in current_ids
|
||||
]
|
||||
|
||||
if len(new_instances) != 1:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
instance = new_instances[0]
|
||||
shard_assignments = instance.shard_assignments
|
||||
placement_node_ids = list(shard_assignments.node_to_runner.keys())
|
||||
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if placement_node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
per_node = total_bytes // len(placement_node_ids)
|
||||
remainder = total_bytes % len(placement_node_ids)
|
||||
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
|
||||
extra = 1 if index < remainder else 0
|
||||
memory_delta_by_node[str(node_id)] = per_node + extra
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
|
||||
return PlacementPreviewResponse(previews=previews)
|
||||
|
||||
@@ -652,23 +634,21 @@ class API:
|
||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model: ModelId) -> ModelId:
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
"""Validate a text model exists and return the resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
"""
|
||||
model_card = await resolve_model_card(model)
|
||||
resolved = model_card.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved
|
||||
instance.shard_assignments.model_id == model_id
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(resolved)
|
||||
await self._trigger_notify_user_to_download_model(model_id)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No instance found for model {resolved}",
|
||||
detail=f"No instance found for model {model_id}",
|
||||
)
|
||||
return resolved
|
||||
return model_id
|
||||
|
||||
async def _validate_image_model(self, model: ModelId) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
@@ -1237,7 +1217,7 @@ class API:
|
||||
supports_tensor=card.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
for card in await get_model_cards()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from exo.shared.types.profiling import (
|
||||
)
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceMeta,
|
||||
MlxRingInstance,
|
||||
@@ -136,7 +136,9 @@ async def test_master():
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId("llama-3.2-1b"),
|
||||
input="Hello, how are you?",
|
||||
input=[
|
||||
InputMessage(role="user", content="Hello, how are you?")
|
||||
],
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -189,7 +191,7 @@ async def test_master():
|
||||
assert isinstance(events[2].event.task, TextGenerationTask)
|
||||
assert events[2].event.task.task_params == TextGenerationTaskParams(
|
||||
model=ModelId("llama-3.2-1b"),
|
||||
input="Hello, how are you?",
|
||||
input=[InputMessage(role="user", content="Hello, how are you?")],
|
||||
)
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
@@ -2,6 +2,8 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from exo.utils.dashboard_path import find_dashboard, find_resources
|
||||
|
||||
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
|
||||
|
||||
|
||||
@@ -31,6 +33,14 @@ EXO_MODELS_DIR = (
|
||||
if _EXO_MODELS_DIR_ENV is None
|
||||
else Path.home() / _EXO_MODELS_DIR_ENV
|
||||
)
|
||||
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
|
||||
RESOURCES_DIR = (
|
||||
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
|
||||
)
|
||||
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
|
||||
DASHBOARD_DIR = (
|
||||
find_dashboard() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
|
||||
)
|
||||
|
||||
# Log files (data/logs or cache)
|
||||
EXO_LOG = EXO_CACHE_HOME / "exo.log"
|
||||
|
||||
@@ -12,16 +12,42 @@ from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PositiveInt,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from tomlkit.exceptions import TOMLKitError
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS, RESOURCES_DIR
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
_card_cache: dict[str, "ModelCard"] = {}
|
||||
# kinda ugly...
|
||||
# TODO: load search path from config.toml
|
||||
_csp = [Path(RESOURCES_DIR) / "inference_model_cards"]
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
_csp.append(Path(RESOURCES_DIR) / "image_model_cards")
|
||||
|
||||
CARD_SEARCH_PATH = _csp
|
||||
|
||||
_card_cache: dict[ModelId, "ModelCard"] = {}
|
||||
|
||||
|
||||
async def _refresh_card_cache():
|
||||
for path in CARD_SEARCH_PATH:
|
||||
async for toml_file in path.rglob("*.toml"):
|
||||
try:
|
||||
card = await ModelCard.load_from_path(toml_file)
|
||||
_card_cache[card.model_id] = card
|
||||
except (ValidationError, TOMLKitError):
|
||||
pass
|
||||
|
||||
|
||||
async def get_model_cards() -> list["ModelCard"]:
|
||||
if len(_card_cache) == 0:
|
||||
await _refresh_card_cache()
|
||||
return list(_card_cache.values())
|
||||
|
||||
|
||||
class ModelTask(str, Enum):
|
||||
@@ -55,31 +81,36 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
async def save(self, path: Path) -> None:
|
||||
async with await open_file(path, "w") as f:
|
||||
py = self.model_dump()
|
||||
py = self.model_dump(exclude_none=True)
|
||||
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
|
||||
await f.write(data)
|
||||
|
||||
async def save_to_default_path(self):
|
||||
await self.save(Path(RESOURCES_DIR) / (self.model_id.normalize() + ".toml"))
|
||||
|
||||
@staticmethod
|
||||
async def load_from_path(path: Path) -> "ModelCard":
|
||||
async with await open_file(path, "r") as f:
|
||||
py = tomlkit.loads(await f.read())
|
||||
return ModelCard.model_validate(py)
|
||||
|
||||
# Is it okay that model card.load defaults to network access if the card doesn't exist? do we want to be more explicit here?
|
||||
@staticmethod
|
||||
async def load(model_id: ModelId) -> "ModelCard":
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == model_id:
|
||||
return card
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
@staticmethod
|
||||
async def from_hf(model_id: ModelId) -> "ModelCard":
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
if model_id not in _card_cache:
|
||||
await _refresh_card_cache()
|
||||
if (mc := _card_cache.get(model_id)) is not None:
|
||||
return mc
|
||||
config_data = await get_config_data(model_id)
|
||||
|
||||
return await ModelCard.fetch_from_hf(model_id)
|
||||
|
||||
@staticmethod
|
||||
async def fetch_from_hf(model_id: ModelId) -> "ModelCard":
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
# TODO: failure if files do not exist
|
||||
config_data = await fetch_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
mem_size_bytes = await fetch_safetensors_size(model_id)
|
||||
|
||||
mc = ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
@@ -89,544 +120,13 @@ class ModelCard(CamelCaseModel):
|
||||
supports_tensor=config_data.supports_tensor,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
await mc.save_to_default_path()
|
||||
_card_cache[model_id] = mc
|
||||
return mc
|
||||
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2.5": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2.5"),
|
||||
storage_size=Memory.from_gb(617),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.2
|
||||
"llama-3.2-1b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.3
|
||||
"llama-3.3-70b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# qwen3
|
||||
"qwen3-0.6b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
storage_size=Memory.from_mb(44900),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.7 flash
|
||||
"glm-4.7-flash-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
|
||||
storage_size=Memory.from_gb(18),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-5bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
|
||||
storage_size=Memory.from_gb(21),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-6bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
|
||||
storage_size=Memory.from_gb(25),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
|
||||
storage_size=Memory.from_gb(32),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("exolabs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("exolabs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-krea-dev": ModelCard(
|
||||
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
|
||||
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("exolabs/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _generate_image_model_quant_variants(
|
||||
# TODO: quantizing and dynamically creating model cards
|
||||
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
@@ -706,15 +206,6 @@ def _generate_image_model_quant_variants(
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
@@ -767,7 +258,7 @@ class ConfigData(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
async def fetch_config_data(model_id: ModelId) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
@@ -789,7 +280,7 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
|
||||
async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
async def fetch_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
|
||||
@@ -28,7 +28,7 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
|
||||
"""
|
||||
|
||||
model: ModelId
|
||||
input: str | list[InputMessage]
|
||||
input: list[InputMessage]
|
||||
instructions: str | None = None
|
||||
max_output_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -1,31 +1,45 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
|
||||
def find_resources() -> Path:
|
||||
resources = _find_resources_in_repo() or _find_resources_in_bundle()
|
||||
if resources is None:
|
||||
raise FileNotFoundError(
|
||||
"Unable to locate resources. Did you clone the repo properly?"
|
||||
)
|
||||
return resources
|
||||
|
||||
|
||||
def _find_resources_in_repo() -> Path | None:
|
||||
current_module = Path(__file__).resolve()
|
||||
for parent in current_module.parents:
|
||||
build = parent / "resources"
|
||||
if build.is_dir():
|
||||
return build
|
||||
return None
|
||||
|
||||
|
||||
def _find_resources_in_bundle() -> Path | None:
|
||||
frozen_root = cast(str | None, getattr(sys, "_MEIPASS", None))
|
||||
if frozen_root is None:
|
||||
return None
|
||||
candidate = Path(frozen_root) / "resources"
|
||||
if candidate.is_dir():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def find_dashboard() -> Path:
|
||||
dashboard = (
|
||||
_find_dashboard_in_env()
|
||||
or _find_dashboard_in_repo()
|
||||
or _find_dashboard_in_bundle()
|
||||
)
|
||||
dashboard = _find_dashboard_in_repo() or _find_dashboard_in_bundle()
|
||||
if not dashboard:
|
||||
raise FileNotFoundError(
|
||||
"Unable to locate dashboard assets - make sure the dashboard has been built, or export DASHBOARD_DIR if you've built the dashboard elsewhere."
|
||||
"Unable to locate dashboard assets - you probably forgot to run `cd dashboard && npm install && npm run build && cd ..`"
|
||||
)
|
||||
return dashboard
|
||||
|
||||
|
||||
def _find_dashboard_in_env() -> Path | None:
|
||||
env = os.environ.get("DASHBOARD_DIR")
|
||||
if not env:
|
||||
return None
|
||||
resolved_env = Path(env).expanduser().resolve()
|
||||
|
||||
return resolved_env
|
||||
|
||||
|
||||
def _find_dashboard_in_repo() -> Path | None:
|
||||
current_module = Path(__file__).resolve()
|
||||
for parent in current_module.parents:
|
||||
|
||||
@@ -17,7 +17,7 @@ from exo.shared.types.api import (
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ def warmup_inference(
|
||||
tokenizer=tokenizer,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId(""),
|
||||
input=content,
|
||||
input=[InputMessage(role="user", content=content)],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -436,16 +436,11 @@ def apply_chat_template(
|
||||
)
|
||||
|
||||
# Convert input to messages
|
||||
if isinstance(task_params.input, str):
|
||||
# Simple string input becomes a single user message
|
||||
formatted_messages.append({"role": "user", "content": task_params.input})
|
||||
else:
|
||||
# List of InputMessage
|
||||
for msg in task_params.input:
|
||||
if not msg.content:
|
||||
logger.warning("Received message with empty content, skipping")
|
||||
continue
|
||||
formatted_messages.append({"role": msg.role, "content": msg.content})
|
||||
for msg in task_params.input:
|
||||
if not msg.content:
|
||||
logger.warning("Received message with empty content, skipping")
|
||||
continue
|
||||
formatted_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
|
||||
@@ -918,15 +918,10 @@ def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
|
||||
|
||||
Extracts the first user input text and checks for debug triggers.
|
||||
"""
|
||||
prompt: str
|
||||
if isinstance(task_params.input, str):
|
||||
prompt = task_params.input
|
||||
else:
|
||||
# List of InputMessage - get first message content
|
||||
if len(task_params.input) == 0:
|
||||
logger.debug("Empty message list in debug prompt check")
|
||||
return
|
||||
prompt = task_params.input[0].content
|
||||
if len(task_params.input) == 0:
|
||||
logger.debug("Empty message list in debug prompt check")
|
||||
return
|
||||
prompt = task_params.input[0].content
|
||||
|
||||
if not prompt:
|
||||
return
|
||||
|
||||
@@ -14,7 +14,7 @@ from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
@@ -114,7 +114,7 @@ def run_gpt_oss_pipeline_device(
|
||||
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=prompt_text,
|
||||
input=[InputMessage(role="user", content=prompt_text)],
|
||||
max_output_tokens=max_tokens,
|
||||
)
|
||||
|
||||
@@ -182,7 +182,7 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
|
||||
task = TextGenerationTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
input=prompt_text,
|
||||
input=[InputMessage(role="user", content=prompt_text)],
|
||||
max_output_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from exo.download.download_utils import (
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
@@ -76,7 +76,7 @@ def get_test_models() -> list[ModelCard]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, ModelCard] = {}
|
||||
for card in MODEL_CARDS.values():
|
||||
for card in asyncio.run(get_model_cards()):
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = card.model_id.short().split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
@@ -296,7 +296,7 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
|
||||
async def test_kimi_tokenizer_specifically():
|
||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||
kimi_models = [
|
||||
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
|
||||
card for card in await get_model_cards() if "kimi" in card.model_id.lower()
|
||||
]
|
||||
|
||||
if not kimi_models:
|
||||
@@ -343,7 +343,7 @@ async def test_kimi_tokenizer_specifically():
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
glm_model_cards = [
|
||||
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
|
||||
card for card in await get_model_cards() if "glm" in card.model_id.lower()
|
||||
]
|
||||
|
||||
if not glm_model_cards:
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import cast
|
||||
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus, TextGeneration
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
@@ -59,7 +59,9 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_status=TaskStatus.Pending,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
|
||||
),
|
||||
)
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -106,7 +108,9 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_status=TaskStatus.Pending,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
|
||||
),
|
||||
)
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -150,7 +154,9 @@ def test_plan_does_not_forward_tasks_for_other_instances():
|
||||
instance_id=other_instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
|
||||
),
|
||||
)
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -198,7 +204,9 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_status=TaskStatus.Complete,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
|
||||
),
|
||||
)
|
||||
|
||||
other_task_id = TaskId("other-task")
|
||||
|
||||
@@ -22,7 +22,7 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -86,7 +86,7 @@ SHUTDOWN_TASK = Shutdown(
|
||||
|
||||
CHAT_PARAMS = TextGenerationTaskParams(
|
||||
model=MODEL_A_ID,
|
||||
input="hello",
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_output_tokens=4,
|
||||
temperature=0.0,
|
||||
|
||||
@@ -10,7 +10,7 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
@@ -23,7 +23,7 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
@@ -114,13 +114,13 @@ async def run_test(test: Tests):
|
||||
|
||||
instances: list[Instance] = []
|
||||
if test.kind in ["ring", "both"]:
|
||||
i = ring_instance(test, hn)
|
||||
i = await ring_instance(test, hn)
|
||||
if i is None:
|
||||
yield "no model found"
|
||||
return
|
||||
instances.append(i)
|
||||
if test.kind in ["rdma", "both"]:
|
||||
i = jaccl_instance(test)
|
||||
if test.kind in ["jaccl", "both"]:
|
||||
i = await jaccl_instance(test)
|
||||
if i is None:
|
||||
yield "no model found"
|
||||
return
|
||||
@@ -145,7 +145,7 @@ async def run_test(test: Tests):
|
||||
return StreamingResponse(run())
|
||||
|
||||
|
||||
def ring_instance(test: Tests, hn: str) -> Instance | None:
|
||||
async def ring_instance(test: Tests, hn: str) -> Instance | None:
|
||||
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
@@ -158,11 +158,7 @@ def ring_instance(test: Tests, hn: str) -> Instance | None:
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
|
||||
)
|
||||
if card is None:
|
||||
return None
|
||||
card = await ModelCard.load(test.model_id)
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52417,
|
||||
@@ -200,7 +196,11 @@ async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=test.model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
input="What is the capital of France?",
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
)
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
@@ -230,12 +230,8 @@ async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
|
||||
return []
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
|
||||
card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
|
||||
)
|
||||
if card is None:
|
||||
return None
|
||||
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
|
||||
card = await ModelCard.load(test.model_id)
|
||||
world_size = len(test.devs)
|
||||
assert test.ibv_devs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user