mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-21 04:22:21 -05:00
Compare commits
1 Commits
foo
...
revert-glm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0bbc4103c |
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,6 +863,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -902,6 +903,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,6 +1520,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1527,6 +1530,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1939,6 +1943,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2646,6 +2651,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2833,6 +2839,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2977,6 +2984,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2998,6 +3006,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -17,8 +17,8 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx==0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
|
||||
@@ -276,7 +276,9 @@ def test_placement_selects_leaf_nodes(
|
||||
# arrange
|
||||
topology = Topology()
|
||||
|
||||
model_card.storage_size = Memory.from_bytes(1000)
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
model_card.n_layers = 12
|
||||
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
|
||||
@@ -477,6 +477,53 @@ async def get_downloaded_size(path: Path) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
async def download_progress_for_local_path(
|
||||
repo_id: str, shard: ShardMetadata, local_path: Path
|
||||
) -> RepoDownloadProgress:
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
total_files = 0
|
||||
total_bytes = 0
|
||||
|
||||
if await aios.path.isdir(local_path):
|
||||
for root, _, files in os.walk(local_path):
|
||||
for f in files:
|
||||
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
|
||||
file_path = Path(root) / f
|
||||
size = (await aios.stat(file_path)).st_size
|
||||
rel_path = str(file_path.relative_to(local_path))
|
||||
file_progress[rel_path] = RepoFileDownloadProgress(
|
||||
repo_id=repo_id,
|
||||
repo_revision="local",
|
||||
file_path=rel_path,
|
||||
downloaded=Memory.from_bytes(size),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(size),
|
||||
speed=0,
|
||||
eta=timedelta(0),
|
||||
status="complete",
|
||||
start_time=time.time(),
|
||||
)
|
||||
total_files += 1
|
||||
total_bytes += size
|
||||
else:
|
||||
raise ValueError(f"Local path {local_path} is not a directory")
|
||||
|
||||
return RepoDownloadProgress(
|
||||
repo_id=repo_id,
|
||||
repo_revision="local",
|
||||
shard=shard,
|
||||
completed_files=total_files,
|
||||
total_files=total_files,
|
||||
downloaded_bytes=Memory.from_bytes(total_bytes),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(total_bytes),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(0),
|
||||
status="complete",
|
||||
file_progress=file_progress,
|
||||
)
|
||||
|
||||
|
||||
async def download_shard(
|
||||
shard: ShardMetadata,
|
||||
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
@@ -487,6 +534,14 @@ async def download_shard(
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=}")
|
||||
|
||||
# Handle local paths
|
||||
if await aios.path.exists(str(shard.model_card.model_id)):
|
||||
logger.info(f"Using local model path {shard.model_card.model_id}")
|
||||
local_path = Path(str(shard.model_card.model_id))
|
||||
return local_path, await download_progress_for_local_path(
|
||||
str(shard.model_card.model_id), shard, local_path
|
||||
)
|
||||
|
||||
revision = "main"
|
||||
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
|
||||
"/", "--"
|
||||
@@ -497,8 +552,7 @@ async def download_shard(
|
||||
if not allow_patterns:
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
|
||||
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -67,16 +67,27 @@ def eval_with_timeout(
|
||||
completed.set()
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
"""Structural type that any compatible layer must satisfy.
|
||||
|
||||
We require a single positional input of type ``mx.array`` and an
|
||||
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
|
||||
protocol matches the vast majority of `mlx.nn.Module` subclasses.
|
||||
"""
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
|
||||
|
||||
|
||||
class CustomMlxLayer(nn.Module):
|
||||
"""Base class for replacing an MLX layer with a custom implementation."""
|
||||
|
||||
def __init__(self, original_layer: nn.Module):
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
|
||||
@property
|
||||
def original_layer(self) -> nn.Module:
|
||||
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -89,53 +100,52 @@ class CustomMlxLayer(nn.Module):
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
def patch_pipeline_first_layer(
|
||||
pipeline_layer: nn.Module, group: mx.distributed.Group
|
||||
) -> nn.Module:
|
||||
cls = type(pipeline_layer)
|
||||
orig_call = cast(Callable[..., mx.array], cls.__call__)
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
r: int,
|
||||
group: mx.distributed.Group,
|
||||
):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.group = group
|
||||
|
||||
rank = group.rank()
|
||||
|
||||
class PatchedFirstLayer(cls):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if rank != 0:
|
||||
x = mx.distributed.recv_like(x, (rank - 1), group=group)
|
||||
return orig_call(self, x, *args, **kwargs)
|
||||
|
||||
pipeline_layer.__class__ = PatchedFirstLayer
|
||||
|
||||
return pipeline_layer
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
def patch_pipeline_last_layer(
|
||||
pipeline_layer: nn.Module, group: mx.distributed.Group
|
||||
) -> nn.Module:
|
||||
cls = type(pipeline_layer)
|
||||
orig_call = cast(Callable[..., mx.array], cls.__call__)
|
||||
orig_call_sig = signature(orig_call)
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
r: int,
|
||||
s: int,
|
||||
group: mx.distributed.Group,
|
||||
):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.s: int = s
|
||||
self.group = group
|
||||
self.original_layer_signature = signature(self.original_layer.__call__)
|
||||
|
||||
rank = group.rank()
|
||||
size = group.size()
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
class PatchedLastLayer(cls):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
|
||||
"cache", None
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output: mx.array = orig_call(self, x, *args, **kwargs)
|
||||
|
||||
if rank != size - 1:
|
||||
output = mx.distributed.send(output, (rank + 1) % size, group=group)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
return output
|
||||
|
||||
pipeline_layer.__class__ = PatchedLastLayer
|
||||
|
||||
return pipeline_layer
|
||||
return output
|
||||
|
||||
|
||||
def _inner_model(model: nn.Module) -> nn.Module:
|
||||
@@ -150,13 +160,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
|
||||
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
|
||||
|
||||
|
||||
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
|
||||
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[nn.Module]
|
||||
layers: list[_LayerCallable]
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
layers = cast(list[nn.Module], inner_model_instance.layers)
|
||||
layers = cast(list[_LayerCallable], inner_model_instance.layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
layers = cast(list[nn.Module], inner_model_instance.h)
|
||||
layers = cast(list[_LayerCallable], inner_model_instance.h)
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
@@ -181,12 +191,15 @@ def pipeline_auto_parallel(
|
||||
layers = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
layers[0] = patch_pipeline_first_layer(layers[0], group)
|
||||
layers[-1] = patch_pipeline_last_layer(
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
group,
|
||||
device_rank,
|
||||
world_size,
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
@@ -433,7 +446,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
@@ -508,17 +521,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x) # type: ignore
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -552,8 +565,8 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
@@ -586,7 +599,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -599,17 +612,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class ShardedQwenMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x) # type: ignore
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -648,7 +661,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
@@ -661,7 +674,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x) # type: ignore
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -169,10 +169,10 @@ def mlx_distributed_init(
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
@@ -312,9 +312,6 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm-4.7-flash" in model_id_lower:
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif "glm" in model_id_lower:
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
@@ -413,6 +413,11 @@ class Worker:
|
||||
)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = SocketConnection(
|
||||
# nonsense multiaddr
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
@@ -433,9 +438,6 @@ class Worker:
|
||||
for conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn.edge, SocketConnection):
|
||||
continue
|
||||
# ignore mDNS discovered connections
|
||||
if conn.edge.sink_multiaddr.port != 52415:
|
||||
continue
|
||||
if (
|
||||
conn.sink not in conns
|
||||
or conn.edge.sink_multiaddr.ip_address
|
||||
@@ -449,7 +451,7 @@ class Worker:
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Fetching and emitting existing download progress...")
|
||||
logger.info("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
@@ -480,7 +482,7 @@ class Worker:
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.debug("Done emitting existing download progress.")
|
||||
logger.info("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -18,7 +18,7 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load, apply_chat_template
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
|
||||
|
||||
class MockLayer(nn.Module):
|
||||
@@ -116,11 +116,12 @@ def run_gpt_oss_pipeline_device(
|
||||
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
@@ -182,11 +183,11 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
|
||||
@@ -10,8 +10,8 @@ import pytest
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
CustomMlxLayer,
|
||||
patch_pipeline_first_layer,
|
||||
patch_pipeline_last_layer,
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
patch_pipeline_model,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
|
||||
@@ -50,8 +50,8 @@ def run_pipeline_device(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
mock = MockLayerInner()
|
||||
first = patch_pipeline_first_layer(mock, group)
|
||||
composed = patch_pipeline_last_layer(first, group)
|
||||
first = PipelineFirstLayer(mock, r=rank, group=group)
|
||||
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
|
||||
|
||||
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
|
||||
inner_model = MockModel([composed])
|
||||
@@ -78,8 +78,8 @@ def test_composed_wrappers_delegate_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
group = mx.distributed.init()
|
||||
|
||||
first = patch_pipeline_first_layer(mock, group)
|
||||
composed = patch_pipeline_last_layer(first, group)
|
||||
first = PipelineFirstLayer(mock, r=0, group=group)
|
||||
composed = PipelineLastLayer(first, r=0, s=1, group=group)
|
||||
|
||||
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert composed.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
@@ -18,7 +18,6 @@ def _check_model_exists() -> bool:
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
|
||||
@@ -89,8 +89,6 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
|
||||
Reference in New Issue
Block a user