Compare commits

...

3 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
484ed1d879 Test stuff 2026-01-20 00:57:27 +00:00
Ryuichi Leo Takashige
209d618d5a Load model layers individually but eagerly 2026-01-19 22:00:31 +00:00
rltakashige
5fd55594c9 Wrap pipeline models for explicit mx.depends between cache and logits (#1206)
## Motivation

GPU timeouts often when prompt size > profile_step_size. It also happens
for seemingly random models.

## Changes

Add mx.depends for cache on the logits.
All gather at the model level rather than the layer level, reducing the
amount of data sent.

## Why It Works

mlx_lm's prefill loop only evaluates cache state, not logits.
When prompt > prefill_step_size, the all_gather is never evaluated,
causing GPU timeout.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
Added failing test cases and then resolved them.
2026-01-19 17:49:42 +00:00
6 changed files with 529 additions and 161 deletions

View File

@@ -16,9 +16,6 @@ from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
@@ -490,17 +487,17 @@ def main() -> int:
logger.debug(f" warmup {i + 1}/{args.warmup} done")
for pp in pp_list:
if (
pp * n_nodes > 2048
and "ring" in instance_meta.lower()
and "tensor" in sharding.lower()
):
model_card = MODEL_CARDS[short_id]
if model_card.metadata.storage_size > Memory.from_gb(10):
logger.info(
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
)
continue
# if (
# pp * n_nodes > 2048
# and "ring" in instance_meta.lower()
# and "tensor" in sharding.lower()
# ):
# model_card = MODEL_CARDS[short_id]
# if model_card.metadata.storage_size > Memory.from_gb(10):
# logger.info(
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
# )
# continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):

View File

@@ -1,7 +1,10 @@
import os
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Callable, Protocol, cast
from typing import TYPE_CHECKING, Any, Protocol, cast
import mlx.core as mx
import mlx.nn as nn
@@ -29,6 +32,40 @@ from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
@@ -108,7 +145,6 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
return output
@@ -137,10 +173,30 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
return layers
class _IdentityModule(nn.Module):
"""Identity module that returns input unchanged. Used to skip computation."""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x
class _IdentityLmHead(nn.Module):
"""Identity lm_head that returns zeros. Used for non-final pipeline ranks."""
def __init__(self, vocab_size: int, dtype: mx.Dtype = mx.float16):
super().__init__()
self.vocab_size = vocab_size
self.dtype = dtype
def __call__(self, x: mx.array) -> mx.array:
# Return zeros with correct shape (batch, seq, vocab_size)
return mx.zeros((*x.shape[:-1], self.vocab_size), dtype=self.dtype)
def pipeline_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
model_shard_meta: PipelineShardMetadata,
model_shard_meta: PipelineShardMetadata
) -> nn.Module:
"""
Automatically parallelize a model across multiple devices.
@@ -158,6 +214,7 @@ def pipeline_auto_parallel(
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],
@@ -193,12 +250,70 @@ def pipeline_auto_parallel(
"Expected a list of layers after auto-parallel initialisation"
)
return patch_pipeline_model(model, group)
def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
# Patch __call__ on the model's class
cls = model.__class__
original_call = cls.__call__ # type :ignore
call_signature = signature(original_call) # type :ignore
def patched_call(
self: T,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = original_call(self, *args, **kwargs) # type: ignore
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call
return model
def patch_tensor_model[T](model: T) -> T:
"""Patch model's __call__ to ensure distributed ops sync during inference."""
cls = model.__class__
original_call = cls.__call__
call_signature = signature(original_call)
def patched_call(
self: T,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
cls.__call__ = patched_call
return model
def tensor_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> nn.Module:
all_to_sharded_linear = partial(
shard_linear,
@@ -243,7 +358,7 @@ def tensor_auto_parallel(
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
@@ -293,7 +408,10 @@ def tensor_auto_parallel(
else:
raise ValueError(f"Unsupported model type: {type(model)}")
return tensor_parallel_sharding_strategy.shard_model(model)
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
class TensorParallelShardingStrategy(ABC):
@@ -313,13 +431,27 @@ class TensorParallelShardingStrategy(ABC):
self.N = group.size()
@abstractmethod
def shard_model(self, model: nn.Module) -> nn.Module: ...
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -362,9 +494,17 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -416,9 +556,17 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -445,9 +593,17 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Qwen3MoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -491,10 +647,18 @@ class ShardedQwenMoE(CustomMlxLayer):
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)

View File

@@ -2,9 +2,7 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -59,6 +57,8 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
eval_with_timeout,
pipeline_auto_parallel,
tensor_auto_parallel,
)
@@ -88,41 +88,6 @@ class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -296,14 +261,6 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
@@ -312,10 +269,22 @@ def shard_and_load(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(
model, group, shard_metadata
)
# Skip eval for pipeline parallel to avoid fast synch issues
mx_barrier(group)
return model, tokenizer
# Eager eval for tensor parallel (ranks have same operations on sharded data)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
logger.debug("SHARDED")
logger.debug(model)

View File

@@ -1,12 +1,24 @@
# type: ignore
import json
import os
import tempfile
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
class MockLayer(nn.Module):
@@ -28,9 +40,6 @@ class PipelineTestConfig:
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
import json
import tempfile
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
@@ -50,35 +59,45 @@ DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
)
DEFAULT_GPT_OSS_MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
def run_gpt_oss_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 200,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
from mlx_lm import load, stream_generate
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
group = mx.distributed.init(backend="ring", strict=True)
model, tokenizer = load(str(model_path))
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=False,
),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=24,
)
model, tokenizer = shard_and_load(shard_meta, group)
model = cast(Model, model)
# Generate a prompt of exact token length
base_text = "The quick brown fox jumps over the lazy dog. "
@@ -93,45 +112,21 @@ def run_gpt_oss_pipeline_device(
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=False,
),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=24,
)
model = pipeline_auto_parallel(model, group, shard_meta)
# Barrier before generation
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
generated_text = ""
for response in stream_generate(
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=max_tokens,
prefill_step_size=prefill_step_size,
task=task,
):
generated_text += response.text
if response.finish_reason is not None:
break
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
@@ -143,27 +138,36 @@ def run_gpt_oss_tensor_parallel_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
prompt_tokens: int,
prefill_step_size: int,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 10,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
from mlx_lm import load, stream_generate
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
group = mx.distributed.init(backend="ring", strict=True)
model, tokenizer = load(str(model_path))
# For tensor parallelism, all devices run all layers
shard_meta = TensorShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
device_rank=rank,
world_size=world_size,
start_layer=0,
end_layer=24,
n_layers=24,
)
model, tokenizer = shard_and_load(shard_meta, group)
model = cast(Model, model)
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
@@ -175,26 +179,21 @@ def run_gpt_oss_tensor_parallel_device(
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
model = tensor_auto_parallel(model, group)
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
generated_text = ""
for response in stream_generate(
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=max_tokens,
prefill_step_size=prefill_step_size,
task=task,
):
generated_text += response.text
if response.finish_reason is not None:
break
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]

