Compare commits

...

9 Commits

Author SHA1 Message Date
rltakashige
867f376b73 Merge branch 'main' into leo/fix-placement-pipeline-edge-case 2026-01-19 11:51:47 +00:00
Ryuichi Leo Takashige
06073248a2 Move patch file to generate directory 2026-01-18 02:54:45 +00:00
Ryuichi Leo Takashige
f8803e2456 Format and lint 2026-01-18 02:51:22 +00:00
Ryuichi Leo Takashige
f84ee9c29f Ensure tensor parallel is not broken 2026-01-18 02:42:03 +00:00
Ryuichi Leo Takashige
a49c459bc3 Distributed patch 2026-01-18 02:13:44 +00:00
Ryuichi Leo Takashige
18ef398a23 Formatting 2026-01-17 21:47:31 +00:00
Ryuichi Leo Takashige
911bef6a88 Fix single layer pipeline parallel
- GPT OSS swa/ga idx
- Composed Pipeline first and last layer
- Tests
2026-01-17 21:17:06 +00:00
Ryuichi Leo Takashige
6e691b561a Fix placement algorithm for pipeline parallel 2026-01-17 00:42:49 +00:00
Ryuichi Leo Takashige
fd09a6dea5 Failing test cases 2026-01-17 00:35:45 +00:00
10 changed files with 913 additions and 52 deletions

View File

@@ -1,3 +1,10 @@
import models as models
import tokenizer_utils as tokenizer_utils
from generate import *
from .generate import (
GenerationResponse as GenerationResponse,
generate as generate,
generate_step as generate_step,
stream_generate as stream_generate,
)
from .utils import (
load as load,
load_model as load_model,
)

View File

