Update mlx and mlx-lm packages

Co-authored-by: Evan <evanev7@gmail.com>
This commit is contained in:
rltakashige
2025-10-31 01:34:43 +00:00
committed by GitHub
parent a346af3477
commit 91c635ca7a
15 changed files with 966 additions and 620 deletions

View File

@@ -26,8 +26,8 @@ dependencies = [
"sqlalchemy[asyncio]>=2.0.43",
"greenlet>=3.2.4",
"huggingface-hub>=0.33.4",
"mlx==0.26.3",
"mlx-lm==0.26.4",
"mlx==0.29.3",
"mlx-lm==0.28.3",
"psutil>=7.0.0",
"transformers>=4.55.2",
"cobs>=1.2.2",

View File

@@ -1,4 +1,4 @@
from typing import Protocol, cast, override
from typing import cast, override, Protocol, TYPE_CHECKING
import mlx.core as mx
import mlx.nn as nn # pyright: ignore[reportMissingTypeStubs]
@@ -22,10 +22,29 @@ class _LayerCallable(Protocol):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class PipelineFirstLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable):
super().__init__()
# Set twice to avoid __setattr__ recursion
object.__setattr__(self, "_original_layer", original_layer)
self.original_layer: _LayerCallable = original_layer
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
return object.__getattribute__(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
@@ -36,10 +55,9 @@ class PipelineFirstLayer(nn.Module):
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(nn.Module):
class PipelineLastLayer(CustomMlxLayer):
def __init__(self, original_layer: _LayerCallable, r: int, s: int):
super().__init__()
self.original_layer: _LayerCallable = original_layer
super().__init__(original_layer)
self.r: int = r
self.s: int = s
@@ -48,7 +66,7 @@ class PipelineLastLayer(nn.Module):
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
output = mx.distributed.send(output, (self.r + 1) % self.s)
output = mx.distributed.all_gather(output)[-output.shape[0] :] # pyright: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output)[-output.shape[0] :]
return output

View File

@@ -50,7 +50,7 @@ def broadcast_from_zero(value: int) -> int:
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu))
mx.eval(m) # type: ignore
return int(m.item()) # type: ignore
return int(m.item())
class HostList(RootModel[list[str]]):
@@ -65,7 +65,9 @@ def mlx_setup(
wired_frac_of_mrwss: float = 0.00, # start with no wiring
) -> None:
if not mx.metal.is_available():
logger.warning("Metal is not available. Skipping MLX memory wired limits setup.")
logger.warning(
"Metal is not available. Skipping MLX memory wired limits setup."
)
return
info = mx.metal.device_info()
mrwss = int(info["max_recommended_working_set_size"]) # bytes
@@ -216,8 +218,8 @@ class NullKVCache(KVCache):
def __init__(self, dtype: mx.Dtype = mx.float16):
super().__init__()
# zero-length K/V so shapes/dtypes are defined but empty
self.keys = mx.zeros((1, 1, 0, 1), dtype=dtype) # pyright: ignore[reportUnknownMemberType]
self.values = mx.zeros((1, 1, 0, 1), dtype=dtype) # pyright: ignore[reportUnknownMemberType]
self.keys = mx.zeros((1, 1, 0, 1), dtype=dtype)
self.values = mx.zeros((1, 1, 0, 1), dtype=dtype)
self.offset = 0
@property
@@ -247,11 +249,11 @@ def mlx_force_oom(size: int = 40000) -> None:
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
"""
mx.set_default_device(mx.gpu) # type: ignore
a = mx.random.uniform(shape=(size, size), dtype=mx.float32) # type: ignore
b = mx.random.uniform(shape=(size, size), dtype=mx.float32) # type: ignore
a = mx.random.uniform(shape=(size, size), dtype=mx.float32)
b = mx.random.uniform(shape=(size, size), dtype=mx.float32)
mx.eval(a, b) # type: ignore
c = mx.matmul(a, b) # type: ignore
d = mx.matmul(a, c) # type: ignore
e = mx.matmul(b, c) # type: ignore
f = mx.sigmoid(d + e) # type: ignore
c = mx.matmul(a, b)
d = mx.matmul(a, c)
e = mx.matmul(b, c)
f = mx.sigmoid(d + e)
mx.eval(f) # type: ignore

View File

@@ -63,7 +63,8 @@ def get_instance_placements_after_create(
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycle for cycle in smallest_cycles
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node.node_id) for node in cycle)
]

View File

@@ -300,7 +300,9 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
)
# Act
placements = get_instance_placements_after_create(create_instance_command, topology, {})
placements = get_instance_placements_after_create(
create_instance_command, topology, {}
)
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
# D-E-F has more total memory.

View File

@@ -50,7 +50,10 @@ class Topology:
self._rx_id_to_node_id_map[rx_id] = node.node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return node_id in self._node_id_to_rx_id_map and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
return (
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
)
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._node_id_to_rx_id_map

View File

@@ -10,7 +10,7 @@ from exo.utils.pydantic_ext import TaggedModel
class TaskId(Id):
pass
class TaskStatus(str, Enum):
Pending = "Pending"

View File

@@ -16,6 +16,7 @@ class DownloadProgressData(CamelCaseModel):
files: dict[str, "DownloadProgressData"]
class BaseDownloadProgress(TaggedModel):
node_id: NodeId

View File

@@ -61,7 +61,7 @@ class RepoFileDownloadProgress(BaseModel):
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen = True)
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
@@ -78,16 +78,18 @@ class RepoDownloadProgress(BaseModel):
status: Literal["not_started", "in_progress", "complete"]
file_progress: Dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(
frozen = True
)
model_config = ConfigDict(frozen=True)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
return etag
def map_repo_file_download_progress_to_download_progress_data(repo_file_download_progress: RepoFileDownloadProgress) -> DownloadProgressData:
def map_repo_file_download_progress_to_download_progress_data(
repo_file_download_progress: RepoFileDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
downloaded_bytes=repo_file_download_progress.downloaded,
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
@@ -98,7 +100,11 @@ def map_repo_file_download_progress_to_download_progress_data(repo_file_download
eta_ms=int(repo_file_download_progress.eta.total_seconds() * 1000),
files={},
)
def map_repo_download_progress_to_download_progress_data(repo_download_progress: RepoDownloadProgress) -> DownloadProgressData:
def map_repo_download_progress_to_download_progress_data(
repo_download_progress: RepoDownloadProgress,
) -> DownloadProgressData:
return DownloadProgressData(
total_bytes=repo_download_progress.total_bytes,
downloaded_bytes=repo_download_progress.downloaded_bytes,
@@ -107,9 +113,15 @@ def map_repo_download_progress_to_download_progress_data(repo_download_progress:
total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed,
eta_ms=int(repo_download_progress.overall_eta.total_seconds() * 1000),
files={file_path: map_repo_file_download_progress_to_download_progress_data(file_progress) for file_path, file_progress in repo_download_progress.file_progress.items()},
files={
file_path: map_repo_file_download_progress_to_download_progress_data(
file_progress
)
for file_path, file_progress in repo_download_progress.file_progress.items()
},
)
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_HOME / "models" / model_id.replace("/", "--")
@@ -235,6 +247,7 @@ async def _fetch_file_list(
async def get_download_headers() -> dict[str, str]:
return {**(await get_auth_headers()), "Accept-Encoding": "identity"}
def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long",
@@ -260,6 +273,7 @@ def create_http_session(
),
)
async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -> str:
hasher = hashlib.sha1() if hash_type == "sha1" else hashlib.sha256()
if hash_type == "sha1":
@@ -395,8 +409,12 @@ def calculate_repo_progress(
all_start_time: float,
) -> RepoDownloadProgress:
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes = sum((p.downloaded.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes_this_session = sum((p.downloaded_this_session.in_bytes for p in file_progress.values()), 0)
all_downloaded_bytes = sum(
(p.downloaded.in_bytes for p in file_progress.values()), 0
)
all_downloaded_bytes_this_session = sum(
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
)
elapsed_time = time.time() - all_start_time
all_speed = (
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
@@ -422,7 +440,9 @@ def calculate_repo_progress(
),
total_files=len(file_progress),
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
downloaded_bytes_this_session=Memory.from_bytes(all_downloaded_bytes_this_session),
downloaded_bytes_this_session=Memory.from_bytes(
all_downloaded_bytes_this_session
),
total_bytes=Memory.from_bytes(all_total_bytes),
overall_speed=all_speed,
overall_eta=all_eta,

View File

@@ -1,8 +1,8 @@
import asyncio
import time
from asyncio import Queue
from functools import partial
from random import random
import time
from typing import AsyncGenerator, Optional
import anyio
@@ -198,6 +198,10 @@ class Worker:
async for event in self.execute_op(op):
await self.event_publisher(event)
except Exception as e:
logger.opt(exception=e).warning(
f"Error occurred when executing task", flush=True
)
if isinstance(op, ExecuteTaskOp):
generator = self.fail_task(
e, runner_id=op.runner_id, task_id=op.task.task_id
@@ -319,7 +323,9 @@ class Worker:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=map_repo_download_progress_to_download_progress_data(initial_progress),
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
)
yield assigned_runner.status_update_event()
@@ -373,7 +379,9 @@ class Worker:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=map_repo_download_progress_to_download_progress_data(progress),
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
)
yield assigned_runner.status_update_event()
@@ -621,8 +629,6 @@ class Worker:
async for event in self.fail_runner(e, runner_id):
yield event
# This function is re-entrant, take care!
async def event_publisher(self, event: Event) -> None:
fe = ForwarderEvent(

View File

@@ -30,6 +30,7 @@ from exo.shared.types.worker.communication import (
runner_print,
)
generation_stream = mx.new_stream(mx.default_device())
@@ -82,7 +83,7 @@ def generate_step(
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, keepdims=True) # pyright: ignore[reportUnknownMemberType]
logprobs = logits - mx.logsumexp(logits, keepdims=True)
sampled = sampler(logprobs)
return sampled, logprobs.squeeze(0)
@@ -220,7 +221,7 @@ async def warmup_inference(
def _generate_warmup():
nonlocal tokens_generated
for _ in stream_generate(
for token in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=warmup_prompt,
@@ -228,6 +229,7 @@ async def warmup_inference(
sampler=sampler,
conn=None,
):
runner_print("Generated warmup token: " + str(token.text))
tokens_generated += 1
await loop.run_in_executor(mlx_executor, _generate_warmup)

View File

@@ -65,7 +65,7 @@ def get_init_timeout(model_shard_meta: ShardMetadata) -> float:
kbps_read = 1024 * 1024 * LB_DISK_GBPS / 3
return weights_size.in_kb / kbps_read + 2.0
return weights_size.in_kb / kbps_read + 30.0
def _prefill_flops_for_shard(model_shard_meta: ShardMetadata, s: int) -> float:

View File

@@ -117,7 +117,14 @@ def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus:
download_progress=DownloadOngoing(
node_id=node_id,
download_progress=DownloadProgressData(
total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0), downloaded_bytes_this_session=Memory.from_bytes(0), completed_files=0, total_files=0, speed=0, eta_ms=0, files={}
total_bytes=Memory.from_bytes(1),
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
completed_files=0,
total_files=0,
speed=0,
eta_ms=0,
files={},
),
)
)

View File

@@ -51,12 +51,16 @@ async def get_memory_profile_async() -> MemoryPerformanceProfile:
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes if override_memory_env else None
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_bytes(
ram_total=int(vm.total),
ram_available=int(override_memory) if override_memory else int(vm.available),
ram_available=int(override_memory)
if override_memory
else int(vm.available),
swap_total=int(sm.total),
swap_available=int(sm.free),
)

1432
uv.lock generated
View File

File diff suppressed because it is too large Load Diff