Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
e4256fa284 fix InstanceViewModel.swift
wasn't caught when we merged the API changes
2026-02-02 17:57:02 +00:00
83 changed files with 945 additions and 1592 deletions

View File

@@ -142,4 +142,4 @@ jobs:
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

View File

@@ -118,37 +118,11 @@ final class ExoProcessController: ObservableObject {
return
}
process.terminationHandler = nil
status = .stopped
guard process.isRunning else {
self.process = nil
return
if process.isRunning {
process.terminate()
}
let proc = process
self.process = nil
Task.detached {
proc.interrupt()
for _ in 0..<50 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
proc.terminate()
}
for _ in 0..<30 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
kill(proc.processIdentifier, SIGKILL)
}
}
status = .stopped
}
func restart() {

View File

@@ -10,7 +10,6 @@ PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
RESOURCES_DIR = PROJECT_ROOT / "resources"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
@@ -19,9 +18,6 @@ if not ENTRYPOINT.is_file():
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not RESOURCES_DIR.is_dir():
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
@@ -62,7 +58,6 @@ HIDDEN_IMPORTS = sorted(
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(RESOURCES_DIR), "resources"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15470210592
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5945589280
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21415799872
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11891178560
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33306978432
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23782357120
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/DeepSeek-V3.1-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 405874409472

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/DeepSeek-V3.1-8bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 765577920512

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.5-Air-8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 122406567936

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.5-Air-bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 229780750336

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 198556925568

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 286737579648

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-8bit-gs32"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 396963397248

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-4bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 19327352832

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-5bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 22548578304

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-6bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 26843545600

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-8bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 34359738368

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 620622774272

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2-Thinking"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 706522120192

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2.5"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 662498705408

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 729808896

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 1863319552

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 3501195264

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 76799803392

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 4637851648

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 8954839040

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 16882073600

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/MiniMax-M2.1-3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 100086644736

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/MiniMax-M2.1-8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 242986745856

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-0.6B-4bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 342884352

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-0.6B-8bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 698351616

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 141733920768

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 268435456000

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 17612931072

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 33279705088

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 289910292480

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 579820584960

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 46976204800

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 47080074240

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 70652212224

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 12025908224

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 144383672320

View File

@@ -1,5 +1,4 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
@@ -61,37 +60,10 @@ class DownloadCoordinator:
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
async def _check_internet_connection(self) -> None:
first_connection = True
while True:
await asyncio.sleep(10)
# Assume that internet connection is set to False on 443 errors.
if self.shard_downloader.internet_connection:
continue
self._test_internet_connection()
if first_connection and self.shard_downloader.internet_connection:
first_connection = False
self._tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
@@ -269,7 +241,7 @@ class DownloadCoordinator:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug(
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
@@ -302,10 +274,10 @@ class DownloadCoordinator:
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug(
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(60)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"

View File

@@ -49,10 +49,6 @@ class HuggingFaceAuthenticationError(Exception):
"""Raised when HuggingFace returns 401/403 for a model download."""
class HuggingFaceRateLimitError(Exception):
"""429 Huggingface code"""
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
token = await get_hf_token()
if status_code == 401 and token is None:
@@ -158,76 +154,49 @@ async def seed_models(seed_dir: str | Path):
logger.error(traceback.format_exc())
_fetched_file_lists_this_session: set[str] = set()
async def fetch_file_list_with_cache(
model_id: ModelId,
revision: str = "main",
recursive: bool = False,
skip_internet: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
model_id: ModelId, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
cache_key = f"{model_id.normalize()}--{revision}"
if cache_key in _fetched_file_lists_this_session and await aios.path.exists(
cache_file
):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
if skip_internet:
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
raise FileNotFoundError(
f"No internet connection and no cached file list for {model_id}"
)
# Always try fresh first
try:
file_list = await fetch_file_list_with_retry(
model_id,
revision,
recursive=recursive,
on_connection_lost=on_connection_lost,
model_id, revision, recursive=recursive
)
# Update cache with fresh data
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
# No cache available, propagate the error
raise
async def fetch_file_list_with_retry(
model_id: ModelId,
revision: str = "main",
path: str = "",
recursive: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 3
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(2.0**attempt)
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
@@ -247,11 +216,7 @@ async def _fetch_file_list(
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
raise HuggingFaceAuthenticationError(msg)
elif response.status == 429:
raise HuggingFaceRateLimitError(
f"Couldn't download {model_id} because of HuggingFace rate limit."
)
elif response.status == 200:
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
@@ -284,7 +249,7 @@ def create_http_session(
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 60
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
@@ -359,9 +324,8 @@ async def download_file_with_retry(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> Path:
n_attempts = 3
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _download_file(
@@ -369,19 +333,14 @@ async def download_file_with_retry(
)
except HuggingFaceAuthenticationError:
raise
except HuggingFaceRateLimitError as e:
if attempt == n_attempts - 1:
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
break
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
@@ -583,9 +542,7 @@ async def download_shard(
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
max_parallel_downloads: int = 8,
skip_download: bool = False,
skip_internet: bool = False,
allow_patterns: list[str] | None = None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.debug(f"Downloading {shard.model_card.model_id=}")
@@ -605,11 +562,7 @@ async def download_shard(
all_start_time = time.time()
file_list = await fetch_file_list_with_cache(
shard.model_card.model_id,
revision,
recursive=True,
skip_internet=skip_internet,
on_connection_lost=on_connection_lost,
shard.model_card.model_id, revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -719,7 +672,6 @@ async def download_shard(
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
)
if not skip_download:

View File

@@ -1,5 +1,4 @@
import asyncio
from asyncio import create_task
from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
@@ -8,7 +7,7 @@ from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -50,10 +49,6 @@ class SingletonShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -90,10 +85,6 @@ class CachedShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -151,8 +142,6 @@ class ResumableShardDownloader(ShardDownloader):
self.on_progress_wrapper,
max_parallel_downloads=self.max_parallel_downloads,
allow_patterns=allow_patterns,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
return target_dir
@@ -165,24 +154,13 @@ class ResumableShardDownloader(ShardDownloader):
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)
return await download_shard(
shard,
self.on_progress_wrapper,
skip_download=True,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
shard, self.on_progress_wrapper, skip_download=True
)
semaphore = asyncio.Semaphore(self.max_parallel_downloads)
async def download_with_semaphore(
model_card: ModelCard,
) -> tuple[Path, RepoDownloadProgress]:
async with semaphore:
return await _status_for_model(model_card.model_id)
# Kick off download status coroutines concurrently
tasks = [
create_task(download_with_semaphore(model_card))
for model_card in await get_model_cards()
asyncio.create_task(_status_for_model(model_card.model_id))
for model_card in MODEL_CARDS.values()
]
for task in asyncio.as_completed(tasks):

View File

@@ -16,11 +16,6 @@ from exo.shared.types.worker.shards import (
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
class ShardDownloader(ABC):
internet_connection: bool = False
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
@abstractmethod
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False

View File

@@ -137,7 +137,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:

View File

@@ -66,9 +66,7 @@ def chat_request_to_text_generation(
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,

View File

@@ -141,9 +141,7 @@ def claude_request_to_text_generation(
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,

View File

@@ -43,10 +43,10 @@ def _extract_content(content: str | list[ResponseContentPart]) -> str:
def responses_request_to_text_generation(
request: ResponsesRequest,
) -> TextGenerationTaskParams:
input_value: list[InputMessage]
input_value: str | list[InputMessage]
built_chat_template: list[dict[str, Any]] | None = None
if isinstance(request.input, str):
input_value = [InputMessage(role="user", content=request.input)]
input_value = request.input
else:
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
@@ -95,11 +95,7 @@ def responses_request_to_text_generation(
}
)
input_value = (
input_messages
if input_messages
else [InputMessage(role="user", content="")]
)
input_value = input_messages if input_messages else ""
built_chat_template = chat_template_messages if chat_template_messages else None
return TextGenerationTaskParams(

View File

@@ -40,7 +40,6 @@ from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import (
DASHBOARD_DIR,
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
EXO_TRACING_CACHE_DIR,
@@ -48,9 +47,9 @@ from exo.shared.constants import (
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
get_model_cards,
)
from exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file
from exo.shared.types.api import (
@@ -139,6 +138,7 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
@@ -146,6 +146,18 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -192,7 +204,7 @@ class API:
self.app.mount(
"/",
StaticFiles(
directory=DASHBOARD_DIR,
directory=find_dashboard(),
html=True,
),
name="dashboard",
@@ -369,7 +381,10 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
model_card = await ModelCard.load(model_id)
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []
for sharding in (Sharding.Pipeline, Sharding.Tensor):
for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):
@@ -384,93 +399,96 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
current_ids = set(self.state.instances.keys())
new_instances = [
instance
for instance_id, instance in placements.items()
if instance_id not in current_ids
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(sorted(placement_node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
)
)
seen.add(
(
if (
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
)
return PlacementPreviewResponse(previews=previews)
@@ -634,21 +652,23 @@ class API:
response = await self._collect_text_generation_with_stats(command.command_id)
return response
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
async def _resolve_and_validate_text_model(self, model: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(model)
resolved = model_card.model_id
if not any(
instance.shard_assignments.model_id == model_id
instance.shard_assignments.model_id == resolved
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(model_id)
await self._trigger_notify_user_to_download_model(resolved)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {model_id}",
detail=f"No instance found for model {resolved}",
)
return model_id
return resolved
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
@@ -1217,7 +1237,7 @@ class API:
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in await get_model_cards()
for card in MODEL_CARDS.values()
]
)

View File

@@ -28,7 +28,7 @@ from exo.shared.types.profiling import (
)
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
InstanceMeta,
MlxRingInstance,
@@ -136,9 +136,7 @@ async def test_master():
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input=[
InputMessage(role="user", content="Hello, how are you?")
],
input="Hello, how are you?",
),
)
),
@@ -191,7 +189,7 @@ async def test_master():
assert isinstance(events[2].event.task, TextGenerationTask)
assert events[2].event.task.task_params == TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input=[InputMessage(role="user", content="Hello, how are you?")],
input="Hello, how are you?",
)
await master.shutdown()

View File

@@ -2,8 +2,6 @@ import os
import sys
from pathlib import Path
from exo.utils.dashboard_path import find_dashboard, find_resources
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
@@ -33,14 +31,6 @@ EXO_MODELS_DIR = (
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
)
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
)
# Log files (data/logs or cache)
EXO_LOG = EXO_CACHE_HOME / "exo.log"

View File

@@ -12,42 +12,16 @@ from pydantic import (
BaseModel,
Field,
PositiveInt,
ValidationError,
field_validator,
model_validator,
)
from tomlkit.exceptions import TOMLKitError
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS, RESOURCES_DIR
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
# kinda ugly...
# TODO: load search path from config.toml
_csp = [Path(RESOURCES_DIR) / "inference_model_cards"]
if EXO_ENABLE_IMAGE_MODELS:
_csp.append(Path(RESOURCES_DIR) / "image_model_cards")
CARD_SEARCH_PATH = _csp
_card_cache: dict[ModelId, "ModelCard"] = {}
async def _refresh_card_cache():
for path in CARD_SEARCH_PATH:
async for toml_file in path.rglob("*.toml"):
try:
card = await ModelCard.load_from_path(toml_file)
_card_cache[card.model_id] = card
except (ValidationError, TOMLKitError):
pass
async def get_model_cards() -> list["ModelCard"]:
if len(_card_cache) == 0:
await _refresh_card_cache()
return list(_card_cache.values())
_card_cache: dict[str, "ModelCard"] = {}
class ModelTask(str, Enum):
@@ -81,36 +55,31 @@ class ModelCard(CamelCaseModel):
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump(exclude_none=True)
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
async def save_to_default_path(self):
await self.save(Path(RESOURCES_DIR) / (self.model_id.normalize() + ".toml"))
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
# Is it okay that model card.load defaults to network access if the card doesn't exist? do we want to be more explicit here?
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id not in _card_cache:
await _refresh_card_cache()
if (mc := _card_cache.get(model_id)) is not None:
return mc
return await ModelCard.fetch_from_hf(model_id)
for card in MODEL_CARDS.values():
if card.model_id == model_id:
return card
return await ModelCard.from_hf(model_id)
@staticmethod
async def fetch_from_hf(model_id: ModelId) -> "ModelCard":
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
# TODO: failure if files do not exist
config_data = await fetch_config_data(model_id)
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await fetch_safetensors_size(model_id)
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
@@ -120,13 +89,544 @@ class ModelCard(CamelCaseModel):
supports_tensor=config_data.supports_tensor,
tasks=[ModelTask.TextGeneration],
)
await mc.save_to_default_path()
_card_cache[model_id] = mc
return mc
# TODO: quantizing and dynamically creating model cards
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"deepseek-v3.1-8bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2-thinking": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2.5": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2.5"),
storage_size=Memory.from_gb(617),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.1
"llama-3.1-8b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-8bit": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-bf16": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-70b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.2
"llama-3.2-1b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.3
"llama-3.3-70b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-fp16": ModelCard(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# qwen3
"qwen3-0.6b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-0.6b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(44900),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"glm-4.5-air-bf16": ModelCard(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-8bit-gs32": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"minimax-m2.1-3bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
}
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("exolabs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("exolabs/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
}
def _generate_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
@@ -206,6 +706,15 @@ def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunctio
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
@@ -258,7 +767,7 @@ class ConfigData(BaseModel):
return data
async def fetch_config_data(model_id: ModelId) -> ConfigData:
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
from exo.download.download_utils import (
download_file_with_retry,
@@ -280,7 +789,7 @@ async def fetch_config_data(model_id: ModelId) -> ConfigData:
return ConfigData.model_validate_json(await f.read())
async def fetch_safetensors_size(model_id: ModelId) -> Memory:
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
from exo.download.download_utils import (
download_file_with_retry,

View File

@@ -28,7 +28,7 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
"""
model: ModelId
input: list[InputMessage]
input: str | list[InputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None

View File

@@ -1,45 +1,31 @@
import os
import sys
from pathlib import Path
from typing import cast
def find_resources() -> Path:
resources = _find_resources_in_repo() or _find_resources_in_bundle()
if resources is None:
raise FileNotFoundError(
"Unable to locate resources. Did you clone the repo properly?"
)
return resources
def _find_resources_in_repo() -> Path | None:
current_module = Path(__file__).resolve()
for parent in current_module.parents:
build = parent / "resources"
if build.is_dir():
return build
return None
def _find_resources_in_bundle() -> Path | None:
frozen_root = cast(str | None, getattr(sys, "_MEIPASS", None))
if frozen_root is None:
return None
candidate = Path(frozen_root) / "resources"
if candidate.is_dir():
return candidate
return None
def find_dashboard() -> Path:
dashboard = _find_dashboard_in_repo() or _find_dashboard_in_bundle()
dashboard = (
_find_dashboard_in_env()
or _find_dashboard_in_repo()
or _find_dashboard_in_bundle()
)
if not dashboard:
raise FileNotFoundError(
"Unable to locate dashboard assets - you probably forgot to run `cd dashboard && npm install && npm run build && cd ..`"
"Unable to locate dashboard assets - make sure the dashboard has been built, or export DASHBOARD_DIR if you've built the dashboard elsewhere."
)
return dashboard
def _find_dashboard_in_env() -> Path | None:
env = os.environ.get("DASHBOARD_DIR")
if not env:
return None
resolved_env = Path(env).expanduser().resolve()
return resolved_env
def _find_dashboard_in_repo() -> Path | None:
current_module = Path(__file__).resolve()
for parent in current_module.parents:

View File

@@ -17,7 +17,7 @@ from exo.shared.types.api import (
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
@@ -100,7 +100,7 @@ def warmup_inference(
tokenizer=tokenizer,
task_params=TextGenerationTaskParams(
model=ModelId(""),
input=[InputMessage(role="user", content=content)],
input=content,
),
)

View File

@@ -436,11 +436,16 @@ def apply_chat_template(
)
# Convert input to messages
for msg in task_params.input:
if not msg.content:
logger.warning("Received message with empty content, skipping")
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
if isinstance(task_params.input, str):
# Simple string input becomes a single user message
formatted_messages.append({"role": "user", "content": task_params.input})
else:
# List of InputMessage
for msg in task_params.input:
if not msg.content:
logger.warning("Received message with empty content, skipping")
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
prompt: str = tokenizer.apply_chat_template(
formatted_messages,

View File

@@ -918,10 +918,15 @@ def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
Extracts the first user input text and checks for debug triggers.
"""
if len(task_params.input) == 0:
logger.debug("Empty message list in debug prompt check")
return
prompt = task_params.input[0].content
prompt: str
if isinstance(task_params.input, str):
prompt = task_params.input
else:
# List of InputMessage - get first message content
if len(task_params.input) == 0:
logger.debug("Empty message list in debug prompt check")
return
prompt = task_params.input[0].content
if not prompt:
return

View File

@@ -14,7 +14,7 @@ from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
@@ -114,7 +114,7 @@ def run_gpt_oss_pipeline_device(
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content=prompt_text)],
input=prompt_text,
max_output_tokens=max_tokens,
)
@@ -182,7 +182,7 @@ def run_gpt_oss_tensor_parallel_device(
task = TextGenerationTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content=prompt_text)],
input=prompt_text,
max_output_tokens=max_tokens,
)

View File

@@ -16,7 +16,7 @@ from exo.download.download_utils import (
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,
@@ -76,7 +76,7 @@ def get_test_models() -> list[ModelCard]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, ModelCard] = {}
for card in asyncio.run(get_model_cards()):
for card in MODEL_CARDS.values():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
@@ -296,7 +296,7 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
card for card in await get_model_cards() if "kimi" in card.model_id.lower()
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
]
if not kimi_models:
@@ -343,7 +343,7 @@ async def test_kimi_tokenizer_specifically():
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_model_cards = [
card for card in await get_model_cards() if "glm" in card.model_id.lower()
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
]
if not glm_model_cards:

View File

@@ -2,7 +2,7 @@ from typing import cast
import exo.worker.plan as plan_mod
from exo.shared.types.tasks import Task, TaskId, TaskStatus, TextGeneration
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
RunnerIdle,
@@ -59,9 +59,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
),
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -108,9 +106,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
),
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -154,9 +150,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
instance_id=other_instance_id,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
),
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -204,9 +198,7 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Complete,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(
model=MODEL_A_ID, input=[InputMessage(role="user", content="")]
),
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
)
other_task_id = TaskId("other-task")

View File

@@ -22,7 +22,7 @@ from exo.shared.types.tasks import (
TaskStatus,
TextGeneration,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -86,7 +86,7 @@ SHUTDOWN_TASK = Shutdown(
CHAT_PARAMS = TextGenerationTaskParams(
model=MODEL_A_ID,
input=[InputMessage(role="user", content="hello")],
input="hello",
stream=True,
max_output_tokens=4,
temperature=0.0,

View File

@@ -1,20 +1,25 @@
import multiprocessing as mp
import socket
from typing import Literal
import time
import typing
import anyio
from fastapi import FastAPI
from fastapi.responses import Response, StreamingResponse
from fastapi.responses import StreamingResponse
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.chunks import TokenChunk
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.events import Event
from exo.shared.types.tasks import (
ConnectToGroup,
LoadModel,
@@ -23,7 +28,7 @@ from exo.shared.types.tasks import (
Task,
TextGeneration,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
@@ -31,14 +36,9 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import channel, mp_channel
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
@@ -46,37 +46,36 @@ from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
ibv_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "jaccl", "both"]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
iid = InstanceId("im testing here")
mp.set_start_method("spawn", force=True)
logger_setup(None)
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52414"
cfg.bind = "0.0.0.0:52415"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
ev = anyio.Event()
cfg.logger_class = InterceptLogger
app = FastAPI()
app.post("/run_test")(run_test)
app.post("/kill")(lambda: kill(ev))
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: ev.wait(),
shutdown_trigger=lambda: shutdown.wait(),
)
def kill(ev: anyio.Event):
ev.set()
return Response(status_code=204)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
async def tb_detection():
@@ -88,19 +87,29 @@ async def tb_detection():
return recv.collect()
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def run_test(test: Tests):
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
@@ -108,63 +117,31 @@ async def run_test(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = await ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["jaccl", "both"]:
i = await jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
hbn[i] = Host(ip="0.0.0.0", port=52417)
break
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
if i + 1 < len(test.devs):
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
card = await ModelCard.load(test.model_id)
card = MODEL_CARDS[test.model_id]
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52417,
ephemeral_port=52416,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=test.model_id,
model_id=ModelId(test.model_id),
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -186,79 +163,113 @@ async def ring_instance(test: Tests, hn: str) -> Instance | None:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
async def execute_test(test: Tests, instance: Instance, hn: str):
world_size = len(test.devs)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input=[
InputMessage(
role="user", content="What is the capital of France?"
)
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
commands.insert(0, ConnectToGroup(instance_id=iid))
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, _ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
for command in commands:
task_send.send(command)
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
entrypoint(
bound_instance,
ev_send,
task_recv,
logger,
)
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
# TODO(evan): return ev_recv.collect()
return []
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = await ModelCard.load(test.model_id)
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
world_size = len(test.devs)
assert test.ibv_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=test.ibv_devs,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=test.model_id,
model_id=ModelId(test.model_id),
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(host[0]): TensorShardMetadata(
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=0,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i, host in enumerate(test.devs)
for i in range(world_size)
},
),
)
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
logger,
),
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()
time.sleep(0.1)
return (runner_process, ev_recv, task_send)
if __name__ == "__main__":
anyio.run(main)

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
done
wait

View File

@@ -1,85 +0,0 @@
#!/usr/bin/env python3
import itertools
import json
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, cast
from urllib.request import Request, urlopen
if not (args := sys.argv[1:]):
sys.exit(
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be jaccl or ring"
)
kind = args[0] if args[0] in ("jaccl", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]
devs = [[h, ip[h]] for h in hosts]
n = len(hosts)
def get_tb(a: str) -> list[dict[str, Any]]:
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
return json.loads(r.read()) # pyright: ignore[reportAny]
def get_models(a: str) -> set[str]:
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
return set(json.loads(r.read())) # pyright: ignore[reportAny]
def run(h: str, a: str, body: bytes) -> None:
with urlopen(
Request(
f"http://{a}:52414/run_test",
data=body,
method="POST",
headers={"Content-Type": "application/json"},
),
timeout=300,
) as r: # pyright: ignore[reportAny]
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
print(f"\n{h}@{a}: {line}", flush=True)
with ThreadPoolExecutor(n) as exctr:
if kind in ("jaccl", "both"):
payloads = list(exctr.map(get_tb, ips))
u2e = {
ident["domainUuid"]: (i, ident["rdmaInterface"])
for i, p in enumerate(payloads)
for d in p
for ident in cast(
list[dict[str, str]],
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
)
}
edges = {
(u2e[s][0], u2e[t][0]): u2e[t][1]
for p in payloads
for d in p
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
}
ibv_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
else:
ibv_devs = None
models = set[str].intersection(*exctr.map(get_models, ips))
print("\n")
print("=" * 70)
print(f"Starting test with {models}")
print("=" * 70)
print("\n")
for model in models:
body = json.dumps(
{"devs": devs, "model_id": model, "ibv_devs": ibv_devs, "kind": kind}
).encode()
list(exctr.map(run, hosts, ips, itertools.repeat(body)))

54
tests/start_distributed_test.sh Executable file
View File

@@ -0,0 +1,54 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done