@@ -4,7 +4,7 @@ This type stub file was generated by pyright.
import os
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import Any, Callable, Type
import mlx.nn as nn
from transformers.utils.auto_docstring import ModelArgs
@@ -26,9 +26,9 @@ def load_model(
strict: bool = True,
model_config: dict[str, Any] = {},
get_model_classes: Callable[
[dict[str, Any]], Tuple[Type[nn.Module], Type[ModelArgs]]
[dict[str, Any]], tuple[Type[nn.Module], Type[ModelArgs]]
] = ...,
) -> Tuple[nn.Module, dict[str, Any]]:
) -> tuple[nn.Module, dict[str, Any]]:
"""
Load and initialize the model from a given path.
@@ -41,12 +41,12 @@ def load_model(
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
get_model_classes (Callable[[dict], tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config.
Defaults to the ``_get_classes`` function.
Returns:
Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config.
tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config.
Raises:
FileNotFoundError: If the weight files (.safetensors) are not found.
@@ -55,16 +55,13 @@ def load_model(
def load(
path_or_hf_repo: str,
tokenizer_config=...,
model_config=...,
adapter_path: Optional[str] = ...,
tokenizer_config: dict[str, Any] | None = ...,
model_config: dict[str, Any] | None = ...,
adapter_path: str | None = ...,
lazy: bool = ...,
return_config: bool = ...,
revision: str = ...,
) -> Union[
Tuple[nn.Module, TokenizerWrapper],
Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]],
]:
revision: str | None = ...,
) -> tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
@@ -82,8 +79,7 @@ def load(
return_config (bool: If ``True`` return the model config as the last item..
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
Returns:
Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]:
A tuple containing the loaded model, tokenizer and, if requested, the model config.
tuple[nn.Module, TokenizerWrapper]: The loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
@@ -102,15 +98,13 @@ def make_shards(weights: dict, max_file_size_gb: int = ...) -> list:
list: List of weight shards.
"""
def create_model_card(
path: Union[str, Path], hf_path: Union[str, Path, None]
): # -> None:
def create_model_card(path: str | Path, hf_path: str | Path | None): # -> None:
"""
Uploads the model to Hugging Face hub.
Args:
path (Union[str, Path]): Local path to the model.
hf_path (Union[str, Path, None]): Path to the original Hugging Face model.
path (str | Path): Local path to the model.
hf_path (str | Path | None): Path to the original Hugging Face model.
"""
def upload_to_hub(path: str, upload_repo: str): # -> None:
@@ -123,7 +117,7 @@ def upload_to_hub(path: str, upload_repo: str): # -> None:
"""
def save_model(
save_path: Union[str, Path], model: nn.Module, *, donate_model: bool = ...
save_path: str | Path, model: nn.Module, *, donate_model: bool = ...
) -> None:
"""Save model weights and metadata index into specified directory."""
@@ -133,8 +127,8 @@ def quantize_model(
group_size: int,
bits: int,
mode: str = ...,
quant_predicate: Optional[Callable[[str, nn.Module], Union[bool, dict]]] = ...,
) -> Tuple[nn.Module, dict]:
quant_predicate: Callable[[str, nn.Module], bool | dict] | None = ...,
) -> tuple[nn.Module, dict]:
"""
Applies quantization to the model weights.
@@ -150,25 +144,25 @@ def quantize_model(
a dict of quantization parameters to pass to `to_quantized`.
Returns:
Tuple: Tuple containing quantized model and config.
tuple: Tuple containing quantized model and config.
"""
def save_config(config: dict, config_path: Union[str, Path]) -> None:
def save_config(config: dict, config_path: str | Path) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
config_path (str | Path): Model configuration file path.
"""
def save(
dst_path: Union[str, Path],
src_path_or_repo: Union[str, Path],
dst_path: str | Path,
src_path_or_repo: str | Path,
model: nn.Module,
tokenizer: TokenizerWrapper,
config: Dict[str, Any],
config: dict[str, Any],
donate_model: bool = ...,
): # -> None:
...

View File

@@ -49,33 +49,87 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
"""Allocate layers proportionally using largest remainder method.
Guarantees each node gets at least 1 layer and total equals exactly total_layers.
"""
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder method: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1 # This should always be true
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
if not selected_cycle:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node.node_profile.memory.ram_available.in_bytes / cycle_memory.in_bytes
for node in selected_cycle
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node.node_profile.memory.ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node.node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(

View File

@@ -3,6 +3,7 @@ from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
@@ -165,6 +166,9 @@ def test_get_smallest_cycles(
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
@@ -397,3 +401,96 @@ def test_get_mlx_jaccl_coordinators(
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a = create_node(900 * 1024, node_a_id)
node_b = create_node(50 * 1024, node_b_id)
node_c = create_node(10 * 1024, node_c_id) # Insufficient memory
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(model_meta, selected_cycle, Sharding.Pipeline)

View File

@@ -46,9 +46,11 @@ class CustomMlxLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
# Set twice to avoid __setattr__ recursion
object.__setattr__(self, "_original_layer", original_layer)
self.original_layer: _LayerCallable = original_layer
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -58,7 +60,7 @@ class CustomMlxLayer(nn.Module):
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
return object.__getattribute__(original_layer, name)
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
@@ -155,7 +157,10 @@ def pipeline_auto_parallel(
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
# assume that at least one layer is assigned to the shard from placement
layers = layers[start_layer:end_layer]
# pipeline last layer can be composed with pipeline first layer
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],
@@ -168,11 +173,18 @@ def pipeline_auto_parallel(
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
layer_types: list[str] = inner_model_instance.layer_types # type: ignore
# Default to 0 if layer type not present - the mask will be created but unused in mlx lm
inner_model_instance.swa_idx = (
0
if "sliding_attention" not in layer_types
else layer_types.index("sliding_attention")
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
inner_model_instance.ga_idx = (
0
if "full_attention" not in layer_types
else layer_types.index("full_attention")
)
_set_layers(model, layers)

View File

@@ -0,0 +1,72 @@
from typing import Any, Generator
_patch_applied = False
def patch_mlx_lm_for_distributed() -> None:
"""
Patches mlx_lm's generate_step to work with distributed inference.
mlx_lm's prefill loop only evaluates cache state, not logits.
With distributed inference, model() triggers mx.distributed.all_gather() which must be
evaluated for all devices to synchronize. When prompt > prefill_step_size, the
all_gather is never evaluated, causing GPU timeout.
This patch uses mx.depends to make cache state depend on logits, ensuring all_gather is
evaluated when cache is eval'd.
"""
global _patch_applied
if _patch_applied:
return
_patch_applied = True
import importlib
import mlx.core as mx
gen_module = importlib.import_module("mlx_lm.generate")
original_generate_step = gen_module.generate_step # pyright: ignore[reportAny]
def patched_generate_step(
prompt: mx.array,
model: Any, # pyright: ignore[reportAny]
**kwargs: Any, # pyright: ignore[reportAny]
) -> Generator[Any, None, None]:
"""Patched generate_step that works with distributed inference."""
prompt_cache = kwargs.get("prompt_cache")
class DistributedModelWrapper:
"""Wrapper that adds mx.depends between logits and cache state."""
def __init__(
self,
inner_model: Any, # pyright: ignore[reportAny]
cache: Any, # pyright: ignore[reportAny]
) -> None:
self._inner: Any = inner_model
self._cache: Any = cache
def __call__(
self,
*args: Any, # pyright: ignore[reportAny]
**kw: Any, # pyright: ignore[reportAny]
) -> mx.array:
logits: mx.array = self._inner(*args, **kw) # pyright: ignore[reportAny]
cache: Any = kw.get("cache") or self._cache # pyright: ignore[reportAny]
if cache is not None:
for c in cache: # pyright: ignore[reportAny]
if hasattr(c, "state") and c.state is not None: # pyright: ignore[reportAny]
c.state = mx.depends(c.state, logits) # pyright: ignore[reportAny, reportUnknownMemberType]
return logits
def __getattr__(self, name: str) -> Any: # pyright: ignore[reportAny]
return getattr(self._inner, name) # pyright: ignore[reportAny]
if prompt_cache is None:
prompt_cache = model.make_cache() # pyright: ignore[reportAny]
kwargs["prompt_cache"] = prompt_cache
wrapped_model = DistributedModelWrapper(model, prompt_cache)
yield from original_generate_step(prompt, wrapped_model, **kwargs)
gen_module.generate_step = patched_generate_step # pyright: ignore[reportAttributeAccessIssue]

View File

@@ -62,6 +62,9 @@ from exo.worker.engines.mlx.auto_parallel import (
pipeline_auto_parallel,
tensor_auto_parallel,
)
from exo.worker.engines.mlx.generator.distributed_patch import (
patch_mlx_lm_for_distributed,
)
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
@@ -290,6 +293,9 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
# prompts > prefill_step_size cause GPU timeout otherwise
patch_mlx_lm_for_distributed()
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")

View File

@@ -0,0 +1,263 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.worker.engines.mlx.auto_parallel import (
PipelineFirstLayer,
PipelineLastLayer,
)
class MockLayer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
self.use_sliding = True
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x * 2
def run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
result_queue: Any, # pyright: ignore[reportAny]
) -> None:
"""Worker function for pipeline parallel tests. Runs in a spawned process."""
import os
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
import mlx.nn as mlx_nn
class MockLayerInner(mlx_nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
def __call__(
self, x: mlx_core.array, *args: object, **kwargs: object
) -> mlx_core.array:
return x * 2
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
x = mlx_core.ones((1, 4))
result = composed(x)
mlx_core.eval(result)
success = result.shape == x.shape
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
@dataclass(frozen=True)
class PipelineTestConfig:
model_path: Path
total_layers: int
base_port: int
max_tokens: int
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
import json
import tempfile
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
return hostfile_path, hosts
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
total_layers=24,
base_port=29600,
max_tokens=200,
)
def run_gpt_oss_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
use_patch: bool,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 200,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
if use_patch:
from exo.worker.engines.mlx.generator.distributed_patch import (
patch_mlx_lm_for_distributed,
)
patch_mlx_lm_for_distributed()
import mlx.core as mlx_core
from mlx_lm import load, stream_generate
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
model, tokenizer = load(str(model_path))
# Generate a prompt of exact token length
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
# Build prompt with approximate target length
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
# Truncate to exact target length
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=False,
),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=24,
)
model = pipeline_auto_parallel(model, group, shard_meta)
# Barrier before generation
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
generated_text = ""
for response in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=max_tokens,
prefill_step_size=prefill_step_size,
):
generated_text += response.text
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
def run_gpt_oss_tensor_parallel_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
prompt_tokens: int,
prefill_step_size: int,
use_patch: bool,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 10,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
if use_patch:
from exo.worker.engines.mlx.generator.distributed_patch import (
patch_mlx_lm_for_distributed,
)
patch_mlx_lm_for_distributed()
import mlx.core as mlx_core
from mlx_lm import load, stream_generate
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
model, tokenizer = load(str(model_path))
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
model = tensor_auto_parallel(model, group)
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
generated_text = ""
for response in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=max_tokens,
prefill_step_size=prefill_step_size,
):
generated_text += response.text
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]

