Compare commits

...

13 Commits

Author SHA1 Message Date
Alex Cheema
771a94d944 debug: log dense vs MoE layer counts in DeepSeekShardingStrategy
This will show how many layers use shard_linear (dense) vs
shard_inplace (MoE) for kimi-k2 and similar models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:29:53 +00:00
Alex Cheema
0c266151ca fix: remove model.shard() bypass - always use custom strategies
The model.shard() call was bypassing our custom sharding strategies
for models that have a built-in shard method. This could be causing
the inconsistent behavior between different models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:08:40 +00:00
Alex Cheema
556f5a0f6d debug: add logging to identify which sharding strategy is used
Log model type, whether it has built-in shard method, and which
strategy is selected. This will help identify patterns between
working and broken models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:04:39 +00:00
Alex Cheema
1d0b121457 revert: restore MLX_METAL_FAST_SYNCH to original location
Revert to setting MLX_METAL_FAST_SYNCH in bootstrap.py before model
loading. Setting it after loading doesn't work properly.

The hang issue with certain models (gpt-oss-20b) + jaccl + fast_synch
needs further investigation into why those specific models trigger
the fence polling deadlock.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:44:40 +00:00
Alex Cheema
f036add84f fix: defer MLX_METAL_FAST_SYNCH until after model loading
MLX_METAL_FAST_SYNCH=1 causes hangs during lazy weight evaluation
with certain models (e.g., gpt-oss-20b) on the jaccl backend. The
fast sync mode appears to conflict with lazy array materialization.

Fix by setting MLX_METAL_FAST_SYNCH=1 only AFTER model loading
completes. This preserves the performance benefit during inference
while avoiding the loading hang.

Also cleaned up debug logging added during investigation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:35:07 +00:00
Alex Cheema
d63c8c86a8 fix: use tree_flatten for nested parameter dict
model.parameters() returns nested dicts, not flat. Use
mx.utils.tree_flatten to get flat list of (name, array) tuples.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:32:55 +00:00
Alex Cheema
80608eaf64 debug: more granular logging to find exact hang location
Log before/after each step: model.parameters(), dict conversion,
and each individual param eval to isolate where hang occurs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:32:32 +00:00
Alex Cheema
fc32199653 debug: eval parameters one-by-one to identify hang location
Iterate through model.parameters() and eval each one individually
with logging to pinpoint exactly which parameter causes the hang.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:30:32 +00:00
Alex Cheema
028e29a6d8 test: try barrier-only fix without preloading all weights
Remove the early mx.eval that loads entire model - just keep barrier
to sync nodes before sharding. This is important because preloading
the entire model on each node would OOM for large models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:20:02 +00:00
Alex Cheema
3941855ad6 debug: add logging around shard_linear and shard_inplace calls
Adding logging to understand where distributed communication happens
during tensor parallelism setup.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:11:39 +00:00
Alex Cheema
1933b224c9 fix: materialize lazy weights before distributed sharding
The jaccl backend deadlocks when mx.eval() is called on lazy weights
that have been wrapped with distributed sharding operations. The issue
is that lazy weight loading (downloading from HF) and distributed
communication were happening simultaneously.

Fix by:
1. Calling mx.eval(model.parameters()) BEFORE tensor_auto_parallel
2. Adding a barrier to ensure all nodes have weights before sharding
3. Then applying sharding to already-materialized weights

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:10:49 +00:00
Alex Cheema
737d97a2d4 Add detailed logging for jaccl/tensor parallel model loading
Add logging at critical points to debug MlxJacclInstance stuck in
RunnerLoading state:

