mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-16 01:51:03 -05:00
Compare commits
13 Commits
model-card
...
debug/gpt-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
771a94d944 | ||
|
|
0c266151ca | ||
|
|
556f5a0f6d | ||
|
|
1d0b121457 | ||
|
|
f036add84f | ||
|
|
d63c8c86a8 | ||
|
|
80608eaf64 | ||
|
|
fc32199653 | ||
|
|
028e29a6d8 | ||
|
|
3941855ad6 | ||
|
|
1933b224c9 | ||
|
|
737d97a2d4 | ||
|
|
3e623ccf0d |
45
AGENTS.md
45
AGENTS.md
@@ -91,6 +91,51 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## API Reference
|
||||
|
||||
The API is served at `http://localhost:52415` by default. Key files:
|
||||
- `docs/api.md`: Full API documentation
|
||||
- `src/exo/master/api.py`: FastAPI implementation
|
||||
- `src/exo/shared/types/api.py`: Request/response Pydantic models
|
||||
|
||||
### Key Endpoints
|
||||
|
||||
```
|
||||
GET /node_id # Current master node ID
|
||||
GET /state # Full cluster state (topology, instances, downloads, etc.)
|
||||
GET /events # Event log for debugging
|
||||
|
||||
POST /instance # Create model instance
|
||||
GET /instance/{id} # Get instance details
|
||||
DELETE /instance/{id} # Delete instance
|
||||
GET /instance/previews # Preview placements for a model
|
||||
GET /instance/placement # Compute placement without creating
|
||||
|
||||
GET /models # List available models
|
||||
GET /v1/models # OpenAI-compatible model list
|
||||
|
||||
POST /v1/chat/completions # OpenAI-compatible chat completions (streaming/non-streaming)
|
||||
POST /bench/chat/completions # Chat completions with performance stats
|
||||
```
|
||||
|
||||
### Useful curl Commands
|
||||
|
||||
```bash
|
||||
# Check cluster state
|
||||
curl -s http://localhost:52415/state | python3 -m json.tool
|
||||
|
||||
# List models
|
||||
curl -s http://localhost:52415/models | python3 -m json.tool
|
||||
|
||||
# Preview placements for a model
|
||||
curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | python3 -m json.tool
|
||||
|
||||
# Chat completion
|
||||
curl -X POST http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "llama-3.2-1b", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
@@ -228,15 +228,10 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return model
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
logger.info(f"tensor_auto_parallel: model type = {type(model).__name__}")
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
logger.info("Using LlamaShardingStrategy")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -245,7 +240,7 @@ def tensor_auto_parallel(
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
logger.info("Using DeepSeekShardingStrategy")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -254,6 +249,7 @@ def tensor_auto_parallel(
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, MiniMaxModel):
|
||||
logger.info("Using MiniMaxShardingStrategy")
|
||||
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -262,6 +258,7 @@ def tensor_auto_parallel(
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
logger.info("Using QwenShardingStrategy")
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -270,6 +267,7 @@ def tensor_auto_parallel(
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
logger.info("Using GptOssShardingStrategy for tensor parallelism")
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -352,6 +350,8 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
dense_count = 0
|
||||
moe_count = 0
|
||||
for layer in model.layers:
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
@@ -370,6 +370,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
dense_count += 1
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
@@ -377,6 +378,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
else:
|
||||
moe_count += 1
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
|
||||
@@ -386,6 +388,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
logger.info(f"DeepSeekShardingStrategy: {dense_count} dense layers (shard_linear), {moe_count} MoE layers (shard_inplace)")
|
||||
return model
|
||||
|
||||
|
||||
@@ -481,7 +484,6 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
|
||||
@@ -162,7 +162,9 @@ def mlx_distributed_init(
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
logger.info(f"rank {rank} BEFORE mx.distributed.init(backend='jaccl')")
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
logger.info(f"rank {rank} AFTER mx.distributed.init - group created")
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
@@ -199,10 +201,12 @@ def load_mlx_items(
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
logger.info("Starting distributed shard_and_load")
|
||||
start_time = time.perf_counter()
|
||||
logger.info(f"BEFORE shard_and_load for model {bound_instance.bound_shard.model_meta.model_id}")
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"AFTER shard_and_load completed")
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
@@ -217,8 +221,10 @@ def shard_and_load(
|
||||
group: Group,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
logger.info(f"shard_and_load: model_path={model_path}")
|
||||
logger.info("BEFORE load_model (lazy=True)")
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
logger.info("AFTER load_model")
|
||||
logger.debug(model)
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
@@ -252,8 +258,6 @@ def shard_and_load(
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
logger.debug("SHARDED")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import http.client
|
||||
import time
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
@@ -6,6 +7,8 @@ from loguru import logger
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
BAD_STATUSLINE_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
@@ -15,8 +18,9 @@ async def check_reachability(
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
# TODO: use an async http client
|
||||
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
@@ -32,7 +36,16 @@ async def check_reachability(
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.HTTPException:
|
||||
except http.client.BadStatusLine:
|
||||
if attempt >= BAD_STATUSLINE_ATTEMPTS:
|
||||
logger.warning(
|
||||
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
|
||||
)
|
||||
return None
|
||||
time.sleep(1)
|
||||
return _fetch_remote_node_id(attempt=attempt + 1)
|
||||
except http.client.HTTPException as e:
|
||||
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
Reference in New Issue
Block a user