mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Update mlx and mlx-lm packages
Co-authored-by: Evan <evanev7@gmail.com>
This commit is contained in:
@@ -26,8 +26,8 @@ dependencies = [
|
||||
"sqlalchemy[asyncio]>=2.0.43",
|
||||
"greenlet>=3.2.4",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"mlx==0.26.3",
|
||||
"mlx-lm==0.26.4",
|
||||
"mlx==0.29.3",
|
||||
"mlx-lm==0.28.3",
|
||||
"psutil>=7.0.0",
|
||||
"transformers>=4.55.2",
|
||||
"cobs>=1.2.2",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Protocol, cast, override
|
||||
from typing import cast, override, Protocol, TYPE_CHECKING
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn # pyright: ignore[reportMissingTypeStubs]
|
||||
@@ -22,10 +22,29 @@ class _LayerCallable(Protocol):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
|
||||
|
||||
|
||||
class PipelineFirstLayer(nn.Module):
|
||||
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
|
||||
class CustomMlxLayer(nn.Module):
|
||||
"""Base class for replacing an MLX layer with a custom implementation."""
|
||||
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
# Set twice to avoid __setattr__ recursion
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
self.original_layer: _LayerCallable = original_layer
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
original_layer = object.__getattribute__(self, "_original_layer")
|
||||
return object.__getattribute__(original_layer, name)
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.s: int = s
|
||||
|
||||
@@ -36,10 +55,9 @@ class PipelineFirstLayer(nn.Module):
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
class PipelineLastLayer(nn.Module):
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
|
||||
super().__init__()
|
||||
self.original_layer: _LayerCallable = original_layer
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.s: int = s
|
||||
|
||||
@@ -48,7 +66,7 @@ class PipelineLastLayer(nn.Module):
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(output, (self.r + 1) % self.s)
|
||||
output = mx.distributed.all_gather(output)[-output.shape[0] :] # pyright: ignore[reportUnknownMemberType]
|
||||
output = mx.distributed.all_gather(output)[-output.shape[0] :]
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def broadcast_from_zero(value: int) -> int:
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu))
|
||||
mx.eval(m) # type: ignore
|
||||
return int(m.item()) # type: ignore
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@@ -65,7 +65,9 @@ def mlx_setup(
|
||||
wired_frac_of_mrwss: float = 0.00, # start with no wiring
|
||||
) -> None:
|
||||
if not mx.metal.is_available():
|
||||
logger.warning("Metal is not available. Skipping MLX memory wired limits setup.")
|
||||
logger.warning(
|
||||
"Metal is not available. Skipping MLX memory wired limits setup."
|
||||
)
|
||||
return
|
||||
info = mx.metal.device_info()
|
||||
mrwss = int(info["max_recommended_working_set_size"]) # bytes
|
||||
@@ -216,8 +218,8 @@ class NullKVCache(KVCache):
|
||||
def __init__(self, dtype: mx.Dtype = mx.float16):
|
||||
super().__init__()
|
||||
# zero-length K/V so shapes/dtypes are defined but empty
|
||||
self.keys = mx.zeros((1, 1, 0, 1), dtype=dtype) # pyright: ignore[reportUnknownMemberType]
|
||||
self.values = mx.zeros((1, 1, 0, 1), dtype=dtype) # pyright: ignore[reportUnknownMemberType]
|
||||
self.keys = mx.zeros((1, 1, 0, 1), dtype=dtype)
|
||||
self.values = mx.zeros((1, 1, 0, 1), dtype=dtype)
|
||||
self.offset = 0
|
||||
|
||||
@property
|
||||
@@ -247,11 +249,11 @@ def mlx_force_oom(size: int = 40000) -> None:
|
||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||
"""
|
||||
mx.set_default_device(mx.gpu) # type: ignore
|
||||
a = mx.random.uniform(shape=(size, size), dtype=mx.float32) # type: ignore
|
||||
b = mx.random.uniform(shape=(size, size), dtype=mx.float32) # type: ignore
|
||||
a = mx.random.uniform(shape=(size, size), dtype=mx.float32)
|
||||
b = mx.random.uniform(shape=(size, size), dtype=mx.float32)
|
||||
mx.eval(a, b) # type: ignore
|
||||
c = mx.matmul(a, b) # type: ignore
|
||||
d = mx.matmul(a, c) # type: ignore
|
||||
e = mx.matmul(b, c) # type: ignore
|
||||
f = mx.sigmoid(d + e) # type: ignore
|
||||
c = mx.matmul(a, b)
|
||||
d = mx.matmul(a, c)
|
||||
e = mx.matmul(b, c)
|
||||
f = mx.sigmoid(d + e)
|
||||
mx.eval(f) # type: ignore
|
||||
|
||||
@@ -63,7 +63,8 @@ def get_instance_placements_after_create(
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
|
||||
cycle for cycle in smallest_cycles
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if any(topology.node_is_leaf(node.node_id) for node in cycle)
|
||||
]
|
||||
|
||||
|
||||
@@ -300,7 +300,9 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
)
|
||||
|
||||
# Act
|
||||
placements = get_instance_placements_after_create(create_instance_command, topology, {})
|
||||
placements = get_instance_placements_after_create(
|
||||
create_instance_command, topology, {}
|
||||
)
|
||||
|
||||
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
|
||||
# D-E-F has more total memory.
|
||||
|
||||
@@ -50,7 +50,10 @@ class Topology:
|
||||
self._rx_id_to_node_id_map[rx_id] = node.node_id
|
||||
|
||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._node_id_to_rx_id_map and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
||||
return (
|
||||
node_id in self._node_id_to_rx_id_map
|
||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
||||
)
|
||||
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
|
||||
@@ -10,7 +10,7 @@ from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
class TaskId(Id):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
Pending = "Pending"
|
||||
|
||||
@@ -16,6 +16,7 @@ class DownloadProgressData(CamelCaseModel):
|
||||
|
||||
files: dict[str, "DownloadProgressData"]
|
||||
|
||||
|
||||
class BaseDownloadProgress(TaggedModel):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class RepoFileDownloadProgress(BaseModel):
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
model_config = ConfigDict(frozen = True)
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class RepoDownloadProgress(BaseModel):
|
||||
@@ -78,16 +78,18 @@ class RepoDownloadProgress(BaseModel):
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
file_progress: Dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(
|
||||
frozen = True
|
||||
)
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
def trim_etag(etag: str) -> str:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
return etag[1:-1]
|
||||
return etag
|
||||
|
||||
def map_repo_file_download_progress_to_download_progress_data(repo_file_download_progress: RepoFileDownloadProgress) -> DownloadProgressData:
|
||||
|
||||
def map_repo_file_download_progress_to_download_progress_data(
|
||||
repo_file_download_progress: RepoFileDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
downloaded_bytes=repo_file_download_progress.downloaded,
|
||||
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
@@ -98,7 +100,11 @@ def map_repo_file_download_progress_to_download_progress_data(repo_file_download
|
||||
eta_ms=int(repo_file_download_progress.eta.total_seconds() * 1000),
|
||||
files={},
|
||||
)
|
||||
def map_repo_download_progress_to_download_progress_data(repo_download_progress: RepoDownloadProgress) -> DownloadProgressData:
|
||||
|
||||
|
||||
def map_repo_download_progress_to_download_progress_data(
|
||||
repo_download_progress: RepoDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
total_bytes=repo_download_progress.total_bytes,
|
||||
downloaded_bytes=repo_download_progress.downloaded_bytes,
|
||||
@@ -107,9 +113,15 @@ def map_repo_download_progress_to_download_progress_data(repo_download_progress:
|
||||
total_files=repo_download_progress.total_files,
|
||||
speed=repo_download_progress.overall_speed,
|
||||
eta_ms=int(repo_download_progress.overall_eta.total_seconds() * 1000),
|
||||
files={file_path: map_repo_file_download_progress_to_download_progress_data(file_progress) for file_path, file_progress in repo_download_progress.file_progress.items()},
|
||||
files={
|
||||
file_path: map_repo_file_download_progress_to_download_progress_data(
|
||||
file_progress
|
||||
)
|
||||
for file_path, file_progress in repo_download_progress.file_progress.items()
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_model_path(model_id: str) -> DirectoryPath:
|
||||
return EXO_HOME / "models" / model_id.replace("/", "--")
|
||||
|
||||
@@ -235,6 +247,7 @@ async def _fetch_file_list(
|
||||
async def get_download_headers() -> dict[str, str]:
|
||||
return {**(await get_auth_headers()), "Accept-Encoding": "identity"}
|
||||
|
||||
|
||||
def create_http_session(
|
||||
auto_decompress: bool = False,
|
||||
timeout_profile: Literal["short", "long"] = "long",
|
||||
@@ -260,6 +273,7 @@ def create_http_session(
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -> str:
|
||||
hasher = hashlib.sha1() if hash_type == "sha1" else hashlib.sha256()
|
||||
if hash_type == "sha1":
|
||||
@@ -395,8 +409,12 @@ def calculate_repo_progress(
|
||||
all_start_time: float,
|
||||
) -> RepoDownloadProgress:
|
||||
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
|
||||
all_downloaded_bytes = sum((p.downloaded.in_bytes for p in file_progress.values()), 0)
|
||||
all_downloaded_bytes_this_session = sum((p.downloaded_this_session.in_bytes for p in file_progress.values()), 0)
|
||||
all_downloaded_bytes = sum(
|
||||
(p.downloaded.in_bytes for p in file_progress.values()), 0
|
||||
)
|
||||
all_downloaded_bytes_this_session = sum(
|
||||
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
|
||||
)
|
||||
elapsed_time = time.time() - all_start_time
|
||||
all_speed = (
|
||||
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
||||
@@ -422,7 +440,9 @@ def calculate_repo_progress(
|
||||
),
|
||||
total_files=len(file_progress),
|
||||
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(all_downloaded_bytes_this_session),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(
|
||||
all_downloaded_bytes_this_session
|
||||
),
|
||||
total_bytes=Memory.from_bytes(all_total_bytes),
|
||||
overall_speed=all_speed,
|
||||
overall_eta=all_eta,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import time
|
||||
from asyncio import Queue
|
||||
from functools import partial
|
||||
from random import random
|
||||
import time
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import anyio
|
||||
@@ -198,6 +198,10 @@ class Worker:
|
||||
async for event in self.execute_op(op):
|
||||
await self.event_publisher(event)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Error occurred when executing task", flush=True
|
||||
)
|
||||
|
||||
if isinstance(op, ExecuteTaskOp):
|
||||
generator = self.fail_task(
|
||||
e, runner_id=op.runner_id, task_id=op.task.task_id
|
||||
@@ -319,7 +323,9 @@ class Worker:
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(initial_progress),
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
@@ -373,7 +379,9 @@ class Worker:
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(progress),
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
@@ -621,8 +629,6 @@ class Worker:
|
||||
async for event in self.fail_runner(e, runner_id):
|
||||
yield event
|
||||
|
||||
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def event_publisher(self, event: Event) -> None:
|
||||
fe = ForwarderEvent(
|
||||
|
||||
@@ -30,6 +30,7 @@ from exo.shared.types.worker.communication import (
|
||||
runner_print,
|
||||
)
|
||||
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
@@ -82,7 +83,7 @@ def generate_step(
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True) # pyright: ignore[reportUnknownMemberType]
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
sampled = sampler(logprobs)
|
||||
return sampled, logprobs.squeeze(0)
|
||||
|
||||
@@ -220,7 +221,7 @@ async def warmup_inference(
|
||||
|
||||
def _generate_warmup():
|
||||
nonlocal tokens_generated
|
||||
for _ in stream_generate(
|
||||
for token in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=warmup_prompt,
|
||||
@@ -228,6 +229,7 @@ async def warmup_inference(
|
||||
sampler=sampler,
|
||||
conn=None,
|
||||
):
|
||||
runner_print("Generated warmup token: " + str(token.text))
|
||||
tokens_generated += 1
|
||||
|
||||
await loop.run_in_executor(mlx_executor, _generate_warmup)
|
||||
|
||||
@@ -65,7 +65,7 @@ def get_init_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
|
||||
kbps_read = 1024 * 1024 * LB_DISK_GBPS / 3
|
||||
|
||||
return weights_size.in_kb / kbps_read + 2.0
|
||||
return weights_size.in_kb / kbps_read + 30.0
|
||||
|
||||
|
||||
def _prefill_flops_for_shard(model_shard_meta: ShardMetadata, s: int) -> float:
|
||||
|
||||
@@ -117,7 +117,14 @@ def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus:
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=node_id,
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0), downloaded_bytes_this_session=Memory.from_bytes(0), completed_files=0, total_files=0, speed=0, eta_ms=0, files={}
|
||||
total_bytes=Memory.from_bytes(1),
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
speed=0,
|
||||
eta_ms=0,
|
||||
files={},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -51,12 +51,16 @@ async def get_memory_profile_async() -> MemoryPerformanceProfile:
|
||||
|
||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||
override_memory: int | None = (
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes if override_memory_env else None
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||
if override_memory_env
|
||||
else None
|
||||
)
|
||||
|
||||
return MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=int(vm.total),
|
||||
ram_available=int(override_memory) if override_memory else int(vm.available),
|
||||
ram_available=int(override_memory)
|
||||
if override_memory
|
||||
else int(vm.available),
|
||||
swap_total=int(sm.total),
|
||||
swap_available=int(sm.free),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user