Compare commits

...

4 Commits
aiohttp ... foo

Author SHA1 Message Date
Ryuichi Leo Takashige
8f6f2f3065 Add fixes 2026-01-20 17:13:02 +00:00
Evan
e6af53c2ae foo 2026-01-20 17:12:31 +00:00
Alex Cheema
ea9c6d6bdf Remove dead local paths code from download_shard (#1227)
## Motivation

The `download_progress_for_local_path` function and the "Handle local
paths" code block in `download_shard` are dead code that cannot be
reached in normal usage. The code checks if `model_id` (e.g.,
"mlx-community/Llama-3.2-3B-Instruct-4bit") exists as a filesystem path,
but model IDs are constrained to HuggingFace repo format and there's no
API pathway to pass local paths.

## Changes

- Removed `download_progress_for_local_path()` function (45 lines)
- Removed the "Handle local paths" block in `download_shard()` (7 lines)

## Why It Works

This code was added in PR #669 as part of a "feature-local-models"
branch, but the feature was never fully integrated. The check
`aios.path.exists(str(shard.model_card.model_id))` would only return
true if a directory literally named
"mlx-community/Llama-3.2-3B-Instruct-4bit" existed in the cwd, which
doesn't happen in practice. Offline caching is already handled by
`fetch_file_list_with_cache`.

## Test Plan

### Manual Testing
- Run exo normally and verify downloads still work

### Automated Testing
- Existing tests pass (this code had no test coverage)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 17:07:27 +00:00
Alex Cheema
4ea66d427b Reduce download log spam (#1225)
## Motivation

When `skip_download=True`, exo was logging a lot of unnecessary messages during periodic download status checks. This resulted in spammy logs that made it hard to see important messages.

## Changes

- Only log "Downloading ... with allow_patterns=..." when actually downloading (not when skip_download is true)
- Changed periodic download progress check logs from INFO to DEBUG level

## Why It Works

The `skip_download=True` parameter is used when checking download status without actually downloading. By guarding the log behind `if not skip_download:`, we avoid logging on every status check. Changing the periodic emitting logs to DEBUG level reduces noise while still keeping them available for debugging.

## Test Plan

### Manual Testing
- Run exo and observe that logs are less spammy during normal operation
- Use -v or -vv flags to see DEBUG logs when needed

### Automated Testing
- Existing tests cover this code path
2026-01-20 16:57:05 +00:00
5 changed files with 81 additions and 149 deletions

View File

@@ -477,53 +477,6 @@ async def get_downloaded_size(path: Path) -> int:
return 0 return 0
async def download_progress_for_local_path(
repo_id: str, shard: ShardMetadata, local_path: Path
) -> RepoDownloadProgress:
file_progress: dict[str, RepoFileDownloadProgress] = {}
total_files = 0
total_bytes = 0
if await aios.path.isdir(local_path):
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
file_path = Path(root) / f
size = (await aios.stat(file_path)).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
repo_revision="local",
file_path=rel_path,
downloaded=Memory.from_bytes(size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(size),
speed=0,
eta=timedelta(0),
status="complete",
start_time=time.time(),
)
total_files += 1
total_bytes += size
else:
raise ValueError(f"Local path {local_path} is not a directory")
return RepoDownloadProgress(
repo_id=repo_id,
repo_revision="local",
shard=shard,
completed_files=total_files,
total_files=total_files,
downloaded_bytes=Memory.from_bytes(total_bytes),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(total_bytes),
overall_speed=0,
overall_eta=timedelta(0),
status="complete",
file_progress=file_progress,
)
async def download_shard( async def download_shard(
shard: ShardMetadata, shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]], on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -534,14 +487,6 @@ async def download_shard(
if not skip_download: if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}") logger.info(f"Downloading {shard.model_card.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
)
revision = "main" revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace( target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
"/", "--" "/", "--"
@@ -552,7 +497,8 @@ async def download_shard(
if not allow_patterns: if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard) allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}") if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time() all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed. # TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, Protocol, cast from typing import TYPE_CHECKING, Any, cast
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -67,27 +67,16 @@ def eval_with_timeout(
completed.set() completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class CustomMlxLayer(nn.Module): class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation.""" """Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable): def __init__(self, original_layer: nn.Module):
super().__init__() super().__init__()
object.__setattr__(self, "_original_layer", original_layer) object.__setattr__(self, "_original_layer", original_layer)
@property @property
def original_layer(self) -> _LayerCallable: def original_layer(self) -> nn.Module:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer")) return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding) # Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING: if not TYPE_CHECKING:
@@ -100,52 +89,53 @@ class CustomMlxLayer(nn.Module):
return getattr(original_layer, name) return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer): def patch_pipeline_first_layer(
def __init__( pipeline_layer: nn.Module, group: mx.distributed.Group
self, ) -> nn.Module:
original_layer: _LayerCallable, cls = type(pipeline_layer)
r: int, orig_call = cast(Callable[..., mx.array], cls.__call__)
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: rank = group.rank()
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group) class PatchedFirstLayer(cls):
return self.original_layer(x, *args, **kwargs) def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(self, x, *args, **kwargs)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
class PipelineLastLayer(CustomMlxLayer): def patch_pipeline_last_layer(
def __init__( pipeline_layer: nn.Module, group: mx.distributed.Group
self, ) -> nn.Module:
original_layer: _LayerCallable, cls = type(pipeline_layer)
r: int, orig_call = cast(Callable[..., mx.array], cls.__call__)
s: int, orig_call_sig = signature(orig_call)
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: rank = group.rank()
cache = self.original_layer_signature.bind_partial( size = group.size()
x, *args, **kwargs
).arguments.get("cache", None)
output: mx.array = self.original_layer(x, *args, **kwargs) class PatchedLastLayer(cls):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != self.s - 1: cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
output = mx.distributed.send( "cache", None
output, (self.r + 1) % self.s, group=self.group
) )
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output output: mx.array = orig_call(self, x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(output, (rank + 1) % size, group=group)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
def _inner_model(model: nn.Module) -> nn.Module: def _inner_model(model: nn.Module) -> nn.Module:
@@ -160,13 +150,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute") raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]: def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
# Handle both model.layers and model.h cases # Handle both model.layers and model.h cases
layers: list[_LayerCallable] layers: list[nn.Module]
if hasattr(inner_model_instance, "layers"): if hasattr(inner_model_instance, "layers"):
layers = cast(list[_LayerCallable], inner_model_instance.layers) layers = cast(list[nn.Module], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"): elif hasattr(inner_model_instance, "h"):
layers = cast(list[_LayerCallable], inner_model_instance.h) layers = cast(list[nn.Module], inner_model_instance.h)
else: else:
raise ValueError("Model must have either a 'layers' or 'h' attribute") raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -191,15 +181,12 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance) layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer] layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group) layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = PipelineLastLayer( layers[-1] = patch_pipeline_last_layer(
layers[-1], layers[-1],
device_rank, group,
world_size,
group=group,
) )
if isinstance(inner_model_instance, GptOssMoeModel): if isinstance(inner_model_instance, GptOssMoeModel):
@@ -446,7 +433,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None: def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
inner_model_instance = _inner_model(model) inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"): if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers inner_model_instance.layers = layers
@@ -521,17 +508,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer): class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable): def __init__(self, layer: nn.Module):
super().__init__(layer) super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None: if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x) x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None: if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y return y # type: ignore
class MiniMaxShardingStrategy(TensorParallelShardingStrategy): class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
@@ -565,7 +552,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place( self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj layer.block_sparse_moe.switch_mlp.up_proj
) )
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue] layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model return model
@@ -599,7 +586,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj) self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj) self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj) self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group layer.mlp.sharding_group = self.group
# Shard the MLP # Shard the MLP
@@ -612,17 +599,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer): class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable): def __init__(self, layer: nn.Module):
super().__init__(layer) super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None: if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x) x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None: if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y return y # type: ignore
class GptOssShardingStrategy(TensorParallelShardingStrategy): class GptOssShardingStrategy(TensorParallelShardingStrategy):
@@ -674,7 +661,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None: if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x) x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x) y = self.original_layer(x) # type: ignore
if self.sharding_group is not None: if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y return y # type: ignore

