mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-20 11:58:57 -05:00
Compare commits
2 Commits
upstream-s
...
leo/test-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67deec88ca | ||
|
|
209d618d5a |
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Callable, Protocol, cast
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -29,6 +32,40 @@ from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
"""Structural type that any compatible layer must satisfy.
|
||||
@@ -211,9 +248,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
"cache", None
|
||||
)
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
# # Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
# if cache is not None:
|
||||
# cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
@@ -225,9 +262,37 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
return model
|
||||
|
||||
|
||||
def patch_tensor_model[T](model: T) -> T:
|
||||
"""Patch model's __call__ to ensure distributed ops sync during inference."""
|
||||
cls = model.__class__
|
||||
original_call = cls.__call__
|
||||
call_signature = signature(original_call)
|
||||
|
||||
def patched_call(
|
||||
self: T,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> mx.array:
|
||||
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
|
||||
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
|
||||
"cache", None
|
||||
)
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
return model
|
||||
|
||||
|
||||
def tensor_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> nn.Module:
|
||||
all_to_sharded_linear = partial(
|
||||
shard_linear,
|
||||
@@ -272,7 +337,7 @@ def tensor_auto_parallel(
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return model
|
||||
return patch_tensor_model(model)
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
@@ -322,7 +387,10 @@ def tensor_auto_parallel(
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
return tensor_parallel_sharding_strategy.shard_model(model)
|
||||
model = tensor_parallel_sharding_strategy.shard_model(
|
||||
model, timeout_seconds, on_timeout
|
||||
)
|
||||
return patch_tensor_model(model)
|
||||
|
||||
|
||||
class TensorParallelShardingStrategy(ABC):
|
||||
@@ -342,13 +410,27 @@ class TensorParallelShardingStrategy(ABC):
|
||||
self.N = group.size()
|
||||
|
||||
@abstractmethod
|
||||
def shard_model(self, model: nn.Module) -> nn.Module: ...
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module: ...
|
||||
|
||||
|
||||
class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(LlamaModel, model)
|
||||
for layer in model.layers:
|
||||
# Force load weights before sharding to avoid FAST_SYNCH deadlock
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
@@ -391,9 +473,17 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
@@ -445,9 +535,17 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
@@ -474,9 +572,17 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
@@ -520,10 +626,18 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
|
||||
@@ -2,9 +2,7 @@ import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -59,6 +57,8 @@ from exo.shared.types.worker.shards import (
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
TimeoutCallback,
|
||||
eval_with_timeout,
|
||||
pipeline_auto_parallel,
|
||||
tensor_auto_parallel,
|
||||
)
|
||||
@@ -88,41 +88,6 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -296,14 +261,6 @@ def shard_and_load(
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(model, group)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
@@ -312,7 +269,15 @@ def shard_and_load(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
Reference in New Issue
Block a user