Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
642b1bb1b4 migrate model cards to .toml files 2026-01-15 17:07:48 +00:00
42 changed files with 1312 additions and 821 deletions

View File

@@ -91,51 +91,6 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## API Reference
The API is served at `http://localhost:52415` by default. Key files:
- `docs/api.md`: Full API documentation
- `src/exo/master/api.py`: FastAPI implementation
- `src/exo/shared/types/api.py`: Request/response Pydantic models
### Key Endpoints
```
GET /node_id # Current master node ID
GET /state # Full cluster state (topology, instances, downloads, etc.)
GET /events # Event log for debugging
POST /instance # Create model instance
GET /instance/{id} # Get instance details
DELETE /instance/{id} # Delete instance
GET /instance/previews # Preview placements for a model
GET /instance/placement # Compute placement without creating
GET /models # List available models
GET /v1/models # OpenAI-compatible model list
POST /v1/chat/completions # OpenAI-compatible chat completions (streaming/non-streaming)
POST /bench/chat/completions # Chat completions with performance stats
```
### Useful curl Commands
```bash
# Check cluster state
curl -s http://localhost:52415/state | python3 -m json.tool
# List models
curl -s http://localhost:52415/models | python3 -m json.tool
# Preview placements for a model
curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | python3 -m json.tool
# Chat completion
curl -X POST http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "llama-3.2-1b", "messages": [{"role": "user", "content": "Hello"}]}'
```
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

View File

@@ -23,6 +23,7 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"tomlkit>=0.14.0",
]
[project.scripts]

View File

@@ -0,0 +1,15 @@
short_id = "deepseek-v3.1-4bit"
model_id = "mlx-community/DeepSeek-V3.1-4bit"
name = "DeepSeek V3.1 (4-bit)"
description = "DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/DeepSeek-V3.1-4bit"
pretty_name = "DeepSeek V3.1 (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 405874409472

View File

@@ -0,0 +1,15 @@
short_id = "deepseek-v3.1-8bit"
model_id = "mlx-community/DeepSeek-V3.1-8bit"
name = "DeepSeek V3.1 (8-bit)"
description = "DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/DeepSeek-V3.1-8bit"
pretty_name = "DeepSeek V3.1 (8-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 765577920512

View File

@@ -0,0 +1,15 @@
short_id = "glm-4.5-air-8bit"
model_id = "mlx-community/GLM-4.5-Air-8bit"
name = "GLM 4.5 Air 8bit"
description = "GLM 4.5 Air 8bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.5-Air-8bit"
pretty_name = "GLM 4.5 Air 8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
[metadata.storage_size]
in_bytes = 122406567936

View File

@@ -0,0 +1,15 @@
short_id = "glm-4.5-air-bf16"
model_id = "mlx-community/GLM-4.5-Air-bf16"
name = "GLM 4.5 Air bf16"
description = "GLM 4.5 Air bf16"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.5-Air-bf16"
pretty_name = "GLM 4.5 Air bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 229780750336

View File

@@ -0,0 +1,15 @@
short_id = "glm-4.7-4bit"
model_id = "mlx-community/GLM-4.7-4bit"
name = "GLM 4.7 4bit"
description = "GLM 4.7 4bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-4bit"
pretty_name = "GLM 4.7 4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 198556925568

View File

@@ -0,0 +1,15 @@
short_id = "glm-4.7-6bit"
model_id = "mlx-community/GLM-4.7-6bit"
name = "GLM 4.7 6bit"
description = "GLM 4.7 6bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-6bit"
pretty_name = "GLM 4.7 6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 286737579648

View File

@@ -0,0 +1,15 @@
short_id = "glm-4.7-8bit-gs32"
model_id = "mlx-community/GLM-4.7-8bit-gs32"
name = "GLM 4.7 8bit (gs32)"
description = "GLM 4.7 8bit (gs32)"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-8bit-gs32"
pretty_name = "GLM 4.7 8bit (gs32)"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 396963397248

View File

@@ -0,0 +1,15 @@
short_id = "gpt-oss-120b-MXFP4-Q8"
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
name = "GPT-OSS 120B (MXFP4-Q8, MLX)"
description = "OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon."
tags = []
[metadata]
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
pretty_name = "GPT-OSS 120B (MXFP4-Q8, MLX)"
n_layers = 36
hidden_size = 2880
supports_tensor = true
[metadata.storage_size]
in_bytes = 70652212224

View File

@@ -0,0 +1,15 @@
short_id = "gpt-oss-20b-4bit"
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q4"
name = "GPT-OSS 20B (MXFP4-Q4, MLX)"
description = "OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization."
tags = []
[metadata]
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q4"
pretty_name = "GPT-OSS 20B (MXFP4-Q4, MLX)"
n_layers = 24
hidden_size = 2880
supports_tensor = true
[metadata.storage_size]
in_bytes = 12025908224