- Before/after mx.distributed.init(backend="jaccl")
- Before/after shard_and_load, load_model
- Before/after tensor_auto_parallel with sharding strategy info
- Progress logs during GptOss layer sharding
- Before/after mx.eval(model.parameters()) and mx.eval(model)
- Before/after mx_barrier(group) sync

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 20:58:25 +00:00
Evan Quiney
3e623ccf0d up http timeout to 3 seconds and retry on BadStatusLine (#1164)
we're seeing a lot of network churn - perhaps this is a connection
timing out issue? lets also re-try after a second

## testing
none yet

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 18:15:12 +00:00
4 changed files with 80 additions and 16 deletions

View File

@@ -91,6 +91,51 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## API Reference
The API is served at `http://localhost:52415` by default. Key files:
- `docs/api.md`: Full API documentation
- `src/exo/master/api.py`: FastAPI implementation
- `src/exo/shared/types/api.py`: Request/response Pydantic models
### Key Endpoints
```
GET /node_id # Current master node ID
GET /state # Full cluster state (topology, instances, downloads, etc.)
GET /events # Event log for debugging
POST /instance # Create model instance
GET /instance/{id} # Get instance details
DELETE /instance/{id} # Delete instance
GET /instance/previews # Preview placements for a model
GET /instance/placement # Compute placement without creating
GET /models # List available models
GET /v1/models # OpenAI-compatible model list
POST /v1/chat/completions # OpenAI-compatible chat completions (streaming/non-streaming)
POST /bench/chat/completions # Chat completions with performance stats
```
### Useful curl Commands
```bash
# Check cluster state
curl -s http://localhost:52415/state | python3 -m json.tool
# List models
curl -s http://localhost:52415/models | python3 -m json.tool
# Preview placements for a model
curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | python3 -m json.tool
# Chat completion
curl -X POST http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "llama-3.2-1b", "messages": [{"role": "user", "content": "Hello"}]}'
```
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

View File

@@ -228,15 +228,10 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
except (AttributeError, TypeError, NameError):
pass
logger.info(f"tensor_auto_parallel: model type = {type(model).__name__}")
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
logger.info("Using LlamaShardingStrategy")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -245,7 +240,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
logger.info("Using DeepSeekShardingStrategy")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -254,6 +249,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, MiniMaxModel):
logger.info("Using MiniMaxShardingStrategy")
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
@@ -262,6 +258,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
logger.info("Using QwenShardingStrategy")
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -270,6 +267,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
logger.info("Using GptOssShardingStrategy for tensor parallelism")
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
@@ -352,6 +350,8 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
dense_count = 0
moe_count = 0
for layer in model.layers:
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
@@ -370,6 +370,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
dense_count += 1
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
@@ -377,6 +378,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
moe_count += 1
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
@@ -386,6 +388,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
logger.info(f"DeepSeekShardingStrategy: {dense_count} dense layers (shard_linear), {moe_count} MoE layers (shard_inplace)")
return model
@@ -481,7 +484,6 @@ class ShardedQwenMoE(CustomMlxLayer):
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)

View File

@@ -162,7 +162,9 @@ def mlx_distributed_init(
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
logger.info(f"rank {rank} BEFORE mx.distributed.init(backend='jaccl')")
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"rank {rank} AFTER mx.distributed.init - group created")
logger.info(f"Rank {rank} mlx distributed initialization complete")
@@ -199,10 +201,12 @@ def load_mlx_items(
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
logger.info("Starting distributed shard_and_load")
start_time = time.perf_counter()
logger.info(f"BEFORE shard_and_load for model {bound_instance.bound_shard.model_meta.model_id}")
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
logger.info(f"AFTER shard_and_load completed")
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
@@ -217,8 +221,10 @@ def shard_and_load(
group: Group,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
logger.info(f"shard_and_load: model_path={model_path}")
logger.info("BEFORE load_model (lazy=True)")
model, _ = load_model(model_path, lazy=True, strict=False)
logger.info("AFTER load_model")
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
@@ -252,8 +258,6 @@ def shard_and_load(
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# TODO: Do we need this?
mx.eval(model)
logger.debug("SHARDED")

View File

@@ -1,4 +1,5 @@
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
@@ -6,6 +7,8 @@ from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
@@ -15,8 +18,9 @@ async def check_reachability(
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
@@ -32,7 +36,16 @@ async def check_reachability(
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()