View File

@@ -0,0 +1,100 @@
import multiprocessing as mp
from typing import Any
import mlx.core as mx
import pytest
from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer,
PipelineFirstLayer,
PipelineLastLayer,
)
from .conftest import MockLayer, run_pipeline_device
def test_single_wrapper_delegates_attributes() -> None:
mock = MockLayer()
wrapped = CustomMlxLayer(mock)
assert wrapped.custom_attr == "test_value" # type: ignore[attr-defined]
assert wrapped.use_sliding is True # type: ignore[attr-defined]
def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer()
group = mx.distributed.init()
first = PipelineFirstLayer(mock, r=0, group=group)
composed = PipelineLastLayer(first, r=0, s=1, group=group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined]
def test_missing_attribute_raises() -> None:
mock = MockLayer()
wrapped = CustomMlxLayer(mock)
with pytest.raises(AttributeError):
_ = wrapped.nonexistent_attr # type: ignore[attr-defined]
def test_composed_call_works() -> None:
import json
import os
import tempfile
ctx = mp.get_context("spawn")
world_size = 2
base_port = 29500
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
try:
result_queue: Any = ctx.Queue()
processes: list[Any] = []
for rank in range(world_size):
p = ctx.Process(
target=run_pipeline_device,
args=(rank, world_size, hostfile_path, result_queue),
)
p.start()
processes.append(p)
for p in processes: # pyright: ignore[reportAny]
p.join(timeout=10) # pyright: ignore[reportAny]
results: dict[int, Any] = {}
errors: dict[int, str] = {}
while not result_queue.empty(): # pyright: ignore[reportAny]
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
if success:
results[rank] = value
else:
errors[rank] = value
assert len(results) == world_size, (
f"Expected {world_size} results, got {len(results)}. Errors: {errors}"
)
# Device 0: input ones -> MockLayer(x*2) -> sends twos to device 1
# Device 1: receives twos -> MockLayer(x*2) -> outputs fours
# all_gather returns last batch, which is device 1's output (4.0)
for rank in range(world_size):
assert rank in results, (
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
)
result_array = results[rank]
# Both devices see the final result (4.0) after all_gather
assert (result_array == 4.0).all(), (
f"Device {rank}: expected 4.0, got {result_array}"
)
finally:
os.unlink(hostfile_path)