View File

@@ -0,0 +1,15 @@
short_id = "kimi-k2-instruct-4bit"
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
name = "Kimi K2 Instruct (4-bit)"
description = "Kimi K2 is a large language model trained on the Kimi K2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
pretty_name = "Kimi K2 Instruct (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 620622774272

View File

@@ -0,0 +1,15 @@
short_id = "kimi-k2-thinking"
model_id = "mlx-community/Kimi-K2-Thinking"
name = "Kimi K2 Thinking (4-bit)"
description = "Kimi K2 Thinking is the latest, most capable version of open-source thinking model."
tags = []
[metadata]
model_id = "mlx-community/Kimi-K2-Thinking"
pretty_name = "Kimi K2 Thinking (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 706522120192

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.1-70b"
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
name = "Llama 3.1 70B (4-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
pretty_name = "Llama 3.1 70B (4-bit)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.1-8b-8bit"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
name = "Llama 3.1 8B (8-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
pretty_name = "Llama 3.1 8B (8-bit)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 8954839040

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.1-8b-bf16"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
name = "Llama 3.1 8B (BF16)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
pretty_name = "Llama 3.1 8B (BF16)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 16882073600

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.1-8b"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
name = "Llama 3.1 8B (4-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
pretty_name = "Llama 3.1 8B (4-bit)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 4637851648

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.2-1b"
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
name = "Llama 3.2 1B (4-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
pretty_name = "Llama 3.2 1B (4-bit)"
n_layers = 16
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 729808896

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.2-3b-8bit"
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
name = "Llama 3.2 3B (8-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
pretty_name = "Llama 3.2 3B (8-bit)"
n_layers = 28
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 3501195264

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.2-3b"
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
name = "Llama 3.2 3B (4-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
pretty_name = "Llama 3.2 3B (4-bit)"
n_layers = 28
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 1863319552

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.3-70b-8bit"
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
name = "Llama 3.3 70B (8-bit)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
pretty_name = "Llama 3.3 70B (8-bit)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 76799803392

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.3-70b-fp16"
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
name = "Llama 3.3 70B (FP16)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
pretty_name = "Llama 3.3 70B (FP16)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 144383672320

View File

@@ -0,0 +1,15 @@
short_id = "llama-3.3-70b"
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
name = "Llama 3.3 70B (4-bit)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
pretty_name = "Llama 3.3 70B"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 40652242944

View File

@@ -0,0 +1,15 @@
short_id = "minimax-m2.1-3bit"
model_id = "mlx-community/MiniMax-M2.1-3bit"
name = "MiniMax M2.1 3bit"
description = "MiniMax M2.1 3bit"
tags = []
[metadata]
model_id = "mlx-community/MiniMax-M2.1-3bit"
pretty_name = "MiniMax M2.1 3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 100086644736

View File

@@ -0,0 +1,15 @@
short_id = "minimax-m2.1-8bit"
model_id = "mlx-community/MiniMax-M2.1-8bit"
name = "MiniMax M2.1 8bit"
description = "MiniMax M2.1 8bit"
tags = []
[metadata]
model_id = "mlx-community/MiniMax-M2.1-8bit"
pretty_name = "MiniMax M2.1 8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 242986745856

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-0.6b-8bit"
model_id = "mlx-community/Qwen3-0.6B-8bit"
name = "Qwen3 0.6B (8-bit)"
description = "Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-0.6B-8bit"
pretty_name = "Qwen3 0.6B (8-bit)"
n_layers = 28
hidden_size = 1024
supports_tensor = false
[metadata.storage_size]
in_bytes = 698351616

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-0.6b"
model_id = "mlx-community/Qwen3-0.6B-4bit"
name = "Qwen3 0.6B (4-bit)"
description = "Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-0.6B-4bit"
pretty_name = "Qwen3 0.6B (4-bit)"
n_layers = 28
hidden_size = 1024
supports_tensor = false
[metadata.storage_size]
in_bytes = 342884352

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-235b-a22b-4bit"
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
name = "Qwen3 235B A22B (4-bit)"
description = "Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
pretty_name = "Qwen3 235B A22B (4-bit)"
n_layers = 94
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 141733920768

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-235b-a22b-8bit"
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
name = "Qwen3 235B A22B (8-bit)"
description = "Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
pretty_name = "Qwen3 235B A22B (8-bit)"
n_layers = 94
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 268435456000

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-30b-8bit"
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
name = "Qwen3 30B A3B (8-bit)"
description = "Qwen3 30B is a large language model trained on the Qwen3 30B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
pretty_name = "Qwen3 30B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 33279705088

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-30b"
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
name = "Qwen3 30B A3B (4-bit)"
description = "Qwen3 30B is a large language model trained on the Qwen3 30B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
pretty_name = "Qwen3 30B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 17612931072

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-80b-a3B-4bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
name = "Qwen3 80B A3B (4-bit)"
description = "Qwen3 80B"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
pretty_name = "Qwen3 80B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 46976204800

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-80b-a3B-8bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
name = "Qwen3 80B A3B (8-bit)"
description = "Qwen3 80B"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
pretty_name = "Qwen3 80B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-80b-a3B-thinking-4bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
name = "Qwen3 80B A3B Thinking (4-bit)"
description = "Qwen3 80B Reasoning model"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
pretty_name = "Qwen3 80B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-80b-a3B-thinking-8bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
name = "Qwen3 80B A3B Thinking (8-bit)"
description = "Qwen3 80B Reasoning model"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
pretty_name = "Qwen3 80B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-coder-480b-a35b-4bit"
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
name = "Qwen3 Coder 480B A35B (4-bit)"
description = "Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
pretty_name = "Qwen3 Coder 480B A35B (4-bit)"
n_layers = 62
hidden_size = 6144
supports_tensor = true
[metadata.storage_size]
in_bytes = 289910292480

