Compare commits

...

5 Commits

Author SHA1 Message Date
Alex Cheema
5ee257a13e feat: add tensor parallelism support for Step 3.5 Flash
Add Step3p5ShardingStrategy to auto_parallel.py following the
DeepSeek pattern (shared expert + routed experts). Shard attention
q/k/v/o projections across devices and MoE expert weights in-place
with all-reduce synchronization.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:52:28 -08:00
Alex Cheema
24511ab7cb feat: add Step 3.5 Flash model cards and update mlx-lm
Update mlx-lm to v0.30.6 which includes Step 3.5 Flash support
(ml-explore/mlx-lm#836). Add model cards for the 4bit, 6bit, and 8bit
quantizations from mlx-community.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:33:57 -08:00
Evan Quiney
d90605f198 migrate model cards to .toml files (#1354) 2026-02-03 12:32:06 +00:00
Evan Quiney
f400b4d7c5 fix InstanceViewModel.swift (#1359)
wasn't caught when we merged the API changes
2026-02-02 18:43:27 +00:00
Evan Quiney
d97bca88e6 improve distributed testing (#1300)
Our distributed test now does a full query cycle for every model loaded
onto the relevant machine. This will help find bugs early, as it already
has found one with Qwen3 Next! I didn't write down what the error was
though. Gooooooood luck with that!

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-02 18:25:39 +00:00
73 changed files with 1477 additions and 883 deletions

View File

@@ -142,4 +142,4 @@ jobs:
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

View File

@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
let promptPreview: String?
let errorMessage: String?
let subtitle: String?
let parameters: ChatCompletionTaskParameters?
let parameters: TextGenerationTaskParameters?
var title: String {
switch kind {

View File

@@ -10,6 +10,7 @@ PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
RESOURCES_DIR = PROJECT_ROOT / "resources"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
@@ -18,6 +19,9 @@ if not ENTRYPOINT.is_file():
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not RESOURCES_DIR.is_dir():
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
@@ -58,6 +62,7 @@ HIDDEN_IMPORTS = sorted(
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(RESOURCES_DIR), "resources"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-Krea-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15470210592
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5945589280
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21415799872
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11891178560
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,45 @@
model_id = "exolabs/FLUX.1-schnell"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33306978432
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23782357120
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image-Edit-2509"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,35 @@
model_id = "exolabs/Qwen-Image"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/DeepSeek-V3.1-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 405874409472

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/DeepSeek-V3.1-8bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 765577920512

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.5-Air-8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 122406567936

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.5-Air-bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 229780750336

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 198556925568

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 286737579648

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-8bit-gs32"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 396963397248

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-Flash-4bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 19327352832

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-Flash-5bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 22548578304

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-Flash-6bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 26843545600

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/GLM-4.7-Flash-8bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 34359738368

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 620622774272

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Kimi-K2-Thinking"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 706522120192

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Kimi-K2.5"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 662498705408

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 729808896

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 1863319552

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 3501195264

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 76799803392

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 4637851648

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 8954839040

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 16882073600

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/MiniMax-M2.1-3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 100086644736

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/MiniMax-M2.1-8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 242986745856

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-0.6B-4bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 342884352

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-0.6B-8bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 698351616

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 141733920768

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 268435456000

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 17612931072

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 33279705088

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 289910292480

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 579820584960

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 46976204800

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 47080074240

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Step-3.5-Flash-4bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 114572190076

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Step-3.5-Flash-6bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 159039627774

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Step-3.5-Flash-8Bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 209082699847

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 70652212224

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 12025908224

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 144383672320

View File

@@ -7,7 +7,7 @@ from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -160,7 +160,7 @@ class ResumableShardDownloader(ShardDownloader):
# Kick off download status coroutines concurrently
tasks = [
asyncio.create_task(_status_for_model(model_card.model_id))
for model_card in MODEL_CARDS.values()
for model_card in await get_model_cards()
]
for task in asyncio.as_completed(tasks):

View File

@@ -40,6 +40,7 @@ from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import (
DASHBOARD_DIR,
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
EXO_TRACING_CACHE_DIR,
@@ -47,9 +48,9 @@ from exo.shared.constants import (
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
get_model_cards,
)
from exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file
from exo.shared.types.api import (
@@ -138,7 +139,6 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
@@ -146,18 +146,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -204,7 +192,7 @@ class API:
self.app.mount(
"/",
StaticFiles(
directory=find_dashboard(),
directory=DASHBOARD_DIR,
html=True,
),
name="dashboard",
@@ -381,10 +369,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
model_card = await ModelCard.load(model_id)
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
for sharding in (Sharding.Pipeline, Sharding.Tensor):
for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
@@ -399,96 +384,93 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
instance=None,
error=str(exc),
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
return PlacementPreviewResponse(previews=previews)
@@ -652,23 +634,21 @@ class API:
response = await self._collect_text_generation_with_stats(command.command_id)
return response
async def _resolve_and_validate_text_model(self, model: ModelId) -> ModelId:
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(model)
resolved = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved
instance.shard_assignments.model_id == model_id
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved)
await self._trigger_notify_user_to_download_model(model_id)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {resolved}",
detail=f"No instance found for model {model_id}",
)
return resolved
return model_id
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
@@ -1237,7 +1217,7 @@ class API:
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
for card in await get_model_cards()
]
)

View File

@@ -2,6 +2,8 @@ import os
import sys
from pathlib import Path
from exo.utils.dashboard_path import find_dashboard, find_resources
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
@@ -31,6 +33,14 @@ EXO_MODELS_DIR = (
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
)
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
# Log files (data/logs or cache)
EXO_LOG = EXO_CACHE_HOME / "exo.log"

View File

@@ -12,16 +12,42 @@ from pydantic import (
BaseModel,
Field,
PositiveInt,
ValidationError,
field_validator,
model_validator,
)
from tomlkit.exceptions import TOMLKitError
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS, RESOURCES_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
_card_cache: dict[str, "ModelCard"] = {}
# kinda ugly...
# TODO: load search path from config.toml
_csp = [Path(RESOURCES_DIR) / "inference_model_cards"]
if EXO_ENABLE_IMAGE_MODELS:
_csp.append(Path(RESOURCES_DIR) / "image_model_cards")
CARD_SEARCH_PATH = _csp
_card_cache: dict[ModelId, "ModelCard"] = {}
async def _refresh_card_cache():
for path in CARD_SEARCH_PATH:
async for toml_file in path.rglob("*.toml"):
try:
card = await ModelCard.load_from_path(toml_file)
_card_cache[card.model_id] = card
except (ValidationError, TOMLKitError):
pass
async def get_model_cards() -> list["ModelCard"]:
if len(_card_cache) == 0:
await _refresh_card_cache()
return list(_card_cache.values())
class ModelTask(str, Enum):
@@ -55,28 +81,33 @@ class ModelCard(CamelCaseModel):
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
py = self.model_dump(exclude_none=True)
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
async def save_to_default_path(self):
await self.save(Path(RESOURCES_DIR) / (self.model_id.normalize() + ".toml"))
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
# Is it okay that model card.load defaults to network access if the card doesn't exist? do we want to be more explicit here?
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
for card in MODEL_CARDS.values():
if card.model_id == model_id:
return card
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if model_id not in _card_cache:
await _refresh_card_cache()
if (mc := _card_cache.get(model_id)) is not None:
return mc
return await ModelCard.fetch_from_hf(model_id)
@staticmethod
async def fetch_from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
# TODO: failure if files do not exist
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
@@ -89,544 +120,13 @@ class ModelCard(CamelCaseModel):
supports_tensor=config_data.supports_tensor,
tasks=[ModelTask.TextGeneration],
)
await mc.save_to_default_path()
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"deepseek-v3.1-8bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2-thinking": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2.5": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2.5"),
storage_size=Memory.from_gb(617),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.1
"llama-3.1-8b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-8bit": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-bf16": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-70b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.2
"llama-3.2-1b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.3
"llama-3.3-70b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-fp16": ModelCard(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# qwen3
"qwen3-0.6b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-0.6b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(44900),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"glm-4.5-air-bf16": ModelCard(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-8bit-gs32": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"minimax-m2.1-3bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
}
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("exolabs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("exolabs/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
}
def _generate_image_model_quant_variants(
# TODO: quantizing and dynamically creating model cards
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
@@ -706,15 +206,6 @@ def _generate_image_model_quant_variants(
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
@@ -742,6 +233,7 @@ class ConfigData(BaseModel):
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
["Step3p5ForCausalLM"],
]
@model_validator(mode="before")

View File

@@ -1,31 +1,45 @@
import os
import sys
from pathlib import Path
from typing import cast
def find_resources() -> Path:
resources = _find_resources_in_repo() or _find_resources_in_bundle()
if resources is None:
raise FileNotFoundError(
"Unable to locate resources. Did you clone the repo properly?"
)
return resources
def _find_resources_in_repo() -> Path | None:
current_module = Path(__file__).resolve()
for parent in current_module.parents:
build = parent / "resources"
if build.is_dir():
return build
return None
def _find_resources_in_bundle() -> Path | None:
frozen_root = cast(str | None, getattr(sys, "_MEIPASS", None))
if frozen_root is None:
return None
candidate = Path(frozen_root) / "resources"
if candidate.is_dir():
return candidate
return None
def find_dashboard() -> Path:
dashboard = (
_find_dashboard_in_env()
or _find_dashboard_in_repo()
or _find_dashboard_in_bundle()
)
dashboard = _find_dashboard_in_repo() or _find_dashboard_in_bundle()
if not dashboard:
raise FileNotFoundError(
"Unable to locate dashboard assets - make sure the dashboard has been built, or export DASHBOARD_DIR if you've built the dashboard elsewhere."
"Unable to locate dashboard assets - you probably forgot to run `cd dashboard && npm install && npm run build && cd ..`"
)
return dashboard
def _find_dashboard_in_env() -> Path | None:
env = os.environ.get("DASHBOARD_DIR")
if not env:
return None
resolved_env = Path(env).expanduser().resolve()
return resolved_env
def _find_dashboard_in_repo() -> Path | None:
current_module = Path(__file__).resolve()
for parent in current_module.parents:

View File

@@ -31,6 +31,8 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step3p5Model
from mlx_lm.models.step3p5 import Step3p5MLP
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
@@ -380,6 +382,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Step3p5Model):
tensor_parallel_sharding_strategy = Step3p5ShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
@@ -774,3 +784,57 @@ class ShardedGptOssMoE(CustomMlxLayer):
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class Step3p5ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Step3p5Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N # pyright: ignore[reportUnknownMemberType]
layer.self_attn.num_kv_heads //= self.N # pyright: ignore[reportUnknownMemberType]
if isinstance(layer.mlp, Step3p5MLP):
# Dense MLP layer
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
else:
# MoE layer: shared expert + routed experts
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedStep3p5MoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class ShardedStep3p5MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -16,7 +16,7 @@ from exo.download.download_utils import (
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,
@@ -76,7 +76,7 @@ def get_test_models() -> list[ModelCard]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
for card in asyncio.run(get_model_cards()):
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
@@ -296,7 +296,7 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
card for card in await get_model_cards() if "kimi" in card.model_id.lower()
]
if not kimi_models:
@@ -343,7 +343,7 @@ async def test_kimi_tokenizer_specifically():
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
card for card in await get_model_cards() if "glm" in card.model_id.lower()
]
if not glm_model_cards:

View File

@@ -1,25 +1,20 @@
import multiprocessing as mp
import socket
import time
import typing
from typing import Literal
import anyio
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.responses import Response, StreamingResponse
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -36,9 +31,14 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.channels import channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
@@ -46,36 +46,37 @@ from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
ibv_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "jaccl", "both"]
mp.set_start_method("spawn", force=True)
logger_setup(None)
iid = InstanceId("im testing here")
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
cfg.bind = "0.0.0.0:52414"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
ev = anyio.Event()
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
app.post("/run_test")(run_test)
app.post("/kill")(lambda: kill(ev))
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: shutdown.wait(),
shutdown_trigger=lambda: ev.wait(),
)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
def kill(ev: anyio.Event):
ev.set()
return Response(status_code=204)
async def tb_detection():
@@ -87,29 +88,19 @@ async def tb_detection():
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
async def run_test(test: Tests):
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
@@ -117,31 +108,63 @@ async def ring_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = await ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["jaccl", "both"]:
i = await jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
async def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
if i + 1 < len(test.devs):
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
hbn[i] = Host(ip="0.0.0.0", port=52417)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
card = await ModelCard.load(test.model_id)
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
ephemeral_port=52417,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -163,113 +186,75 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str):
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,
)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
if world_size > 1:
commands.insert(0, ConnectToGroup(instance_id=iid))
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, _ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
for command in commands:
task_send.send(command)
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
entrypoint(
bound_instance,
ev_send,
task_recv,
logger,
)
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
# TODO(evan): return ev_recv.collect()
return []
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = await ModelCard.load(test.model_id)
world_size = len(test.devs)
assert test.ibv_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
jaccl_devices=test.ibv_devs,
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
RunnerId(host[0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
for i, host in enumerate(test.devs)
},
),
)
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
logger,
),
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()
time.sleep(0.1)
return (runner_process, ev_recv, task_send)
if __name__ == "__main__":
anyio.run(main)

