Compare commits

...

2 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
8f6f2f3065 Add fixes 2026-01-20 17:13:02 +00:00
Evan
e6af53c2ae foo 2026-01-20 17:12:31 +00:00
3 changed files with 77 additions and 91 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, Protocol, cast from typing import TYPE_CHECKING, Any, cast
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -67,27 +67,16 @@ def eval_with_timeout(
completed.set() completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class CustomMlxLayer(nn.Module): class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation.""" """Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable): def __init__(self, original_layer: nn.Module):
super().__init__() super().__init__()
object.__setattr__(self, "_original_layer", original_layer) object.__setattr__(self, "_original_layer", original_layer)
@property @property
def original_layer(self) -> _LayerCallable: def original_layer(self) -> nn.Module:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer")) return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding) # Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING: if not TYPE_CHECKING:
@@ -100,52 +89,53 @@ class CustomMlxLayer(nn.Module):
return getattr(original_layer, name) return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer): def patch_pipeline_first_layer(
def __init__( pipeline_layer: nn.Module, group: mx.distributed.Group
self, ) -> nn.Module:
original_layer: _LayerCallable, cls = type(pipeline_layer)
r: int, orig_call = cast(Callable[..., mx.array], cls.__call__)
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: rank = group.rank()
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group) class PatchedFirstLayer(cls):
return self.original_layer(x, *args, **kwargs) def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(self, x, *args, **kwargs)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
class PipelineLastLayer(CustomMlxLayer): def patch_pipeline_last_layer(
def __init__( pipeline_layer: nn.Module, group: mx.distributed.Group
self, ) -> nn.Module:
original_layer: _LayerCallable, cls = type(pipeline_layer)
r: int, orig_call = cast(Callable[..., mx.array], cls.__call__)
s: int, orig_call_sig = signature(orig_call)
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: rank = group.rank()
cache = self.original_layer_signature.bind_partial( size = group.size()
x, *args, **kwargs
).arguments.get("cache", None)
output: mx.array = self.original_layer(x, *args, **kwargs) class PatchedLastLayer(cls):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != self.s - 1: cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
output = mx.distributed.send( "cache", None
output, (self.r + 1) % self.s, group=self.group
) )
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output output: mx.array = orig_call(self, x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(output, (rank + 1) % size, group=group)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
def _inner_model(model: nn.Module) -> nn.Module: def _inner_model(model: nn.Module) -> nn.Module:
@@ -160,13 +150,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute") raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]: def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
# Handle both model.layers and model.h cases # Handle both model.layers and model.h cases
layers: list[_LayerCallable] layers: list[nn.Module]
if hasattr(inner_model_instance, "layers"): if hasattr(inner_model_instance, "layers"):
layers = cast(list[_LayerCallable], inner_model_instance.layers) layers = cast(list[nn.Module], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"): elif hasattr(inner_model_instance, "h"):
layers = cast(list[_LayerCallable], inner_model_instance.h) layers = cast(list[nn.Module], inner_model_instance.h)
else: else:
raise ValueError("Model must have either a 'layers' or 'h' attribute") raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -191,15 +181,12 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance) layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer] layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group) layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = PipelineLastLayer( layers[-1] = patch_pipeline_last_layer(
layers[-1], layers[-1],
device_rank, group,
world_size,
group=group,
) )
if isinstance(inner_model_instance, GptOssMoeModel): if isinstance(inner_model_instance, GptOssMoeModel):
@@ -446,7 +433,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None: def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
inner_model_instance = _inner_model(model) inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"): if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers inner_model_instance.layers = layers
@@ -521,17 +508,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer): class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable): def __init__(self, layer: nn.Module):
super().__init__(layer) super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None: if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x) x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None: if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y return y # type: ignore
class MiniMaxShardingStrategy(TensorParallelShardingStrategy): class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
@@ -565,7 +552,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place( self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj layer.block_sparse_moe.switch_mlp.up_proj
) )
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue] layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model return model
@@ -599,7 +586,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj) self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj) self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj) self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group layer.mlp.sharding_group = self.group
# Shard the MLP # Shard the MLP
@@ -612,17 +599,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer): class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable): def __init__(self, layer: nn.Module):
super().__init__(layer) super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None: if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x) x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None: if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y return y # type: ignore
class GptOssShardingStrategy(TensorParallelShardingStrategy): class GptOssShardingStrategy(TensorParallelShardingStrategy):
@@ -674,7 +661,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None: if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x) x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x) y = self.original_layer(x) # type: ignore
if self.sharding_group is not None: if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y return y # type: ignore

View File

@@ -18,7 +18,7 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load from exo.worker.engines.mlx.utils_mlx import shard_and_load, apply_chat_template
class MockLayer(nn.Module): class MockLayer(nn.Module):
@@ -116,12 +116,11 @@ def run_gpt_oss_pipeline_device(
messages=[ChatCompletionMessage(role="user", content=prompt_text)], messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens, max_tokens=max_tokens,
) )
prompt = apply_chat_template(tokenizer, task)
generated_text = "" generated_text = ""
for response in mlx_generate( for response in mlx_generate(
model=model, model=model, tokenizer=tokenizer, task=task, prompt=prompt
tokenizer=tokenizer,
task=task,
): ):
generated_text += response.text generated_text += response.text
if response.finish_reason is not None: if response.finish_reason is not None:
@@ -183,11 +182,11 @@ def run_gpt_oss_tensor_parallel_device(
max_tokens=max_tokens, max_tokens=max_tokens,
) )
prompt = apply_chat_template(tokenizer, task)
generated_text = "" generated_text = ""
for response in mlx_generate( for response in mlx_generate(
model=model, model=model, tokenizer=tokenizer, task=task, prompt=prompt
tokenizer=tokenizer,
task=task,
): ):
generated_text += response.text generated_text += response.text
if response.finish_reason is not None: if response.finish_reason is not None:

View File

@@ -10,8 +10,8 @@ import pytest
from exo.worker.engines.mlx.auto_parallel import ( from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer, CustomMlxLayer,
PipelineFirstLayer, patch_pipeline_first_layer,
PipelineLastLayer, patch_pipeline_last_layer,
patch_pipeline_model, patch_pipeline_model,
) )
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
@@ -50,8 +50,8 @@ def run_pipeline_device(
group = mx.distributed.init(backend="ring", strict=True) group = mx.distributed.init(backend="ring", strict=True)
mock = MockLayerInner() mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group) first = patch_pipeline_first_layer(mock, group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group) composed = patch_pipeline_last_layer(first, group)
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather # Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
inner_model = MockModel([composed]) inner_model = MockModel([composed])
@@ -78,8 +78,8 @@ def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer() mock = MockLayer()
group = mx.distributed.init() group = mx.distributed.init()
first = PipelineFirstLayer(mock, r=0, group=group) first = patch_pipeline_first_layer(mock, group)
composed = PipelineLastLayer(first, r=0, s=1, group=group) composed = patch_pipeline_last_layer(first, group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined] assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined] assert composed.use_sliding is True # type: ignore[attr-defined]