View File

@@ -0,0 +1,15 @@
short_id = "qwen3-coder-480b-a35b-8bit"
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
name = "Qwen3 Coder 480B A35B (8-bit)"
description = "Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
pretty_name = "Qwen3 Coder 480B A35B (8-bit)"
n_layers = 62
hidden_size = 6144
supports_tensor = true
[metadata.storage_size]
in_bytes = 579820584960

View File

@@ -1,3 +1,6 @@
from anyio import Path, open_file
import tomlkit
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
@@ -11,6 +14,21 @@ class ModelCard(CamelCaseModel):
tags: list[str]
metadata: ModelMetadata
@staticmethod
async def load(path: Path) -> "ModelCard":
async with await open_file(path) as f:
data = await f.read()
py = tomlkit.loads(data)
return ModelCard.model_validate(py)
async def save(self, path: Path):
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3

View File

@@ -228,10 +228,15 @@ def tensor_auto_parallel(
group=group,
)
logger.info(f"tensor_auto_parallel: model type = {type(model).__name__}")
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.info("Using LlamaShardingStrategy")
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -240,7 +245,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.info("Using DeepSeekShardingStrategy")
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -249,7 +254,6 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, MiniMaxModel):
logger.info("Using MiniMaxShardingStrategy")
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
@@ -258,7 +262,6 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
logger.info("Using QwenShardingStrategy")
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -267,7 +270,6 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
logger.info("Using GptOssShardingStrategy for tensor parallelism")
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
@@ -350,8 +352,6 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
dense_count = 0
moe_count = 0
for layer in model.layers:
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
@@ -370,7 +370,6 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
dense_count += 1
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
@@ -378,7 +377,6 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
moe_count += 1
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
@@ -388,7 +386,6 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
logger.info(f"DeepSeekShardingStrategy: {dense_count} dense layers (shard_linear), {moe_count} MoE layers (shard_inplace)")
return model
@@ -484,6 +481,7 @@ class ShardedQwenMoE(CustomMlxLayer):
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)

View File

@@ -162,9 +162,7 @@ def mlx_distributed_init(
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
logger.info(f"rank {rank} BEFORE mx.distributed.init(backend='jaccl')")
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"rank {rank} AFTER mx.distributed.init - group created")
logger.info(f"Rank {rank} mlx distributed initialization complete")
@@ -201,12 +199,10 @@ def load_mlx_items(
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed shard_and_load")
logger.info("Starting distributed init")
start_time = time.perf_counter()
logger.info(f"BEFORE shard_and_load for model {bound_instance.bound_shard.model_meta.model_id}")
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
logger.info(f"AFTER shard_and_load completed")
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
@@ -221,10 +217,8 @@ def shard_and_load(
group: Group,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
logger.info(f"shard_and_load: model_path={model_path}")
logger.info("BEFORE load_model (lazy=True)")
model, _ = load_model(model_path, lazy=True, strict=False)
logger.info("AFTER load_model")
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
@@ -258,6 +252,8 @@ def shard_and_load(
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# TODO: Do we need this?
mx.eval(model)
logger.debug("SHARDED")

View File

@@ -1,5 +1,4 @@
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
@@ -7,8 +6,6 @@ from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
@@ -18,9 +15,8 @@ async def check_reachability(
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
@@ -36,16 +32,7 @@ async def check_reachability(
return NodeId(body) or None
except OSError:
return None
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
except http.client.HTTPException:
return None
finally:
connection.close()

1493
uv.lock generated
View File

File diff suppressed because it is too large Load Diff