54
tests/run_exo_on.sh Executable file
View File

@@ -0,0 +1,54 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
done
wait

85
tests/start_distributed_test.py Executable file
View File

@@ -0,0 +1,85 @@
#!/usr/bin/env python3
import itertools
import json
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, cast
from urllib.request import Request, urlopen
if not (args := sys.argv[1:]):
sys.exit(
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be jaccl or ring"
)
kind = args[0] if args[0] in ("jaccl", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]
devs = [[h, ip[h]] for h in hosts]
n = len(hosts)
def get_tb(a: str) -> list[dict[str, Any]]:
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
return json.loads(r.read()) # pyright: ignore[reportAny]
def get_models(a: str) -> set[str]:
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
return set(json.loads(r.read())) # pyright: ignore[reportAny]
def run(h: str, a: str, body: bytes) -> None:
with urlopen(
Request(
f"http://{a}:52414/run_test",
data=body,
method="POST",
headers={"Content-Type": "application/json"},
),
timeout=300,
) as r: # pyright: ignore[reportAny]
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
print(f"\n{h}@{a}: {line}", flush=True)
with ThreadPoolExecutor(n) as exctr:
if kind in ("jaccl", "both"):
payloads = list(exctr.map(get_tb, ips))
u2e = {
ident["domainUuid"]: (i, ident["rdmaInterface"])
for i, p in enumerate(payloads)
for d in p
for ident in cast(
list[dict[str, str]],
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
)
}
edges = {
(u2e[s][0], u2e[t][0]): u2e[t][1]
for p in payloads
for d in p
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
}
ibv_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
else:
ibv_devs = None
models = set[str].intersection(*exctr.map(get_models, ips))
print("\n")
print("=" * 70)
print(f"Starting test with {models}")
print("=" * 70)
print("\n")
for model in models:
body = json.dumps(
{"devs": devs, "model_id": model, "ibv_devs": ibv_devs, "kind": kind}
).encode()
list(exctr.map(run, hosts, ips, itertools.repeat(body)))

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done

4
uv.lock generated
View File

@@ -1072,8 +1072,8 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.5"
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
version = "0.30.6"
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#ab050d1fac2ef1d7bea6b8d870f1e5717d7f59f5" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },