Compare commits

..

8 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
8f6f2f3065 Add fixes 2026-01-20 17:13:02 +00:00
Evan
e6af53c2ae foo 2026-01-20 17:12:31 +00:00
Alex Cheema
ea9c6d6bdf Remove dead local paths code from download_shard (#1227)
## Motivation

The `download_progress_for_local_path` function and the "Handle local
paths" code block in `download_shard` are dead code that cannot be
reached in normal usage. The code checks if `model_id` (e.g.,
"mlx-community/Llama-3.2-3B-Instruct-4bit") exists as a filesystem path,
but model IDs are constrained to HuggingFace repo format and there's no
API pathway to pass local paths.

## Changes

- Removed `download_progress_for_local_path()` function (45 lines)
- Removed the "Handle local paths" block in `download_shard()` (7 lines)

## Why It Works

This code was added in PR #669 as part of a "feature-local-models"
branch, but the feature was never fully integrated. The check
`aios.path.exists(str(shard.model_card.model_id))` would only return
true if a directory literally named
"mlx-community/Llama-3.2-3B-Instruct-4bit" existed in the cwd, which
doesn't happen in practice. Offline caching is already handled by
`fetch_file_list_with_cache`.

## Test Plan

### Manual Testing
- Run exo normally and verify downloads still work

### Automated Testing
- Existing tests pass (this code had no test coverage)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 17:07:27 +00:00
Alex Cheema
4ea66d427b Reduce download log spam (#1225)
## Motivation

When `skip_download=True`, exo was logging a lot of unnecessary messages during periodic download status checks. This resulted in spammy logs that made it hard to see important messages.

## Changes

- Only log "Downloading ... with allow_patterns=..." when actually downloading (not when skip_download is true)
- Changed periodic download progress check logs from INFO to DEBUG level

## Why It Works

The `skip_download=True` parameter is used when checking download status without actually downloading. By guarding the log behind `if not skip_download:`, we avoid logging on every status check. Changing the periodic emitting logs to DEBUG level reduces noise while still keeping them available for debugging.

## Test Plan

### Manual Testing
- Run exo and observe that logs are less spammy during normal operation
- Use -v or -vv flags to see DEBUG logs when needed

### Automated Testing
- Existing tests cover this code path
2026-01-20 16:57:05 +00:00
rltakashige
8b709e68b2 Mark slow tests as slow (#1220)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 15:03:46 +00:00
Evan Quiney
4da6eeb11f fix a test broken by #1204 (#1219)
bad merge broke a test - fix it
2026-01-20 14:56:20 +00:00
Evan
3d2eee4884 quiet localhost log
this log is just noise - remove it
2026-01-20 14:51:26 +00:00
Evan
116558839e don't clear mdns discovered connections
pingers currently removes mdns discovered connections - these systems
should be independent
2026-01-20 14:46:20 +00:00
19 changed files with 264 additions and 593 deletions

View File

@@ -1 +0,0 @@
# Download package - centralized download management for exo

View File

@@ -1,304 +0,0 @@
import asyncio
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.download.download_utils import (
RepoDownloadProgress,
delete_model,
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
class DownloadCoordinator:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
download_command_receiver: Receiver[ForwarderDownloadCommand],
local_event_sender: Sender[ForwarderEvent],
):
self.node_id = node_id
self.session_id = session_id
self.shard_downloader = shard_downloader
self.download_command_receiver = download_command_receiver
self.local_event_sender = local_event_sender
# Local state
self.download_status: dict[ModelId, DownloadProgress] = {}
self.active_downloads: dict[ModelId, asyncio.Task[None]] = {}
# Internal event channel for forwarding
self.event_sender, self.event_receiver = channel[Event]()
self.local_event_index = 0
self._tg: TaskGroup = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
# Only process commands targeting this node
if cmd.command.target_node_id != self.node_id:
continue
match cmd.command:
case StartDownload(shard_metadata=shard):
await self._start_download(shard)
case CancelDownload(model_id=model_id):
await self._cancel_download(model_id)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
# Check if already downloading or complete
if model_id in self.download_status:
status = self.download_status[model_id]
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
logger.debug(
f"Download for {model_id} already in progress or complete, skipping"
)
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
# Check initial status from downloader
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
return
# Start actual download
self._start_download_task(shard, initial_progress)
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
model_id = shard.model_card.model_id
# Emit ongoing status
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[callback_shard.model_card.model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
# Clean up active download tracking
if callback_shard.model_card.model_id in self.active_downloads:
del self.active_downloads[callback_shard.model_card.model_id]
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[callback_shard.model_card.model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
except Exception as e:
logger.error(f"Download failed for {model_id}: {e}")
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
finally:
if model_id in self.active_downloads:
del self.active_downloads[model_id]
# Track and start the download task
# asyncio.create_task() immediately starts the coroutine, no need for start_soon
task = asyncio.create_task(download_wrapper())
self.active_downloads[model_id] = task
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads:
logger.info(f"Cancelling download for {model_id}")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Update status if we have shard metadata
if model_id in self.download_status:
current_status = self.download_status[model_id]
if hasattr(current_status, "shard_metadata"):
failed = DownloadFailed(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
error_message="Download cancelled",
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
async def _delete_download(self, model_id: ModelId) -> None:
# Cancel if active
if model_id in self.active_downloads:
logger.info(f"Cancelling active download for {model_id} before deletion")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Delete from disk
logger.info(f"Deleting model files for {model_id}")
deleted = await delete_model(str(model_id))
if deleted:
logger.info(f"Successfully deleted model {model_id}")
else:
logger.warning(f"Model {model_id} was not found on disk")
# Remove from status tracking
if model_id in self.download_status:
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
fe = ForwarderEvent(
origin_idx=self.local_event_index,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {self.local_event_index}: {str(event)[:100]}"
)
self.local_event_index += 1
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"
)

View File

@@ -12,8 +12,6 @@ from loguru import logger
from pydantic import PositiveInt
import exo.routing.topics as topics
from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
@@ -23,6 +21,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
@@ -30,7 +29,6 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
@@ -51,22 +49,8 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
logger.info(f"Starting node {node_id}")
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -83,11 +67,11 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
else:
worker = None
@@ -115,24 +99,13 @@ class Node:
election_result_sender=er_send,
)
return cls(
router,
download_coordinator,
worker,
election,
er_recv,
master,
api,
node_id,
)
return cls(router, worker, election, er_recv, master, api, node_id)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
@@ -197,24 +170,13 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -223,9 +185,6 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -267,7 +226,6 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -310,11 +268,6 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--no-downloads",
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -276,9 +276,7 @@ def test_placement_selects_leaf_nodes(
# arrange
topology = Topology()
# Model requires more than any single node but fits within a 3-node cycle
model_card.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_card.storage_size = Memory.from_bytes(1000)
node_id_a = NodeId()
node_id_b = NodeId()

View File

@@ -3,7 +3,7 @@ from enum import Enum
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.events import (
ForwarderEvent,
)
@@ -45,6 +45,3 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)

View File

@@ -6,13 +6,13 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.download.download_utils import (
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
class ConfigData(BaseModel):

View File

@@ -1,10 +1,10 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -43,24 +43,6 @@ class RequestEventLog(BaseCommand):
since_idx: int
class StartDownload(BaseCommand):
target_node_id: NodeId
shard_metadata: ShardMetadata
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
class DeleteDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | CancelDownload | DeleteDownload
Command = (
TestCommand
| RequestEventLog
@@ -75,8 +57,3 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
command: DownloadCommand

View File

@@ -25,16 +25,16 @@ from pydantic import (
TypeAdapter,
)
from exo.download.huggingface_utils import (
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
class ModelSafetensorsIndexMetadata(BaseModel):
@@ -477,53 +477,6 @@ async def get_downloaded_size(path: Path) -> int:
return 0
async def download_progress_for_local_path(
repo_id: str, shard: ShardMetadata, local_path: Path
) -> RepoDownloadProgress:
file_progress: dict[str, RepoFileDownloadProgress] = {}
total_files = 0
total_bytes = 0
if await aios.path.isdir(local_path):
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
file_path = Path(root) / f
size = (await aios.stat(file_path)).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
repo_revision="local",
file_path=rel_path,
downloaded=Memory.from_bytes(size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(size),
speed=0,
eta=timedelta(0),
status="complete",
start_time=time.time(),
)
total_files += 1
total_bytes += size
else:
raise ValueError(f"Local path {local_path} is not a directory")
return RepoDownloadProgress(
repo_id=repo_id,
repo_revision="local",
shard=shard,
completed_files=total_files,
total_files=total_files,
downloaded_bytes=Memory.from_bytes(total_bytes),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(total_bytes),
overall_speed=0,
overall_eta=timedelta(0),
status="complete",
file_progress=file_progress,
)
async def download_shard(
shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -534,14 +487,6 @@ async def download_shard(
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
"/", "--"
@@ -552,7 +497,8 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.

View File

@@ -3,14 +3,14 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
from exo.worker.download.shard_downloader import ShardDownloader
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:

View File

@@ -5,13 +5,13 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.download.download_utils import RepoDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Any, cast
import mlx.core as mx
import mlx.nn as nn
@@ -67,27 +67,16 @@ def eval_with_timeout(
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable):
def __init__(self, original_layer: nn.Module):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
def original_layer(self) -> nn.Module:
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -100,52 +89,53 @@ class CustomMlxLayer(nn.Module):
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
def patch_pipeline_first_layer(
pipeline_layer: nn.Module, group: mx.distributed.Group
) -> nn.Module:
cls = type(pipeline_layer)
orig_call = cast(Callable[..., mx.array], cls.__call__)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
return self.original_layer(x, *args, **kwargs)
rank = group.rank()
class PatchedFirstLayer(cls):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(self, x, *args, **kwargs)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
class PipelineLastLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
s: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
def patch_pipeline_last_layer(
pipeline_layer: nn.Module, group: mx.distributed.Group
) -> nn.Module:
cls = type(pipeline_layer)
orig_call = cast(Callable[..., mx.array], cls.__call__)
orig_call_sig = signature(orig_call)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
rank = group.rank()
size = group.size()
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
class PatchedLastLayer(cls):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
"cache", None
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
output: mx.array = orig_call(self, x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(output, (rank + 1) % size, group=group)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
def _inner_model(model: nn.Module) -> nn.Module:
@@ -160,13 +150,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
# Handle both model.layers and model.h cases
layers: list[_LayerCallable]
layers: list[nn.Module]
if hasattr(inner_model_instance, "layers"):
layers = cast(list[_LayerCallable], inner_model_instance.layers)
layers = cast(list[nn.Module], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"):
layers = cast(list[_LayerCallable], inner_model_instance.h)
layers = cast(list[nn.Module], inner_model_instance.h)
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -191,15 +181,12 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = patch_pipeline_last_layer(
layers[-1],
device_rank,
world_size,
group=group,
group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
@@ -446,7 +433,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
@@ -521,17 +508,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
@@ -565,7 +552,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
@@ -599,7 +586,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -612,17 +599,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class GptOssShardingStrategy(TensorParallelShardingStrategy):
@@ -674,7 +661,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
y = self.original_layer(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore

View File

@@ -40,7 +40,6 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
@@ -55,6 +54,7 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,

View File

@@ -2,25 +2,21 @@ from datetime import datetime, timezone
from random import random
import anyio
from anyio import CancelScope, create_task_group, fail_after
from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -37,12 +33,22 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -52,6 +58,7 @@ class Worker:
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
@@ -59,21 +66,23 @@ class Worker:
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.local_event_index = 0
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
@@ -94,6 +103,7 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -103,7 +113,6 @@ class Worker:
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
@@ -148,22 +157,14 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
def _get_local_download_status(self) -> dict[ModelId, DownloadProgress]:
"""Extract this node's download status from global state."""
downloads = self.state.downloads.get(self.node_id, [])
return {dp.shard_metadata.model_card.model_id: dp for dp in downloads}
async def plan_step(self):
while True:
await anyio.sleep(0.1)
# Get download status from state (event-sourced)
local_download_status = self._get_local_download_status()
# 3. based on the updated state, we plan & execute an operation.
task: Task | None = plan(
self.node_id,
self.runners,
local_download_status,
self.download_status,
self.state.downloads,
self.state.instances,
self.state.runners,
@@ -185,23 +186,42 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
# Send StartDownload command to DownloadCoordinator
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
# Mark task as running - completion will be detected
# when download status in state changes to DownloadCompleted
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -306,6 +326,65 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
):
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
@@ -334,11 +413,6 @@ class Worker:
)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
@@ -359,6 +433,9 @@ class Worker:
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
# ignore mDNS discovered connections
if conn.edge.sink_multiaddr.port != 52415:
continue
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address
@@ -368,3 +445,42 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -18,7 +18,7 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
from exo.worker.engines.mlx.utils_mlx import shard_and_load, apply_chat_template
class MockLayer(nn.Module):
@@ -116,12 +116,11 @@ def run_gpt_oss_pipeline_device(
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
model=model, tokenizer=tokenizer, task=task, prompt=prompt
):
generated_text += response.text
if response.finish_reason is not None:
@@ -183,11 +182,11 @@ def run_gpt_oss_tensor_parallel_device(
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
model=model, tokenizer=tokenizer, task=task, prompt=prompt
):
generated_text += response.text
if response.finish_reason is not None:

View File

@@ -10,8 +10,8 @@ import pytest
from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer,
PipelineFirstLayer,
PipelineLastLayer,
patch_pipeline_first_layer,
patch_pipeline_last_layer,
patch_pipeline_model,
)
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
@@ -50,8 +50,8 @@ def run_pipeline_device(
group = mx.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
first = patch_pipeline_first_layer(mock, group)
composed = patch_pipeline_last_layer(first, group)
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
inner_model = MockModel([composed])
@@ -78,8 +78,8 @@ def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer()
group = mx.distributed.init()
first = PipelineFirstLayer(mock, r=0, group=group)
composed = PipelineLastLayer(first, r=0, s=1, group=group)
first = patch_pipeline_first_layer(mock, group)
composed = patch_pipeline_last_layer(first, group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined]

View File

@@ -18,6 +18,7 @@ def _check_model_exists() -> bool:
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",

View File

@@ -11,12 +11,12 @@ from pathlib import Path
import pytest
from exo.download.download_utils import (
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,
@@ -89,6 +89,8 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module")
def event_loop():

View File

@@ -11,10 +11,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
@@ -40,6 +36,10 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.worker.runner.bootstrap import entrypoint