mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 05:23:11 -05:00
Compare commits
1 Commits
foo
...
leo/invest
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0c32c30e2 |
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Callable, Protocol, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Protocol, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -106,10 +106,52 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
return output
|
||||
|
||||
|
||||
class DistributedModelWrapper:
|
||||
"""Wrapper that ensures distributed ops are evaluated during prefill.
|
||||
|
||||
mlx_lm's prefill loop only evaluates cache state, not logits. With pipeline
|
||||
parallel, the send/recv ops are triggered when logits are evaluated. This
|
||||
wrapper adds mx.depends between cache.state and logits, so when cache is
|
||||
evaluated after each prefill chunk, the distributed ops are also evaluated.
|
||||
"""
|
||||
|
||||
_inner: Any
|
||||
_cache: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_model: nn.Module,
|
||||
cache: Any = None, # pyright: ignore[reportAny]
|
||||
):
|
||||
object.__setattr__(self, "_inner", inner_model)
|
||||
object.__setattr__(self, "_cache", cache)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any, # pyright: ignore[reportAny]
|
||||
**kwargs: Any, # pyright: ignore[reportAny]
|
||||
) -> mx.array:
|
||||
logits: mx.array = self._inner(*args, **kwargs) # pyright: ignore[reportAny]
|
||||
cache: Any = kwargs.get("cache") or self._cache # pyright: ignore[reportAny]
|
||||
if cache is not None:
|
||||
for c in cache: # pyright: ignore[reportAny]
|
||||
if hasattr(c, "state") and c.state is not None: # pyright: ignore[reportAny]
|
||||
c.state = mx.depends(c.state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
return logits
|
||||
|
||||
def __getattr__(self, name: str) -> Any: # pyright: ignore[reportAny]
|
||||
return getattr(self._inner, name) # pyright: ignore[reportAny]
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None: # pyright: ignore[reportAny]
|
||||
if name in ("_inner", "_cache"):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
setattr(self._inner, name, value) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def _inner_model(model: nn.Module) -> nn.Module:
|
||||
inner = getattr(model, "model", None)
|
||||
if isinstance(inner, nn.Module):
|
||||
@@ -168,12 +210,19 @@ def pipeline_auto_parallel(
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
# Handle case where layer type may not exist in this shard
|
||||
try:
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
except ValueError:
|
||||
inner_model_instance.swa_idx = -1
|
||||
try:
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
except ValueError:
|
||||
inner_model_instance.ga_idx = -1
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
@@ -181,7 +230,14 @@ def pipeline_auto_parallel(
|
||||
"Expected a list of layers after auto-parallel initialisation"
|
||||
)
|
||||
|
||||
return model
|
||||
# Store pipeline group on model for token broadcasting in generate
|
||||
model._pipeline_group = group
|
||||
|
||||
# Wrap model to ensure distributed ops are evaluated during prefill
|
||||
wrapped = DistributedModelWrapper(model)
|
||||
wrapped._pipeline_group = group
|
||||
|
||||
return wrapped # type: ignore[return-value]
|
||||
|
||||
|
||||
def tensor_auto_parallel(
|
||||
|
||||
@@ -148,6 +148,11 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Get pipeline group for token broadcasting (if using pipeline parallelism)
|
||||
pipeline_group: mx.distributed.Group | None = getattr(
|
||||
model, "_pipeline_group", None
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
@@ -162,6 +167,17 @@ def mlx_generate(
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
# Broadcast token across all pipeline devices for synchronization
|
||||
if pipeline_group is not None:
|
||||
token_array = mx.array([[out.token]], dtype=mx.int32)
|
||||
token_array = mx.distributed.all_gather(token_array, group=pipeline_group)
|
||||
# Take the token from the last device (which has the full output)
|
||||
# all_gather concatenates along first dim: [world_size, 1]
|
||||
correct_token = int(token_array[-1, 0].item())
|
||||
out.token = correct_token
|
||||
# Re-decode text since our local token may have been wrong
|
||||
out.text = tokenizer.decode([correct_token])
|
||||
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -147,6 +147,79 @@ def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
return int(m.item())
|
||||
|
||||
|
||||
def _get_inner_model(model: nn.Module) -> nn.Module:
|
||||
"""Get the inner model (handles wrapper patterns)."""
|
||||
if hasattr(model, "model"):
|
||||
inner: Any = model.model # type: ignore[attr-defined]
|
||||
if isinstance(inner, nn.Module):
|
||||
return inner
|
||||
if hasattr(model, "transformer"):
|
||||
inner = model.transformer # type: ignore[attr-defined]
|
||||
if isinstance(inner, nn.Module):
|
||||
return inner
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_layers(inner: nn.Module) -> list[nn.Module]:
|
||||
"""Get the transformer layers from the model."""
|
||||
if hasattr(inner, "layers"):
|
||||
layers: Any = inner.layers # type: ignore[attr-defined]
|
||||
if isinstance(layers, list):
|
||||
return cast(list[nn.Module], layers)
|
||||
if hasattr(inner, "h"):
|
||||
h_layers: Any = inner.h # type: ignore[attr-defined]
|
||||
if isinstance(h_layers, list):
|
||||
return cast(list[nn.Module], h_layers)
|
||||
return []
|
||||
|
||||
|
||||
def eval_parameters_per_layer(
|
||||
model: nn.Module,
|
||||
group: Group | None,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate model parameters layer by layer to reduce Metal sync pressure.
|
||||
|
||||
This avoids the FAST_SYNCH deadlock by breaking evaluation into smaller chunks,
|
||||
reducing peak Metal GPU memory/sync contention.
|
||||
"""
|
||||
inner = _get_inner_model(model)
|
||||
layers = _get_model_layers(inner)
|
||||
|
||||
n_layers = len(layers)
|
||||
# Reserve some time for embedding and output layers
|
||||
per_layer_timeout = timeout_seconds / max(n_layers + 4, 1)
|
||||
|
||||
# Evaluate embedding/input layers first
|
||||
if hasattr(inner, "embed_tokens"):
|
||||
embed: Any = inner.embed_tokens # type: ignore[attr-defined]
|
||||
if hasattr(embed, "parameters"): # pyright: ignore[reportUnknownArgumentType]
|
||||
logger.debug("Evaluating embed_tokens")
|
||||
eval_with_timeout(embed.parameters(), per_layer_timeout, on_timeout) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# Evaluate each transformer layer separately
|
||||
for i, layer in enumerate(layers):
|
||||
logger.debug(f"Evaluating layer {i}/{n_layers}")
|
||||
eval_with_timeout(layer.parameters(), per_layer_timeout, on_timeout)
|
||||
|
||||
# Barrier between layers for distributed sync
|
||||
if group is not None:
|
||||
mx_barrier(group)
|
||||
|
||||
# Evaluate output layers
|
||||
if hasattr(inner, "norm"):
|
||||
norm: Any = inner.norm # type: ignore[attr-defined]
|
||||
if hasattr(norm, "parameters"): # pyright: ignore[reportUnknownArgumentType]
|
||||
logger.debug("Evaluating norm")
|
||||
eval_with_timeout(norm.parameters(), per_layer_timeout, on_timeout) # pyright: ignore[reportUnknownMemberType]
|
||||
if hasattr(model, "lm_head"):
|
||||
lm_head: Any = model.lm_head # type: ignore[attr-defined]
|
||||
if hasattr(lm_head, "parameters"): # pyright: ignore[reportUnknownArgumentType]
|
||||
logger.debug("Evaluating lm_head")
|
||||
eval_with_timeout(lm_head.parameters(), per_layer_timeout, on_timeout) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -265,7 +338,9 @@ def shard_and_load(
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
# Use lazy=False to materialize weights before applying distributed sharding
|
||||
# This avoids deadlock with FAST_SYNCH + lazy weights + distributed collectives
|
||||
model, _ = load_model(model_path, lazy=False, strict=False)
|
||||
logger.debug(model)
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
@@ -303,12 +378,14 @@ def shard_and_load(
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"Evaluating model parameters per-layer with total timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
|
||||
# TODO: Do we need this?
|
||||
# Per-layer evaluation to avoid FAST_SYNCH deadlock with tensor parallel
|
||||
eval_parameters_per_layer(model, group, timeout_seconds, on_timeout)
|
||||
|
||||
# Final model eval (should be fast now, weights already materialized)
|
||||
mx.eval(model)
|
||||
|
||||
logger.debug("SHARDED")
|
||||
|
||||
271
src/exo/worker/tests/unittests/test_mlx/conftest.py
Normal file
271
src/exo/worker/tests/unittests/test_mlx/conftest.py
Normal file
@@ -0,0 +1,271 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
)
|
||||
|
||||
|
||||
class MockLayer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_attr = "test_value"
|
||||
self.use_sliding = True
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
return x * 2
|
||||
|
||||
|
||||
def run_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
) -> None:
|
||||
"""Worker function for pipeline parallel tests. Runs in a spawned process."""
|
||||
import os
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
import mlx.nn as mlx_nn
|
||||
|
||||
class MockLayerInner(mlx_nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_attr = "test_value"
|
||||
|
||||
def __call__(
|
||||
self, x: mlx_core.array, *args: object, **kwargs: object
|
||||
) -> mlx_core.array:
|
||||
return x * 2
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
mock = MockLayerInner()
|
||||
first = PipelineFirstLayer(mock, r=rank, group=group)
|
||||
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
|
||||
|
||||
x = mlx_core.ones((1, 4))
|
||||
result = composed(x)
|
||||
mlx_core.eval(result)
|
||||
|
||||
success = result.shape == x.shape
|
||||
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTestConfig:
|
||||
model_path: Path
|
||||
total_layers: int
|
||||
base_port: int
|
||||
max_tokens: int
|
||||
|
||||
|
||||
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
hostfile_path = f.name
|
||||
|
||||
return hostfile_path, hosts
|
||||
|
||||
|
||||
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
|
||||
|
||||
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
|
||||
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
|
||||
total_layers=24,
|
||||
base_port=29600,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
|
||||
def run_gpt_oss_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
model_path: Path,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int, # noqa: ARG001 - kept for API compatibility
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
max_tokens: int = 200,
|
||||
) -> None:
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
import mlx.nn as mlx_nn
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from mlx_lm.utils import load_model
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
model: mlx_nn.Module
|
||||
model, _ = load_model(model_path, lazy=False, strict=False)
|
||||
tokenizer: TokenizerWrapper = load_tokenizer_for_model_id(
|
||||
"mlx-community/gpt-oss-20b-MXFP4-Q8", model_path
|
||||
)
|
||||
|
||||
# Generate a prompt of exact token length
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
base_len = len(base_tokens)
|
||||
|
||||
# Build prompt with approximate target length
|
||||
repeats = (prompt_tokens // base_len) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)
|
||||
# Truncate to exact target length
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
n_layers=24,
|
||||
)
|
||||
|
||||
model = pipeline_auto_parallel(model, group, shard_meta)
|
||||
|
||||
# Evaluate model parameters (required to avoid GPU timeout with distributed)
|
||||
mlx_core.eval(model.parameters())
|
||||
mlx_core.eval(model)
|
||||
|
||||
# Barrier before generation
|
||||
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
|
||||
mlx_core.eval(barrier)
|
||||
|
||||
# Create task params for mlx_generate
|
||||
task = ChatCompletionTaskParams(
|
||||
model="mlx-community/gpt-oss-20b-MXFP4-Q8",
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content=prompt_text),
|
||||
],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Use mlx_generate which has token broadcasting built in
|
||||
generated_text = ""
|
||||
for response in mlx_generate(model, tokenizer, task): # type: ignore[arg-type]
|
||||
generated_text += response.text
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def run_gpt_oss_tensor_parallel_device(
|
||||
rank: int,
|
||||
world_size: int, # noqa: ARG001 - kept for API compatibility
|
||||
hostfile_path: str,
|
||||
model_path: Path,
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int, # noqa: ARG001 - kept for API compatibility
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
max_tokens: int = 10,
|
||||
) -> None:
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
import mlx.nn as mlx_nn
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from mlx_lm.utils import load_model
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
model: mlx_nn.Module
|
||||
model, _ = load_model(model_path, lazy=False, strict=False)
|
||||
tokenizer: TokenizerWrapper = load_tokenizer_for_model_id(
|
||||
"mlx-community/gpt-oss-20b-MXFP4-Q8", model_path
|
||||
)
|
||||
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
base_len = len(base_tokens)
|
||||
|
||||
repeats = (prompt_tokens // base_len) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
model = tensor_auto_parallel(model, group)
|
||||
|
||||
# Evaluate model parameters (required to avoid GPU timeout with distributed)
|
||||
mlx_core.eval(model.parameters())
|
||||
mlx_core.eval(model)
|
||||
|
||||
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
|
||||
mlx_core.eval(barrier)
|
||||
|
||||
# Create task params for mlx_generate
|
||||
task = ChatCompletionTaskParams(
|
||||
model="mlx-community/gpt-oss-20b-MXFP4-Q8",
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content=prompt_text),
|
||||
],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Use mlx_generate
|
||||
generated_text = ""
|
||||
for response in mlx_generate(model, tokenizer, task): # type: ignore[arg-type]
|
||||
generated_text += response.text
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
|
||||
244
src/exo/worker/tests/unittests/test_mlx/test_distributed_fix.py
Normal file
244
src/exo/worker/tests/unittests/test_mlx/test_distributed_fix.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import (
|
||||
DEFAULT_GPT_OSS_CONFIG,
|
||||
create_hostfile,
|
||||
run_gpt_oss_pipeline_device,
|
||||
run_gpt_oss_tensor_parallel_device,
|
||||
)
|
||||
|
||||
|
||||
def _check_model_exists() -> bool:
|
||||
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedTestResult:
|
||||
timed_out: bool
|
||||
world_size: int
|
||||
results: dict[int, tuple[bool, str]]
|
||||
|
||||
@property
|
||||
def all_success(self) -> bool:
|
||||
if len(self.results) != self.world_size:
|
||||
return False
|
||||
return all(r[0] for r in self.results.values())
|
||||
|
||||
|
||||
def run_distributed_test(
|
||||
world_size: int,
|
||||
port_offset: int,
|
||||
process_timeout: int,
|
||||
target: Callable[..., None],
|
||||
make_args: Callable[[int], tuple[Any, ...]],
|
||||
) -> DistributedTestResult:
|
||||
ctx = mp.get_context("spawn")
|
||||
hostfile_path, _ = create_hostfile(
|
||||
world_size, DEFAULT_GPT_OSS_CONFIG.base_port + port_offset
|
||||
)
|
||||
|
||||
try:
|
||||
result_queue: Any = ctx.Queue()
|
||||
processes: list[Any] = []
|
||||
|
||||
for rank in range(world_size):
|
||||
args = make_args(rank)
|
||||
p = ctx.Process(
|
||||
target=target,
|
||||
args=(rank, world_size, hostfile_path, *args, result_queue),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
p.join(timeout=process_timeout) # pyright: ignore[reportAny]
|
||||
|
||||
timed_out = any(p.is_alive() for p in processes) # pyright: ignore[reportAny]
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
if p.is_alive(): # pyright: ignore[reportAny]
|
||||
p.terminate() # pyright: ignore[reportAny]
|
||||
p.join(timeout=5) # pyright: ignore[reportAny]
|
||||
|
||||
results: dict[int, tuple[bool, str]] = {}
|
||||
while not result_queue.empty(): # pyright: ignore[reportAny]
|
||||
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
|
||||
results[rank] = (success, value)
|
||||
|
||||
return DistributedTestResult(
|
||||
timed_out=timed_out, world_size=world_size, results=results
|
||||
)
|
||||
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
|
||||
|
||||
def run_pipeline_test(
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
port_offset: int = 0,
|
||||
process_timeout: int = 60,
|
||||
) -> DistributedTestResult:
|
||||
def make_args(rank: int) -> tuple[Any, ...]:
|
||||
return (
|
||||
DEFAULT_GPT_OSS_CONFIG.model_path,
|
||||
layer_splits,
|
||||
prompt_tokens,
|
||||
prefill_step_size,
|
||||
)
|
||||
|
||||
return run_distributed_test(
|
||||
world_size=len(layer_splits),
|
||||
port_offset=port_offset,
|
||||
process_timeout=process_timeout,
|
||||
target=run_gpt_oss_pipeline_device,
|
||||
make_args=make_args,
|
||||
)
|
||||
|
||||
|
||||
def run_tensor_test(
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
port_offset: int = 0,
|
||||
process_timeout: int = 60,
|
||||
) -> DistributedTestResult:
|
||||
def make_args(rank: int) -> tuple[Any, ...]:
|
||||
return (
|
||||
DEFAULT_GPT_OSS_CONFIG.model_path,
|
||||
prompt_tokens,
|
||||
prefill_step_size,
|
||||
)
|
||||
|
||||
return run_distributed_test(
|
||||
world_size=2,
|
||||
port_offset=port_offset,
|
||||
process_timeout=process_timeout,
|
||||
target=run_gpt_oss_tensor_parallel_device,
|
||||
make_args=make_args,
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineParallelFix:
|
||||
"""Test that the lazy=False + token broadcast fix works for pipeline parallel."""
|
||||
|
||||
# This configuration previously triggered the bug
|
||||
BUG_TRIGGER_SPLITS: list[tuple[int, int]] = [(0, 1), (1, 24)]
|
||||
|
||||
def test_pipeline_single_layer_first_device(self) -> None:
|
||||
"""Test the edge case that previously caused deadlock."""
|
||||
result = run_pipeline_test(
|
||||
layer_splits=self.BUG_TRIGGER_SPLITS,
|
||||
prompt_tokens=100,
|
||||
prefill_step_size=64,
|
||||
process_timeout=60,
|
||||
)
|
||||
assert not result.timed_out, "Unexpected timeout - fix may not be working"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestPipelineSplitConfigurations:
|
||||
"""Test various pipeline split configurations."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_splits",
|
||||
[
|
||||
[(0, 1), (1, 24)],
|
||||
[(0, 6), (6, 24)],
|
||||
[(0, 12), (12, 24)],
|
||||
],
|
||||
ids=["1_23", "6_18", "12_12"],
|
||||
)
|
||||
def test_pipeline_splits(
|
||||
self,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
) -> None:
|
||||
result = run_pipeline_test(
|
||||
layer_splits=layer_splits,
|
||||
prompt_tokens=600,
|
||||
prefill_step_size=512,
|
||||
port_offset=100,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout with {layer_splits}"
|
||||
assert result.all_success, f"Failures with {layer_splits}: {result.results}"
|
||||
|
||||
|
||||
class TestPrefillStepSizeBoundaries:
|
||||
"""Test boundary conditions around prefill_step_size."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_step_size,prompt_tokens",
|
||||
[
|
||||
(512, 511),
|
||||
(512, 512),
|
||||
(512, 513),
|
||||
(512, 1024),
|
||||
],
|
||||
ids=["under", "exact", "over", "double"],
|
||||
)
|
||||
def test_boundary_conditions(
|
||||
self,
|
||||
prefill_step_size: int,
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
result = run_pipeline_test(
|
||||
layer_splits=[(0, 12), (12, 24)],
|
||||
prompt_tokens=prompt_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
port_offset=200,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestTensorParallelFix:
|
||||
"""Test that the lazy=False fix works for tensor parallelism."""
|
||||
|
||||
def test_tensor_parallel(self) -> None:
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=100,
|
||||
prefill_step_size=64,
|
||||
port_offset=400,
|
||||
)
|
||||
assert not result.timed_out, "Unexpected timeout"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestTensorParallelBoundaries:
|
||||
"""Test tensor parallel with various prefill boundaries."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_step_size,prompt_tokens",
|
||||
[
|
||||
(512, 511),
|
||||
(512, 512),
|
||||
(512, 513),
|
||||
(512, 1024),
|
||||
],
|
||||
ids=["under", "exact", "over", "double"],
|
||||
)
|
||||
def test_tensor_parallel_boundaries(
|
||||
self,
|
||||
prefill_step_size: int,
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
port_offset=500,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from .conftest import (
|
||||
DEFAULT_GPT_OSS_CONFIG,
|
||||
)
|
||||
from .test_distributed_fix import run_tensor_test
|
||||
|
||||
|
||||
def _check_model_exists() -> bool:
|
||||
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestPerLayerEvaluation:
|
||||
"""Test per-layer evaluation strategy."""
|
||||
|
||||
def test_per_layer_eval_completes(self) -> None:
|
||||
"""Test that per-layer eval completes without deadlock."""
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=100,
|
||||
prefill_step_size=64,
|
||||
port_offset=700,
|
||||
process_timeout=120,
|
||||
)
|
||||
assert not result.timed_out, "Per-layer eval should complete without timeout"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
def test_per_layer_eval_with_large_prompt(self) -> None:
|
||||
"""Test per-layer eval with larger prompt."""
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=500,
|
||||
prefill_step_size=256,
|
||||
port_offset=800,
|
||||
process_timeout=180,
|
||||
)
|
||||
assert not result.timed_out, "Per-layer eval should handle large prompts"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestTensorParallelWithFastSynch:
|
||||
"""Test tensor parallel with FAST_SYNCH=1 using per-layer eval."""
|
||||
|
||||
def test_tensor_parallel_fast_synch_no_deadlock(self) -> None:
|
||||
"""Test that per-layer eval prevents FAST_SYNCH deadlock.
|
||||
|
||||
This test verifies the per-layer evaluation strategy works
|
||||
with the default FAST_SYNCH behavior (auto-enabled for JACCL).
|
||||
"""
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=100,
|
||||
prefill_step_size=64,
|
||||
port_offset=900,
|
||||
process_timeout=120,
|
||||
)
|
||||
assert not result.timed_out, "Per-layer eval should prevent FAST_SYNCH deadlock"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
Reference in New Issue
Block a user