View File

@@ -449,7 +449,7 @@ class Worker:
async def _emit_existing_download_progress(self) -> None: async def _emit_existing_download_progress(self) -> None:
try: try:
while True: while True:
logger.info("Fetching and emitting existing download progress...") logger.debug("Fetching and emitting existing download progress...")
async for ( async for (
_, _,
progress, progress,
@@ -480,7 +480,7 @@ class Worker:
await self.event_sender.send( await self.event_sender.send(
NodeDownloadProgress(download_progress=status) NodeDownloadProgress(download_progress=status)
) )
logger.info("Done emitting existing download progress.") logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes await anyio.sleep(5 * 60) # 5 minutes
except Exception as e: except Exception as e:
logger.error(f"Error emitting existing download progress: {e}") logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -18,7 +18,7 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load from exo.worker.engines.mlx.utils_mlx import shard_and_load, apply_chat_template
class MockLayer(nn.Module): class MockLayer(nn.Module):
@@ -116,12 +116,11 @@ def run_gpt_oss_pipeline_device(
messages=[ChatCompletionMessage(role="user", content=prompt_text)], messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens, max_tokens=max_tokens,
) )
prompt = apply_chat_template(tokenizer, task)
generated_text = "" generated_text = ""
for response in mlx_generate( for response in mlx_generate(
model=model, model=model, tokenizer=tokenizer, task=task, prompt=prompt
tokenizer=tokenizer,
task=task,
): ):
generated_text += response.text generated_text += response.text
if response.finish_reason is not None: if response.finish_reason is not None:
@@ -183,11 +182,11 @@ def run_gpt_oss_tensor_parallel_device(
max_tokens=max_tokens, max_tokens=max_tokens,
) )
prompt = apply_chat_template(tokenizer, task)
generated_text = "" generated_text = ""
for response in mlx_generate( for response in mlx_generate(
model=model, model=model, tokenizer=tokenizer, task=task, prompt=prompt
tokenizer=tokenizer,
task=task,
): ):
generated_text += response.text generated_text += response.text
if response.finish_reason is not None: if response.finish_reason is not None:

View File

@@ -10,8 +10,8 @@ import pytest
from exo.worker.engines.mlx.auto_parallel import ( from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer, CustomMlxLayer,
PipelineFirstLayer, patch_pipeline_first_layer,
PipelineLastLayer, patch_pipeline_last_layer,
patch_pipeline_model, patch_pipeline_model,
) )
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
@@ -50,8 +50,8 @@ def run_pipeline_device(
group = mx.distributed.init(backend="ring", strict=True) group = mx.distributed.init(backend="ring", strict=True)
mock = MockLayerInner() mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group) first = patch_pipeline_first_layer(mock, group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group) composed = patch_pipeline_last_layer(first, group)
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather # Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
inner_model = MockModel([composed]) inner_model = MockModel([composed])
@@ -78,8 +78,8 @@ def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer() mock = MockLayer()
group = mx.distributed.init() group = mx.distributed.init()
first = PipelineFirstLayer(mock, r=0, group=group) first = patch_pipeline_first_layer(mock, group)
composed = PipelineLastLayer(first, r=0, s=1, group=group) composed = patch_pipeline_last_layer(first, group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined] assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined] assert composed.use_sliding is True # type: ignore[attr-defined]