Compare commits

..

5 Commits

Author SHA1 Message Date
Evan
5019ac489e foo 2026-01-20 16:52:28 +00:00
rltakashige
8b709e68b2 Mark slow tests as slow (#1220)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 15:03:46 +00:00
Evan Quiney
4da6eeb11f fix a test broken by #1204 (#1219)
bad merge broke a test - fix it
2026-01-20 14:56:20 +00:00
Evan
3d2eee4884 quiet localhost log
this log is just noise - remove it
2026-01-20 14:51:26 +00:00
Evan
116558839e don't clear mdns discovered connections
pingers currently removes mdns discovered connections - these systems
should be independent
2026-01-20 14:46:20 +00:00
6 changed files with 68 additions and 172 deletions

View File

@@ -276,9 +276,7 @@ def test_placement_selects_leaf_nodes(
# arrange
topology = Topology()
# Model requires more than any single node but fits within a 3-node cycle
model_card.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_card.storage_size = Memory.from_bytes(1000)
node_id_a = NodeId()
node_id_b = NodeId()

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Any, cast
import mlx.core as mx
import mlx.nn as nn
@@ -66,28 +66,16 @@ def eval_with_timeout(
finally:
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable):
def __init__(self, original_layer: nn.Module):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
def original_layer(self) -> nn.Module:
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -100,53 +88,49 @@ class CustomMlxLayer(nn.Module):
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
s: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
def patch_pipeline_first_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
rank = group.rank()
class PatchedFirstLayer(nn.Module):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(x, *args, **kwargs)
output: mx.array = self.original_layer(x, *args, **kwargs)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
def patch_pipeline_last_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
orig_call_sig = signature(orig_call)
return output
rank = group.rank()
size = group.size()
class PatchedLastLayer(nn.Module):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = orig_call_sig.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
output: mx.array = orig_call(x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(
output, (rank + 1) % size, group=group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
def _inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
@@ -160,13 +144,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
# Handle both model.layers and model.h cases
layers: list[_LayerCallable]
layers: list[nn.Module]
if hasattr(inner_model_instance, "layers"):
layers = cast(list[_LayerCallable], inner_model_instance.layers)
layers = cast(list[nn.Module], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"):
layers = cast(list[_LayerCallable], inner_model_instance.h)
layers = cast(list[nn.Module], inner_model_instance.h)
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -191,15 +175,12 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = patch_pipeline_last_layer(
layers[-1],
device_rank,
world_size,
group=group,
group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
@@ -446,7 +427,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
@@ -521,17 +502,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
@@ -565,7 +546,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
@@ -599,7 +580,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -612,17 +593,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class GptOssShardingStrategy(TensorParallelShardingStrategy):
@@ -674,7 +655,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
y = self.original_layer(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore

View File

@@ -413,11 +413,6 @@ class Worker:
)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
@@ -438,6 +433,9 @@ class Worker:
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
# ignore mDNS discovered connections
if conn.edge.sink_multiaddr.port != 52415:
continue
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address

View File

@@ -18,6 +18,7 @@ def _check_model_exists() -> bool:
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",

View File

@@ -89,6 +89,8 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module")
def event_loop():

View File

@@ -1,84 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
PREFS="${PREFS:-/Library/Preferences/SystemConfiguration/preferences.plist}"
tmpdir="$(mktemp -d)"
trap 'rm -rf "$tmpdir"' EXIT
injson="$tmpdir/in.json"
outjson="$tmpdir/out.json"
plutil -convert json -o "$injson" "$PREFS"
perl -Mstrict -Mwarnings -MJSON::PP -e '
my ($in, $out) = @ARGV;
open my $fh, "<", $in or die "open $in: $!";
local $/;
my $txt = <$fh>;
close $fh;
my $json = JSON::PP->new->utf8->relaxed(1);
my $d = $json->decode($txt);
if (ref($d->{VirtualNetworkInterfaces}) eq "HASH"
&& ref($d->{VirtualNetworkInterfaces}{Bridge}) eq "HASH") {
delete $d->{VirtualNetworkInterfaces}{Bridge}{bridge0};
}
my @bridge_svcs;
if (ref($d->{NetworkServices}) eq "HASH") {
for my $k (keys %{ $d->{NetworkServices} }) {
my $svc = $d->{NetworkServices}{$k};
next unless ref($svc) eq "HASH";
my $iface = $svc->{Interface};
next unless ref($iface) eq "HASH";
my $dev = $iface->{DeviceName};
if (defined $dev && $dev eq "bridge0") {
push @bridge_svcs, $k;
}
}
delete @{ $d->{NetworkServices} }{ @bridge_svcs } if @bridge_svcs;
}
my %is_bridge = map { $_ => 1 } @bridge_svcs;
if (ref($d->{Sets}) eq "HASH") {
for my $setk (keys %{ $d->{Sets} }) {
my $set = $d->{Sets}{$setk};
next unless ref($set) eq "HASH";
my $net = $set->{Network};
next unless ref($net) eq "HASH";
if (ref($net->{Interface}) eq "HASH") {
delete $net->{Interface}{bridge0};
}
if (ref($net->{Service}) eq "HASH" && @bridge_svcs) {
for my $svc (@bridge_svcs) {
delete $net->{Service}{$svc};
}
}
my $g = $net->{Global};
if (ref($g) eq "HASH"
&& ref($g->{IPv4}) eq "HASH"
&& ref($g->{IPv4}{ServiceOrder}) eq "ARRAY"
&& @bridge_svcs) {
my @so = @{ $g->{IPv4}{ServiceOrder} };
@so = grep { !defined($_) || !$is_bridge{$_} } @so;
$g->{IPv4}{ServiceOrder} = \@so;
}
}
}
open my $oh, ">", $out or die "open $out: $!";
print $oh JSON::PP->new->utf8->canonical(1)->pretty(1)->encode($d);
close $oh;
' "$injson" "$outjson"
# Convert JSON -> plist (write back as binary1; change to xml1 if you prefer)
plutil -convert xml1 -o "$PREFS" "$outjson"
# Ask configd to reload SystemConfiguration state
killall -HUP configd 2>/dev/null || true