mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-15 17:39:45 -05:00
Compare commits
8 Commits
main
...
debug/gpt-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f036add84f | ||
|
|
d63c8c86a8 | ||
|
|
80608eaf64 | ||
|
|
fc32199653 | ||
|
|
028e29a6d8 | ||
|
|
3941855ad6 | ||
|
|
1933b224c9 | ||
|
|
737d97a2d4 |
45
AGENTS.md
45
AGENTS.md
@@ -91,6 +91,51 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## API Reference
|
||||
|
||||
The API is served at `http://localhost:52415` by default. Key files:
|
||||
- `docs/api.md`: Full API documentation
|
||||
- `src/exo/master/api.py`: FastAPI implementation
|
||||
- `src/exo/shared/types/api.py`: Request/response Pydantic models
|
||||
|
||||
### Key Endpoints
|
||||
|
||||
```
|
||||
GET /node_id # Current master node ID
|
||||
GET /state # Full cluster state (topology, instances, downloads, etc.)
|
||||
GET /events # Event log for debugging
|
||||
|
||||
POST /instance # Create model instance
|
||||
GET /instance/{id} # Get instance details
|
||||
DELETE /instance/{id} # Delete instance
|
||||
GET /instance/previews # Preview placements for a model
|
||||
GET /instance/placement # Compute placement without creating
|
||||
|
||||
GET /models # List available models
|
||||
GET /v1/models # OpenAI-compatible model list
|
||||
|
||||
POST /v1/chat/completions # OpenAI-compatible chat completions (streaming/non-streaming)
|
||||
POST /bench/chat/completions # Chat completions with performance stats
|
||||
```
|
||||
|
||||
### Useful curl Commands
|
||||
|
||||
```bash
|
||||
# Check cluster state
|
||||
curl -s http://localhost:52415/state | python3 -m json.tool
|
||||
|
||||
# List models
|
||||
curl -s http://localhost:52415/models | python3 -m json.tool
|
||||
|
||||
# Preview placements for a model
|
||||
curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | python3 -m json.tool
|
||||
|
||||
# Chat completion
|
||||
curl -X POST http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "llama-3.2-1b", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
@@ -230,9 +230,12 @@ def tensor_auto_parallel(
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
logger.info("Using model's built-in shard method")
|
||||
model.shard(group) # type: ignore
|
||||
logger.info("model.shard(group) completed")
|
||||
return model
|
||||
except (AttributeError, TypeError, NameError):
|
||||
except (AttributeError, TypeError, NameError) as e:
|
||||
logger.info(f"model.shard failed with {e}, falling back to manual sharding")
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
@@ -270,6 +273,7 @@ def tensor_auto_parallel(
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
logger.info("Using GptOssShardingStrategy for tensor parallelism")
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -481,7 +485,6 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
|
||||
@@ -162,7 +162,9 @@ def mlx_distributed_init(
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
logger.info(f"rank {rank} BEFORE mx.distributed.init(backend='jaccl')")
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
logger.info(f"rank {rank} AFTER mx.distributed.init - group created")
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
@@ -199,10 +201,12 @@ def load_mlx_items(
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
logger.info("Starting distributed shard_and_load")
|
||||
start_time = time.perf_counter()
|
||||
logger.info(f"BEFORE shard_and_load for model {bound_instance.bound_shard.model_meta.model_id}")
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"AFTER shard_and_load completed")
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
@@ -217,8 +221,10 @@ def shard_and_load(
|
||||
group: Group,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
logger.info(f"shard_and_load: model_path={model_path}")
|
||||
logger.info("BEFORE load_model (lazy=True)")
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
logger.info("AFTER load_model")
|
||||
logger.debug(model)
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
@@ -252,8 +258,6 @@ def shard_and_load(
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
logger.debug("SHARDED")
|
||||
|
||||
@@ -17,11 +17,9 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
# NOTE: MLX_METAL_FAST_SYNCH is set AFTER model loading in runner.py
|
||||
# Setting it before loading causes hangs with lazy weight evaluation
|
||||
# on certain models (e.g., gpt-oss-20b) with jaccl backend.
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -20,7 +21,7 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
@@ -111,6 +112,15 @@ def main(
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
|
||||
# Enable fast sync AFTER model loading to avoid hang with lazy weights
|
||||
# See: https://github.com/exo-explore/exo/issues/XXX
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
logger.info("Enabled MLX_METAL_FAST_SYNCH after model loading")
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
|
||||
Reference in New Issue
Block a user