View File

@@ -0,0 +1,256 @@
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import Any, Callable
import pytest
from .conftest import (
DEFAULT_GPT_OSS_CONFIG,
create_hostfile,
run_gpt_oss_pipeline_device,
run_gpt_oss_tensor_parallel_device,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
),
]
@dataclass
class DistributedTestResult:
timed_out: bool
world_size: int
results: dict[int, tuple[bool, str]]
@property
def all_success(self) -> bool:
if len(self.results) != self.world_size:
return False
return all(r[0] for r in self.results.values())
def run_distributed_test(
world_size: int,
port_offset: int,
process_timeout: int,
target: Callable[..., None],
make_args: Callable[[int], tuple[Any, ...]],
) -> DistributedTestResult:
ctx = mp.get_context("spawn")
hostfile_path, _ = create_hostfile(
world_size, DEFAULT_GPT_OSS_CONFIG.base_port + port_offset
)
try:
result_queue: Any = ctx.Queue()
processes: list[Any] = []
for rank in range(world_size):
args = make_args(rank)
p = ctx.Process(
target=target,
args=(rank, world_size, hostfile_path, *args, result_queue),
)
p.start()
processes.append(p)
for p in processes: # pyright: ignore[reportAny]
p.join(timeout=process_timeout) # pyright: ignore[reportAny]
timed_out = any(p.is_alive() for p in processes) # pyright: ignore[reportAny]
for p in processes: # pyright: ignore[reportAny]
if p.is_alive(): # pyright: ignore[reportAny]
p.terminate() # pyright: ignore[reportAny]
p.join(timeout=5) # pyright: ignore[reportAny]
results: dict[int, tuple[bool, str]] = {}
while not result_queue.empty(): # pyright: ignore[reportAny]
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
results[rank] = (success, value)
return DistributedTestResult(
timed_out=timed_out, world_size=world_size, results=results
)
finally:
os.unlink(hostfile_path)
def run_pipeline_test(
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
use_patch: bool,
port_offset: int = 0,
process_timeout: int = 60,
) -> DistributedTestResult:
def make_args(rank: int) -> tuple[Any, ...]:
return (
DEFAULT_GPT_OSS_CONFIG.model_path,
layer_splits,
prompt_tokens,
prefill_step_size,
use_patch,
)
return run_distributed_test(
world_size=len(layer_splits),
port_offset=port_offset,
process_timeout=process_timeout,
target=run_gpt_oss_pipeline_device,
make_args=make_args,
)
def run_tensor_test(
prompt_tokens: int,
prefill_step_size: int,
use_patch: bool,
port_offset: int = 0,
process_timeout: int = 60,
) -> DistributedTestResult:
def make_args(rank: int) -> tuple[Any, ...]:
return (
DEFAULT_GPT_OSS_CONFIG.model_path,
prompt_tokens,
prefill_step_size,
use_patch,
)
return run_distributed_test(
world_size=2,
port_offset=port_offset,
process_timeout=process_timeout,
target=run_gpt_oss_tensor_parallel_device,
make_args=make_args,
)
class TestPipelineParallelPrefillBug:
BUG_TRIGGER_SPLITS: list[tuple[int, int]] = [(0, 1), (1, 24)]
def test_prefill_bug_without_patch(self) -> None:
result = run_pipeline_test(
layer_splits=self.BUG_TRIGGER_SPLITS,
prompt_tokens=100,
prefill_step_size=64,
use_patch=False,
process_timeout=30,
)
assert result.timed_out or not result.all_success, (
"Expected timeout/failure WITHOUT patch. "
"If this fails, mlx_lm may have been fixed upstream."
)
def test_prefill_fixed_with_patch(self) -> None:
result = run_pipeline_test(
layer_splits=self.BUG_TRIGGER_SPLITS,
prompt_tokens=100,
prefill_step_size=64,
use_patch=True,
)
assert not result.timed_out, "Unexpected timeout with patch"
assert result.all_success, f"Failures: {result.results}"
class TestPipelineSplitConfigurations:
@pytest.mark.parametrize(
"layer_splits",
[
[(0, 1), (1, 24)],
[(0, 6), (6, 24)],
[(0, 12), (12, 24)],
],
ids=["1_23", "6_18", "12_12"],
)
def test_pipeline_splits_with_patch(
self,
layer_splits: list[tuple[int, int]],
) -> None:
result = run_pipeline_test(
layer_splits=layer_splits,
prompt_tokens=600,
prefill_step_size=512,
use_patch=True,
port_offset=100,
)
assert not result.timed_out, f"Timeout with {layer_splits}"
assert result.all_success, f"Failures with {layer_splits}: {result.results}"
class TestPrefillStepSizeBoundaries:
@pytest.mark.parametrize(
"prefill_step_size,prompt_tokens",
[
(512, 511),
(512, 512),
(512, 513),
(512, 1024),
],
ids=["under", "exact", "over", "double"],
)
def test_boundary_conditions_with_patch(
self,
prefill_step_size: int,
prompt_tokens: int,
) -> None:
result = run_pipeline_test(
layer_splits=[(0, 12), (12, 24)],
prompt_tokens=prompt_tokens,
prefill_step_size=prefill_step_size,
use_patch=True,
port_offset=200,
)
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelWithPatch:
"""Test that the patch does not break tensor parallelism."""
def test_tensor_parallel(self) -> None:
result = run_tensor_test(
prompt_tokens=100,
prefill_step_size=64,
use_patch=True,
port_offset=400,
)
assert not result.timed_out, "Unexpected timeout with patch"
assert result.all_success, f"Failures: {result.results}"
class TestTensorParallelBoundaries:
@pytest.mark.parametrize(
"prefill_step_size,prompt_tokens",
[
(512, 511),
(512, 512),
(512, 513),
(512, 1024),
],
ids=["under", "exact", "over", "double"],
)
def test_tensor_parallel_boundaries_with_patch(
self,
prefill_step_size: int,
prompt_tokens: int,
) -> None:
result = run_tensor_test(
prompt_tokens=prompt_tokens,
prefill_step_size=prefill_step_size,
use_patch=True,
port_offset=500,
)
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
assert result.all_success, f"Failures: {result.results}"