View File

@@ -1,13 +1,18 @@
import json
import multiprocessing as mp
import os
import tempfile
from typing import Any
import mlx.core as mx
import mlx.nn as mlx_nn
import pytest
from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer,
PipelineFirstLayer,
PipelineLastLayer,
patch_pipeline_model,
)
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
@@ -23,30 +28,38 @@ def run_pipeline_device(
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
import mlx.nn as mlx_nn
class MockLayerInner(mlx_nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
def __call__(
self, x: mlx_core.array, *args: object, **kwargs: object
) -> mlx_core.array:
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x * 2
class MockModel(mlx_nn.Module):
def __init__(self, layers: list[mlx_nn.Module]) -> None:
super().__init__()
self.layers = layers
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
for layer in self.layers:
x = layer(x, *args, **kwargs) # pyright: ignore[reportUnknownVariableType]
return x # pyright: ignore[reportUnknownVariableType]
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
group = mx.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
x = mlx_core.ones((1, 4))
result = composed(x)
mlx_core.eval(result)
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
inner_model = MockModel([composed])
model = patch_pipeline_model(inner_model, group)
x = mx.ones((1, 4))
result = model(x)
mx.eval(result)
success = result.shape == x.shape
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
except Exception as e:
@@ -81,10 +94,6 @@ def test_missing_attribute_raises() -> None:
def test_composed_call_works() -> None:
import json
import os
import tempfile
ctx = mp.get_context("spawn")
world_size = 2

View File

@@ -0,0 +1,230 @@
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import Any, Callable
import pytest
from exo.worker.tests.unittests.test_mlx.conftest import (
DEFAULT_GPT_OSS_CONFIG,
create_hostfile,
run_gpt_oss_pipeline_device,
run_gpt_oss_tensor_parallel_device,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
pytestmark = [
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
),
]
@dataclass
class DistributedTestResult:
timed_out: bool
world_size: int
results: dict[int, tuple[bool, str]]
@property
def all_success(self) -> bool:
if len(self.results) != self.world_size:
return False
return all(r[0] for r in self.results.values())
def run_distributed_test(
world_size: int,
port_offset: int,
process_timeout: int,
target: Callable[..., None],
make_args: Callable[[int], tuple[Any, ...]],
) -> DistributedTestResult:
ctx = mp.get_context("spawn")
hostfile_path, _ = create_hostfile(
world_size, DEFAULT_GPT_OSS_CONFIG.base_port + port_offset
)
try:
result_queue: Any = ctx.Queue()
processes: list[Any] = []
for rank in range(world_size):
args = make_args(rank)
p = ctx.Process(
target=target,
args=(rank, world_size, hostfile_path, *args, result_queue),
)
p.start()
processes.append(p)
for p in processes: # pyright: ignore[reportAny]
p.join(timeout=process_timeout) # pyright: ignore[reportAny]
timed_out = any(p.is_alive() for p in processes) # pyright: ignore[reportAny]
for p in processes: # pyright: ignore[reportAny]
if p.is_alive(): # pyright: ignore[reportAny]
p.terminate() # pyright: ignore[reportAny]
p.join(timeout=5) # pyright: ignore[reportAny]
results: dict[int, tuple[bool, str]] = {}
while not result_queue.empty(): # pyright: ignore[reportAny]
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
results[rank] = (success, value)
return DistributedTestResult(
timed_out=timed_out, world_size=world_size, results=results
)
finally:
os.unlink(hostfile_path)
def run_pipeline_test(
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
port_offset: int = 0,
process_timeout: int = 60,
) -> DistributedTestResult:
def make_args(rank: int) -> tuple[Any, ...]:
return (
layer_splits,
prompt_tokens,
prefill_step_size,
)
return run_distributed_test(
world_size=len(layer_splits),
port_offset=port_offset,
process_timeout=process_timeout,
target=run_gpt_oss_pipeline_device,
make_args=make_args,
)
def run_tensor_test(
prompt_tokens: int,
prefill_step_size: int,
port_offset: int = 0,
process_timeout: int = 60,
) -> DistributedTestResult:
def make_args(rank: int) -> tuple[Any, ...]:
return (
prompt_tokens,
prefill_step_size,
)
return run_distributed_test(
world_size=2,
port_offset=port_offset,
process_timeout=process_timeout,
target=run_gpt_oss_tensor_parallel_device,
make_args=make_args,
)
class TestPipelineParallelFix:
BUG_TRIGGER_SPLITS: list[tuple[int, int]] = [(0, 1), (1, 24)]
def test_pipeline_single_layer_first_device(self) -> None:
result = run_pipeline_test(
layer_splits=self.BUG_TRIGGER_SPLITS,
prompt_tokens=100,
prefill_step_size=64,
process_timeout=60,
)
assert not result.timed_out, "Unexpected timeout - fix may not be working"
assert result.all_success, f"Failures: {result.results}"
class TestPipelineSplitConfigurations:
@pytest.mark.parametrize(
"layer_splits",
[
[(0, 1), (1, 24)],
[(0, 6), (6, 24)],
[(0, 12), (12, 24)],
],
ids=["1_23", "6_18", "12_12"],
)
def test_pipeline_splits(
self,
layer_splits: list[tuple[int, int]],
) -> None:
result = run_pipeline_test(
layer_splits=layer_splits,
prompt_tokens=600,
prefill_step_size=512,
port_offset=100,
)
assert not result.timed_out, f"Timeout with {layer_splits}"
assert result.all_success, f"Failures with {layer_splits}: {result.results}"
class TestPrefillStepSizeBoundaries:
@pytest.mark.parametrize(
"prefill_step_size,prompt_tokens",
[
(512, 511),
(512, 512),
(512, 513),
(512, 1024),
],
ids=["under", "exact", "over", "double"],
)
def test_boundary_conditions(
self,
prefill_step_size: int,
prompt_tokens: int,
) -> None:
result = run_pipeline_test(
layer_splits=[(0, 12), (12, 24)],
prompt_tokens=prompt_tokens,
prefill_step_size=prefill_step_size,
port_offset=200,
)
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelFix:
def test_tensor_parallel(self) -> None:
result = run_tensor_test(
prompt_tokens=100,
prefill_step_size=64,
port_offset=400,
)
assert not result.timed_out, "Unexpected timeout"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelBoundaries:
@pytest.mark.parametrize(
"prefill_step_size,prompt_tokens",
[
(512, 511),
(512, 512),
(512, 513),
(512, 1024),
],
ids=["under", "exact", "over", "double"],
)
def test_tensor_parallel_boundaries(
self,
prefill_step_size: int,
prompt_tokens: int,
) -> None:
result = run_tensor_test(
prompt_tokens=prompt_tokens,
prefill_step_size=prefill_step_size,
port_offset=500,
)
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
assert result.all_success, f"Failures: {result.results}"