Compare commits

...

1 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
b0c32c30e2 Test solution 2026-01-18 15:58:43 +00:00
6 changed files with 740 additions and 13 deletions

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Callable, Protocol, cast
from typing import TYPE_CHECKING, Any, Callable, Protocol, cast
import mlx.core as mx
import mlx.nn as nn
@@ -106,10 +106,52 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
return output
class DistributedModelWrapper:
"""Wrapper that ensures distributed ops are evaluated during prefill.
mlx_lm's prefill loop only evaluates cache state, not logits. With pipeline
parallel, the send/recv ops are triggered when logits are evaluated. This
wrapper adds mx.depends between cache.state and logits, so when cache is
evaluated after each prefill chunk, the distributed ops are also evaluated.
"""
_inner: Any
_cache: Any
def __init__(
self,
inner_model: nn.Module,
cache: Any = None, # pyright: ignore[reportAny]
):
object.__setattr__(self, "_inner", inner_model)
object.__setattr__(self, "_cache", cache)
def __call__(
self,
*args: Any, # pyright: ignore[reportAny]
**kwargs: Any, # pyright: ignore[reportAny]
) -> mx.array:
logits: mx.array = self._inner(*args, **kwargs) # pyright: ignore[reportAny]
cache: Any = kwargs.get("cache") or self._cache # pyright: ignore[reportAny]
if cache is not None:
for c in cache: # pyright: ignore[reportAny]
if hasattr(c, "state") and c.state is not None: # pyright: ignore[reportAny]
c.state = mx.depends(c.state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
def __getattr__(self, name: str) -> Any: # pyright: ignore[reportAny]
return getattr(self._inner, name) # pyright: ignore[reportAny]
def __setattr__(self, name: str, value: Any) -> None: # pyright: ignore[reportAny]
if name in ("_inner", "_cache"):
object.__setattr__(self, name, value)
else:
setattr(self._inner, name, value) # pyright: ignore[reportAny]
def _inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
if isinstance(inner, nn.Module):
@@ -168,12 +210,19 @@ def pipeline_auto_parallel(
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
# Handle case where layer type may not exist in this shard
try:
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
except ValueError:
inner_model_instance.swa_idx = -1
try:
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
except ValueError:
inner_model_instance.ga_idx = -1
_set_layers(model, layers)
@@ -181,7 +230,14 @@ def pipeline_auto_parallel(
"Expected a list of layers after auto-parallel initialisation"
)
return model
# Store pipeline group on model for token broadcasting in generate
model._pipeline_group = group
# Wrap model to ensure distributed ops are evaluated during prefill
wrapped = DistributedModelWrapper(model)
wrapped._pipeline_group = group
return wrapped # type: ignore[return-value]
def tensor_auto_parallel(

View File

@@ -148,6 +148,11 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Get pipeline group for token broadcasting (if using pipeline parallelism)
pipeline_group: mx.distributed.Group | None = getattr(
model, "_pipeline_group", None
)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
@@ -162,6 +167,17 @@ def mlx_generate(
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
# Broadcast token across all pipeline devices for synchronization
if pipeline_group is not None:
token_array = mx.array([[out.token]], dtype=mx.int32)
token_array = mx.distributed.all_gather(token_array, group=pipeline_group)
# Take the token from the last device (which has the full output)
# all_gather concatenates along first dim: [world_size, 1]
correct_token = int(token_array[-1, 0].item())
out.token = correct_token
# Re-decode text since our local token may have been wrong
out.text = tokenizer.decode([correct_token])
logger.info(out.text)
stats: GenerationStats | None = None

View File

@@ -147,6 +147,79 @@ def broadcast_from_zero(value: int, group: Group | None = None):
return int(m.item())
def _get_inner_model(model: nn.Module) -> nn.Module:
"""Get the inner model (handles wrapper patterns)."""
if hasattr(model, "model"):
inner: Any = model.model # type: ignore[attr-defined]
if isinstance(inner, nn.Module):
return inner
if hasattr(model, "transformer"):
inner = model.transformer # type: ignore[attr-defined]
if isinstance(inner, nn.Module):
return inner
return model
def _get_model_layers(inner: nn.Module) -> list[nn.Module]:
"""Get the transformer layers from the model."""
if hasattr(inner, "layers"):
layers: Any = inner.layers # type: ignore[attr-defined]
if isinstance(layers, list):
return cast(list[nn.Module], layers)
if hasattr(inner, "h"):
h_layers: Any = inner.h # type: ignore[attr-defined]
if isinstance(h_layers, list):
return cast(list[nn.Module], h_layers)
return []
def eval_parameters_per_layer(
model: nn.Module,
group: Group | None,
timeout_seconds: float,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate model parameters layer by layer to reduce Metal sync pressure.
This avoids the FAST_SYNCH deadlock by breaking evaluation into smaller chunks,
reducing peak Metal GPU memory/sync contention.
"""
inner = _get_inner_model(model)
layers = _get_model_layers(inner)
n_layers = len(layers)
# Reserve some time for embedding and output layers
per_layer_timeout = timeout_seconds / max(n_layers + 4, 1)
# Evaluate embedding/input layers first
if hasattr(inner, "embed_tokens"):
embed: Any = inner.embed_tokens # type: ignore[attr-defined]
if hasattr(embed, "parameters"): # pyright: ignore[reportUnknownArgumentType]
logger.debug("Evaluating embed_tokens")
eval_with_timeout(embed.parameters(), per_layer_timeout, on_timeout) # pyright: ignore[reportUnknownMemberType]
# Evaluate each transformer layer separately
for i, layer in enumerate(layers):
logger.debug(f"Evaluating layer {i}/{n_layers}")
eval_with_timeout(layer.parameters(), per_layer_timeout, on_timeout)
# Barrier between layers for distributed sync
if group is not None:
mx_barrier(group)
# Evaluate output layers
if hasattr(inner, "norm"):
norm: Any = inner.norm # type: ignore[attr-defined]
if hasattr(norm, "parameters"): # pyright: ignore[reportUnknownArgumentType]
logger.debug("Evaluating norm")
eval_with_timeout(norm.parameters(), per_layer_timeout, on_timeout) # pyright: ignore[reportUnknownMemberType]
if hasattr(model, "lm_head"):
lm_head: Any = model.lm_head # type: ignore[attr-defined]
if hasattr(lm_head, "parameters"): # pyright: ignore[reportUnknownArgumentType]
logger.debug("Evaluating lm_head")
eval_with_timeout(lm_head.parameters(), per_layer_timeout, on_timeout) # pyright: ignore[reportUnknownMemberType]
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -265,7 +338,9 @@ def shard_and_load(
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
# Use lazy=False to materialize weights before applying distributed sharding
# This avoids deadlock with FAST_SYNCH + lazy weights + distributed collectives
model, _ = load_model(model_path, lazy=False, strict=False)
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
@@ -303,12 +378,14 @@ def shard_and_load(
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"Evaluating model parameters per-layer with total timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
# Per-layer evaluation to avoid FAST_SYNCH deadlock with tensor parallel
eval_parameters_per_layer(model, group, timeout_seconds, on_timeout)
# Final model eval (should be fast now, weights already materialized)
mx.eval(model)
logger.debug("SHARDED")

View File

@@ -0,0 +1,271 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.worker.engines.mlx.auto_parallel import (
PipelineFirstLayer,
PipelineLastLayer,
)
class MockLayer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
self.use_sliding = True
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x * 2
def run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
result_queue: Any, # pyright: ignore[reportAny]
) -> None:
"""Worker function for pipeline parallel tests. Runs in a spawned process."""
import os
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
import mlx.nn as mlx_nn
class MockLayerInner(mlx_nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
def __call__(
self, x: mlx_core.array, *args: object, **kwargs: object
) -> mlx_core.array:
return x * 2
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
x = mlx_core.ones((1, 4))
result = composed(x)
mlx_core.eval(result)
success = result.shape == x.shape
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
@dataclass(frozen=True)
class PipelineTestConfig:
model_path: Path
total_layers: int
base_port: int
max_tokens: int
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
import json
import tempfile
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
return hostfile_path, hosts
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
total_layers=24,
base_port=29600,
max_tokens=200,
)
def run_gpt_oss_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int, # noqa: ARG001 - kept for API compatibility
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 200,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
import mlx.nn as mlx_nn
from mlx_lm.tokenizer_utils import TokenizerWrapper
from mlx_lm.utils import load_model
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
model: mlx_nn.Module
model, _ = load_model(model_path, lazy=False, strict=False)
tokenizer: TokenizerWrapper = load_tokenizer_for_model_id(
"mlx-community/gpt-oss-20b-MXFP4-Q8", model_path
)
# Generate a prompt of exact token length
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
# Build prompt with approximate target length
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
# Truncate to exact target length
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=False,
),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=24,
)
model = pipeline_auto_parallel(model, group, shard_meta)
# Evaluate model parameters (required to avoid GPU timeout with distributed)
mlx_core.eval(model.parameters())
mlx_core.eval(model)
# Barrier before generation
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
# Create task params for mlx_generate
task = ChatCompletionTaskParams(
model="mlx-community/gpt-oss-20b-MXFP4-Q8",
messages=[
ChatCompletionMessage(role="user", content=prompt_text),
],
max_tokens=max_tokens,
)
# Use mlx_generate which has token broadcasting built in
generated_text = ""
for response in mlx_generate(model, tokenizer, task): # type: ignore[arg-type]
generated_text += response.text
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
def run_gpt_oss_tensor_parallel_device(
rank: int,
world_size: int, # noqa: ARG001 - kept for API compatibility
hostfile_path: str,
model_path: Path,
prompt_tokens: int,
prefill_step_size: int, # noqa: ARG001 - kept for API compatibility
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 10,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
import mlx.nn as mlx_nn
from mlx_lm.tokenizer_utils import TokenizerWrapper
from mlx_lm.utils import load_model
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
model: mlx_nn.Module
model, _ = load_model(model_path, lazy=False, strict=False)
tokenizer: TokenizerWrapper = load_tokenizer_for_model_id(
"mlx-community/gpt-oss-20b-MXFP4-Q8", model_path
)
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
model = tensor_auto_parallel(model, group)
# Evaluate model parameters (required to avoid GPU timeout with distributed)
mlx_core.eval(model.parameters())
mlx_core.eval(model)
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
# Create task params for mlx_generate
task = ChatCompletionTaskParams(
model="mlx-community/gpt-oss-20b-MXFP4-Q8",
messages=[
ChatCompletionMessage(role="user", content=prompt_text),
],
max_tokens=max_tokens,
)
# Use mlx_generate
generated_text = ""
for response in mlx_generate(model, tokenizer, task): # type: ignore[arg-type]
generated_text += response.text
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]

View File

@@ -0,0 +1,244 @@
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import Any, Callable
import pytest
from .conftest import (
DEFAULT_GPT_OSS_CONFIG,
create_hostfile,
run_gpt_oss_pipeline_device,
run_gpt_oss_tensor_parallel_device,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
pytestmark = [
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
),
]
@dataclass
class DistributedTestResult:
timed_out: bool
world_size: int
results: dict[int, tuple[bool, str]]
@property
def all_success(self) -> bool:
if len(self.results) != self.world_size:
return False
return all(r[0] for r in self.results.values())
def run_distributed_test(
world_size: int,
port_offset: int,
process_timeout: int,
target: Callable[..., None],
make_args: Callable[[int], tuple[Any, ...]],
) -> DistributedTestResult:
ctx = mp.get_context("spawn")
hostfile_path, _ = create_hostfile(
world_size, DEFAULT_GPT_OSS_CONFIG.base_port + port_offset
)
try:
result_queue: Any = ctx.Queue()
processes: list[Any] = []
for rank in range(world_size):
args = make_args(rank)
p = ctx.Process(
target=target,
args=(rank, world_size, hostfile_path, *args, result_queue),
)
p.start()
processes.append(p)
for p in processes: # pyright: ignore[reportAny]
p.join(timeout=process_timeout) # pyright: ignore[reportAny]
timed_out = any(p.is_alive() for p in processes) # pyright: ignore[reportAny]
for p in processes: # pyright: ignore[reportAny]
if p.is_alive(): # pyright: ignore[reportAny]
p.terminate() # pyright: ignore[reportAny]
p.join(timeout=5) # pyright: ignore[reportAny]
results: dict[int, tuple[bool, str]] = {}
while not result_queue.empty(): # pyright: ignore[reportAny]
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
results[rank] = (success, value)
return DistributedTestResult(
timed_out=timed_out, world_size=world_size, results=results
)
finally:
os.unlink(hostfile_path)
def run_pipeline_test(
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
port_offset: int = 0,
process_timeout: int = 60,
) -> DistributedTestResult:
def make_args(rank: int) -> tuple[Any, ...]:
return (
DEFAULT_GPT_OSS_CONFIG.model_path,
layer_splits,
prompt_tokens,
prefill_step_size,
)
return run_distributed_test(
world_size=len(layer_splits),
port_offset=port_offset,
process_timeout=process_timeout,
target=run_gpt_oss_pipeline_device,
make_args=make_args,
)
def run_tensor_test(
prompt_tokens: int,
prefill_step_size: int,
port_offset: int = 0,
process_timeout: int = 60,
) -> DistributedTestResult:
def make_args(rank: int) -> tuple[Any, ...]:
return (
DEFAULT_GPT_OSS_CONFIG.model_path,
prompt_tokens,
prefill_step_size,
)
return run_distributed_test(
world_size=2,
port_offset=port_offset,
process_timeout=process_timeout,
target=run_gpt_oss_tensor_parallel_device,
make_args=make_args,
)
class TestPipelineParallelFix:
"""Test that the lazy=False + token broadcast fix works for pipeline parallel."""
# This configuration previously triggered the bug
BUG_TRIGGER_SPLITS: list[tuple[int, int]] = [(0, 1), (1, 24)]
def test_pipeline_single_layer_first_device(self) -> None:
"""Test the edge case that previously caused deadlock."""
result = run_pipeline_test(
layer_splits=self.BUG_TRIGGER_SPLITS,
prompt_tokens=100,
prefill_step_size=64,
process_timeout=60,
)
assert not result.timed_out, "Unexpected timeout - fix may not be working"
assert result.all_success, f"Failures: {result.results}"
class TestPipelineSplitConfigurations:
"""Test various pipeline split configurations."""
@pytest.mark.parametrize(
"layer_splits",
[
[(0, 1), (1, 24)],
[(0, 6), (6, 24)],
[(0, 12), (12, 24)],
],
ids=["1_23", "6_18", "12_12"],
)
def test_pipeline_splits(
self,
layer_splits: list[tuple[int, int]],
) -> None:
result = run_pipeline_test(
layer_splits=layer_splits,
prompt_tokens=600,
prefill_step_size=512,
port_offset=100,
)
assert not result.timed_out, f"Timeout with {layer_splits}"
assert result.all_success, f"Failures with {layer_splits}: {result.results}"
class TestPrefillStepSizeBoundaries:
"""Test boundary conditions around prefill_step_size."""
@pytest.mark.parametrize(
"prefill_step_size,prompt_tokens",
[
(512, 511),
(512, 512),
(512, 513),
(512, 1024),
],
ids=["under", "exact", "over", "double"],
)
def test_boundary_conditions(
self,
prefill_step_size: int,
prompt_tokens: int,
) -> None:
result = run_pipeline_test(
layer_splits=[(0, 12), (12, 24)],
prompt_tokens=prompt_tokens,
prefill_step_size=prefill_step_size,
port_offset=200,
)
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelFix:
"""Test that the lazy=False fix works for tensor parallelism."""
def test_tensor_parallel(self) -> None:
result = run_tensor_test(
prompt_tokens=100,
prefill_step_size=64,
port_offset=400,
)
assert not result.timed_out, "Unexpected timeout"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelBoundaries:
"""Test tensor parallel with various prefill boundaries."""
@pytest.mark.parametrize(
"prefill_step_size,prompt_tokens",
[
(512, 511),
(512, 512),
(512, 513),
(512, 1024),
],
ids=["under", "exact", "over", "double"],
)
def test_tensor_parallel_boundaries(
self,
prefill_step_size: int,
prompt_tokens: int,
) -> None:
result = run_tensor_test(
prompt_tokens=prompt_tokens,
prefill_step_size=prefill_step_size,
port_offset=500,
)
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
assert result.all_success, f"Failures: {result.results}"

View File

@@ -0,0 +1,63 @@
import pytest
from .conftest import (
DEFAULT_GPT_OSS_CONFIG,
)
from .test_distributed_fix import run_tensor_test
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
pytestmark = [
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
),
]
class TestPerLayerEvaluation:
"""Test per-layer evaluation strategy."""
def test_per_layer_eval_completes(self) -> None:
"""Test that per-layer eval completes without deadlock."""
result = run_tensor_test(
prompt_tokens=100,
prefill_step_size=64,
port_offset=700,
process_timeout=120,
)
assert not result.timed_out, "Per-layer eval should complete without timeout"
assert result.all_success, f"Failures: {result.results}"
def test_per_layer_eval_with_large_prompt(self) -> None:
"""Test per-layer eval with larger prompt."""
result = run_tensor_test(
prompt_tokens=500,
prefill_step_size=256,
port_offset=800,
process_timeout=180,
)
assert not result.timed_out, "Per-layer eval should handle large prompts"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelWithFastSynch:
"""Test tensor parallel with FAST_SYNCH=1 using per-layer eval."""
def test_tensor_parallel_fast_synch_no_deadlock(self) -> None:
"""Test that per-layer eval prevents FAST_SYNCH deadlock.
This test verifies the per-layer evaluation strategy works
with the default FAST_SYNCH behavior (auto-enabled for JACCL).
"""
result = run_tensor_test(
prompt_tokens=100,
prefill_step_size=64,
port_offset=900,
process_timeout=120,
)
assert not result.timed_out, "Per-layer eval should prevent FAST_SYNCH deadlock"
assert result.all_success, f"Failures: {result.results}"