mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-20 11:58:57 -05:00
Compare commits
1 Commits
upstream-s
...
new-bridge
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1bca96747d |
@@ -276,7 +276,9 @@ def test_placement_selects_leaf_nodes(
|
||||
# arrange
|
||||
topology = Topology()
|
||||
|
||||
model_card.storage_size = Memory.from_bytes(1000)
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
model_card.n_layers = 12
|
||||
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
|
||||
@@ -13,11 +13,17 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
@@ -335,7 +341,33 @@ def tensor_auto_parallel(
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, MiniMaxModel):
|
||||
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -343,6 +375,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -377,6 +418,34 @@ class TensorParallelShardingStrategy(ABC):
|
||||
) -> nn.Module: ...
|
||||
|
||||
|
||||
class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(LlamaModel, model)
|
||||
for layer in model.layers:
|
||||
# Force load weights before sharding to avoid FAST_SYNCH deadlock
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
if layer.self_attn.n_kv_heads is not None:
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
@@ -403,6 +472,105 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
else:
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -455,3 +623,58 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
layer.self_attn.num_key_value_groups = (
|
||||
layer.self_attn.num_attention_heads
|
||||
// layer.self_attn.num_key_value_heads
|
||||
)
|
||||
|
||||
layer.self_attn.sinks = layer.self_attn.sinks[
|
||||
layer.self_attn.num_attention_heads
|
||||
* self.group.rank() : layer.self_attn.num_attention_heads
|
||||
* (self.group.rank() + 1)
|
||||
]
|
||||
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -413,6 +413,11 @@ class Worker:
|
||||
)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = SocketConnection(
|
||||
# nonsense multiaddr
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
@@ -433,9 +438,6 @@ class Worker:
|
||||
for conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn.edge, SocketConnection):
|
||||
continue
|
||||
# ignore mDNS discovered connections
|
||||
if conn.edge.sink_multiaddr.port != 52415:
|
||||
continue
|
||||
if (
|
||||
conn.sink not in conns
|
||||
or conn.edge.sink_multiaddr.ip_address
|
||||
|
||||
@@ -18,7 +18,6 @@ def _check_model_exists() -> bool:
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
|
||||
@@ -89,8 +89,6 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
|
||||
84
tmp/kill_bridge_plist.sh
Normal file
84
tmp/kill_bridge_plist.sh
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="${PREFS:-/Library/Preferences/SystemConfiguration/preferences.plist}"
|
||||
|
||||
tmpdir="$(mktemp -d)"
|
||||
trap 'rm -rf "$tmpdir"' EXIT
|
||||
injson="$tmpdir/in.json"
|
||||
outjson="$tmpdir/out.json"
|
||||
plutil -convert json -o "$injson" "$PREFS"
|
||||
|
||||
perl -Mstrict -Mwarnings -MJSON::PP -e '
|
||||
my ($in, $out) = @ARGV;
|
||||
|
||||
open my $fh, "<", $in or die "open $in: $!";
|
||||
local $/;
|
||||
my $txt = <$fh>;
|
||||
close $fh;
|
||||
|
||||
my $json = JSON::PP->new->utf8->relaxed(1);
|
||||
my $d = $json->decode($txt);
|
||||
|
||||
if (ref($d->{VirtualNetworkInterfaces}) eq "HASH"
|
||||
&& ref($d->{VirtualNetworkInterfaces}{Bridge}) eq "HASH") {
|
||||
delete $d->{VirtualNetworkInterfaces}{Bridge}{bridge0};
|
||||
}
|
||||
|
||||
my @bridge_svcs;
|
||||
if (ref($d->{NetworkServices}) eq "HASH") {
|
||||
for my $k (keys %{ $d->{NetworkServices} }) {
|
||||
my $svc = $d->{NetworkServices}{$k};
|
||||
next unless ref($svc) eq "HASH";
|
||||
my $iface = $svc->{Interface};
|
||||
next unless ref($iface) eq "HASH";
|
||||
my $dev = $iface->{DeviceName};
|
||||
if (defined $dev && $dev eq "bridge0") {
|
||||
push @bridge_svcs, $k;
|
||||
}
|
||||
}
|
||||
delete @{ $d->{NetworkServices} }{ @bridge_svcs } if @bridge_svcs;
|
||||
}
|
||||
|
||||
my %is_bridge = map { $_ => 1 } @bridge_svcs;
|
||||
|
||||
if (ref($d->{Sets}) eq "HASH") {
|
||||
for my $setk (keys %{ $d->{Sets} }) {
|
||||
my $set = $d->{Sets}{$setk};
|
||||
next unless ref($set) eq "HASH";
|
||||
my $net = $set->{Network};
|
||||
next unless ref($net) eq "HASH";
|
||||
|
||||
if (ref($net->{Interface}) eq "HASH") {
|
||||
delete $net->{Interface}{bridge0};
|
||||
}
|
||||
|
||||
if (ref($net->{Service}) eq "HASH" && @bridge_svcs) {
|
||||
for my $svc (@bridge_svcs) {
|
||||
delete $net->{Service}{$svc};
|
||||
}
|
||||
}
|
||||
|
||||
my $g = $net->{Global};
|
||||
if (ref($g) eq "HASH"
|
||||
&& ref($g->{IPv4}) eq "HASH"
|
||||
&& ref($g->{IPv4}{ServiceOrder}) eq "ARRAY"
|
||||
&& @bridge_svcs) {
|
||||
|
||||
my @so = @{ $g->{IPv4}{ServiceOrder} };
|
||||
@so = grep { !defined($_) || !$is_bridge{$_} } @so;
|
||||
$g->{IPv4}{ServiceOrder} = \@so;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
open my $oh, ">", $out or die "open $out: $!";
|
||||
print $oh JSON::PP->new->utf8->canonical(1)->pretty(1)->encode($d);
|
||||
close $oh;
|
||||
' "$injson" "$outjson"
|
||||
|
||||
# Convert JSON -> plist (write back as binary1; change to xml1 if you prefer)
|
||||
plutil -convert xml1 -o "$PREFS" "$outjson"
|
||||
|
||||
# Ask configd to reload SystemConfiguration state
|
||||
killall -HUP configd 2>/dev/null || true
|
||||
Reference in New Issue
Block a user