Compare commits

...

4 Commits

Author SHA1 Message Date
Evan
02b3bbfb3d aaa 2026-01-28 15:57:35 +00:00
Evan Quiney
748a026071 fix configdata validation for kimi-k2 (#1314)
## motivation
our shard downloader could not correctly fetch data for kimi-k2, as it
deferred some values to a text_config field.
## changes
config_data now prioritizes this field if it exists in information like
layer_count
2026-01-28 14:29:36 +00:00
Alex Cheema
f1a2d054ec Update tagline to "Run frontier AI locally" (#1313)
- Update README tagline from "Run your own AI cluster at home with
everyday devices" to "Run frontier AI locally"
2026-01-28 12:38:14 +00:00
Alex Cheema
b3c8f85fc8 Update MLX to 0.30.4 (#1311)
## Summary
- Bump mlx from 0.30.3 to 0.30.4

## Test plan
- [x] `uv lock` succeeds
- [x] Type checking passes (`uv run basedpyright`)
- [x] Run inference tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 04:30:21 -08:00
9 changed files with 1181 additions and 1182 deletions

View File

@@ -5,7 +5,7 @@
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>

View File

@@ -6,8 +6,6 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"pydantic>=2.11.7",
"fastapi>=0.116.1",
"filelock>=3.18.0",
@@ -17,8 +15,8 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",

View File

@@ -121,6 +121,7 @@ class DownloadCoordinator:
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
logger.warning("starting download for {shard}")
model_id = shard.model_card.model_id
# Emit ongoing status

View File

@@ -8,13 +8,13 @@ import traceback
from collections.abc import Awaitable
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from typing import Callable, Literal, cast
from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
import certifi
import httpx
from huggingface_hub import (
snapshot_download, # pyright: ignore[reportUnknownVariableType]
)
@@ -176,7 +176,7 @@ async def fetch_file_list_with_cache(
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
f"{type(e).__name__}: Failed to fetch file list for {model_id}, using cached data"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
@@ -196,7 +196,7 @@ async def fetch_file_list_with_retry(
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
await asyncio.sleep(min(16, 0.5 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
@@ -211,26 +211,25 @@ async def _fetch_file_list(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
response = await session.get(url, headers=headers)
if response.status_code in [401, 403]:
msg = await _build_auth_error_message(response.status_code, model_id)
raise HuggingFaceAuthenticationError(msg)
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
if response.status_code != 200:
raise Exception(f"Failed to fetch file list: {response.status_code}")
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
async def get_download_headers() -> dict[str, str]:
@@ -238,34 +237,29 @@ async def get_download_headers() -> dict[str, str]:
def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long",
) -> aiohttp.ClientSession:
) -> httpx.AsyncClient:
if timeout_profile == "short":
total_timeout = 30
connect_timeout = 10
sock_read_timeout = 30
sock_connect_timeout = 10
read_timeout = 30
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 1800
sock_connect_timeout = 60
read_timeout = 1800
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
# default here is to load env vars
return httpx.AsyncClient(
verify=ssl_context,
timeout=httpx.Timeout(
connect=connect_timeout,
sock_read=sock_read_timeout,
sock_connect=sock_connect_timeout,
read=read_timeout,
write=total_timeout,
pool=total_timeout,
),
)
@@ -292,26 +286,28 @@ async def file_meta(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.head(url, headers=headers) as r,
session.stream("HEAD", url, headers=headers) as r,
):
if r.status == 307:
if r.status_code == 307:
# On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("x-linked-etag")
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag)
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
redirected_location = cast(str | None, r.headers.get("location"))
return await file_meta(model_id, revision, path, redirected_location)
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
if r.status_code in [401, 403]:
msg = await _build_auth_error_message(r.status_code, model_id)
raise HuggingFaceAuthenticationError(msg)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
content_length = cast(
str | None,
r.headers.get("x-linked-size") or r.headers.get("content-length"),
)
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
content_length = 0 if content_length is None else int(content_length)
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
etag = trim_etag(etag)
@@ -340,7 +336,7 @@ async def download_file_with_retry(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
await asyncio.sleep(min(16, 0.5 * (2.0**attempt)))
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
@@ -353,6 +349,7 @@ async def _download_file(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path:
logger.warning(f"downloading {path} from {model_id} to {target_dir}")
target_path = target_dir / path
if await aios.path.exists(target_path):
@@ -392,20 +389,20 @@ async def _download_file(
n_read = resume_byte_pos or 0
async with (
create_http_session(timeout_profile="long") as session,
session.get(url, headers=headers) as r,
session.stream("GET", url, headers=headers, follow_redirects=True) as r,
):
if r.status == 404:
if r.status_code == 404:
raise FileNotFoundError(f"File not found: {url}")
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
if r.status_code in [401, 403]:
msg = await _build_auth_error_message(r.status_code, model_id)
raise HuggingFaceAuthenticationError(msg)
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
assert r.status_code in [200, 206], (
f"Failed to download {path} from {url}: {r.status_code}"
)
async with aiofiles.open(
partial_path, "ab" if resume_byte_pos else "wb"
) as f:
while chunk := await r.content.read(8 * 1024 * 1024):
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
n_read = n_read + (await f.write(chunk))
on_progress(n_read, length, False)

View File

@@ -168,7 +168,8 @@ class ResumableShardDownloader(ShardDownloader):
yield await task
# TODO: except Exception
except Exception as e:
logger.error("Error downloading shard:", e)
task.cancel()
logger.opt(exception=e).error("Error downloading shard")
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Annotated
from typing import Annotated, Any
import aiofiles
import aiofiles.os as aios
@@ -7,7 +7,14 @@ import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt, field_validator
from pydantic import (
AliasChoices,
BaseModel,
Field,
PositiveInt,
field_validator,
model_validator,
)
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.types.common import ModelId
@@ -711,15 +718,18 @@ if EXO_ENABLE_IMAGE_MODELS:
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
hidden_size: Annotated[int, Field(ge=0)] | None = None
layer_count: int = Field(
validation_alias=AliasChoices(
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
)
)
@property
def supports_tensor(self) -> bool:
@@ -734,25 +744,27 @@ class ConfigData(BaseModel):
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
@model_validator(mode="before")
@classmethod
def defer_to_text_config(cls, data: dict[str, Any]):
text_config = data.get("text_config")
if text_config is None:
return data
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
for field in [
"architectures",
"hidden_size",
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
]:
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
data[field] = val
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
return data
async def get_config_data(model_id: ModelId) -> ConfigData:

View File

@@ -1,10 +1,8 @@
import time
from typing import Generic, TypeVar
K = TypeVar("K")
from collections.abc import Hashable
class KeyedBackoff(Generic[K]):
class KeyedBackoff[K: Hashable]:
"""Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0):

View File

@@ -165,12 +165,11 @@ def mlx_distributed_init(
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)

2167
uv.lock generated
View File

File diff suppressed because it is too large Load Diff