Compare commits

...

2 Commits

Author SHA1 Message Date
rltakashige
0a38644536 Merge branch 'main' into leo/load-layers-individually 2026-01-19 22:10:16 +00:00
Ryuichi Leo Takashige
209d618d5a Load model layers individually but eagerly 2026-01-19 22:00:31 +00:00
2 changed files with 134 additions and 55 deletions

View File

@@ -1,7 +1,10 @@
import os
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Callable, Protocol, cast
from typing import TYPE_CHECKING, Any, Protocol, cast
import mlx.core as mx
import mlx.nn as nn
@@ -29,6 +32,40 @@ from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
@@ -225,9 +262,37 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
return model
def patch_tensor_model[T](model: T) -> T:
"""Patch model's __call__ to ensure distributed ops sync during inference."""
cls = model.__class__
original_call = cls.__call__
call_signature = signature(original_call)
def patched_call(
self: T,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
cls.__call__ = patched_call
return model
def tensor_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> nn.Module:
all_to_sharded_linear = partial(
shard_linear,
@@ -272,7 +337,7 @@ def tensor_auto_parallel(
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
@@ -322,7 +387,10 @@ def tensor_auto_parallel(
else:
raise ValueError(f"Unsupported model type: {type(model)}")
return tensor_parallel_sharding_strategy.shard_model(model)
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
class TensorParallelShardingStrategy(ABC):
@@ -342,13 +410,27 @@ class TensorParallelShardingStrategy(ABC):
self.N = group.size()
@abstractmethod
def shard_model(self, model: nn.Module) -> nn.Module: ...
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -391,9 +473,17 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -445,9 +535,17 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -474,9 +572,17 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Qwen3MoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -520,10 +626,18 @@ class ShardedQwenMoE(CustomMlxLayer):
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)

View File

@@ -2,9 +2,7 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -59,6 +57,8 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
eval_with_timeout,
pipeline_auto_parallel,
tensor_auto_parallel,
)
@@ -88,41 +88,6 @@ class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -296,14 +261,6 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
@@ -312,7 +269,15 @@ def shard_and_load(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)