Compare commits

..

2 Commits

Author SHA1 Message Date
Evan
02b3bbfb3d aaa 2026-01-28 15:57:35 +00:00
Evan Quiney
748a026071 fix configdata validation for kimi-k2 (#1314)
## motivation
our shard downloader could not correctly fetch data for kimi-k2, as it
deferred some values to a text_config field.
## changes
config_data now prioritizes this field if it exists in information like
layer_count
2026-01-28 14:29:36 +00:00
8 changed files with 97 additions and 97 deletions

View File

@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 20
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -83,7 +80,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")

View File

@@ -6,8 +6,6 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"pydantic>=2.11.7",
"fastapi>=0.116.1",
"filelock>=3.18.0",

View File

@@ -121,6 +121,7 @@ class DownloadCoordinator:
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
logger.warning("starting download for {shard}")
model_id = shard.model_card.model_id
# Emit ongoing status

View File

@@ -8,13 +8,13 @@ import traceback
from collections.abc import Awaitable
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from typing import Callable, Literal, cast
from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
import certifi
import httpx
from huggingface_hub import (
snapshot_download, # pyright: ignore[reportUnknownVariableType]
)
@@ -176,7 +176,7 @@ async def fetch_file_list_with_cache(
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
f"{type(e).__name__}: Failed to fetch file list for {model_id}, using cached data"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
@@ -196,7 +196,7 @@ async def fetch_file_list_with_retry(
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
await asyncio.sleep(min(16, 0.5 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
@@ -211,26 +211,25 @@ async def _fetch_file_list(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
response = await session.get(url, headers=headers)
if response.status_code in [401, 403]:
msg = await _build_auth_error_message(response.status_code, model_id)
raise HuggingFaceAuthenticationError(msg)
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
if response.status_code != 200:
raise Exception(f"Failed to fetch file list: {response.status_code}")
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
async def get_download_headers() -> dict[str, str]:
@@ -238,34 +237,29 @@ async def get_download_headers() -> dict[str, str]:
def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long",
) -> aiohttp.ClientSession:
) -> httpx.AsyncClient:
if timeout_profile == "short":
total_timeout = 30
connect_timeout = 10
sock_read_timeout = 30
sock_connect_timeout = 10
read_timeout = 30
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 1800
sock_connect_timeout = 60
read_timeout = 1800
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
# default here is to load env vars
return httpx.AsyncClient(
verify=ssl_context,
timeout=httpx.Timeout(
connect=connect_timeout,
sock_read=sock_read_timeout,
sock_connect=sock_connect_timeout,
read=read_timeout,
write=total_timeout,
pool=total_timeout,
),
)
@@ -292,26 +286,28 @@ async def file_meta(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.head(url, headers=headers) as r,
session.stream("HEAD", url, headers=headers) as r,
):
if r.status == 307:
if r.status_code == 307:
# On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("x-linked-etag")
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag)
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
redirected_location = cast(str | None, r.headers.get("location"))
return await file_meta(model_id, revision, path, redirected_location)
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
if r.status_code in [401, 403]:
msg = await _build_auth_error_message(r.status_code, model_id)
raise HuggingFaceAuthenticationError(msg)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
content_length = cast(
str | None,
r.headers.get("x-linked-size") or r.headers.get("content-length"),
)
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
content_length = 0 if content_length is None else int(content_length)
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
etag = trim_etag(etag)
@@ -340,7 +336,7 @@ async def download_file_with_retry(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
await asyncio.sleep(min(16, 0.5 * (2.0**attempt)))
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
@@ -353,6 +349,7 @@ async def _download_file(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path:
logger.warning(f"downloading {path} from {model_id} to {target_dir}")
target_path = target_dir / path
if await aios.path.exists(target_path):
@@ -392,20 +389,20 @@ async def _download_file(
n_read = resume_byte_pos or 0
async with (
create_http_session(timeout_profile="long") as session,
session.get(url, headers=headers) as r,
session.stream("GET", url, headers=headers, follow_redirects=True) as r,
):
if r.status == 404:
if r.status_code == 404:
raise FileNotFoundError(f"File not found: {url}")
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
if r.status_code in [401, 403]:
msg = await _build_auth_error_message(r.status_code, model_id)
raise HuggingFaceAuthenticationError(msg)
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
assert r.status_code in [200, 206], (
f"Failed to download {path} from {url}: {r.status_code}"
)
async with aiofiles.open(
partial_path, "ab" if resume_byte_pos else "wb"
) as f:
while chunk := await r.content.read(8 * 1024 * 1024):
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
n_read = n_read + (await f.write(chunk))
on_progress(n_read, length, False)

View File

@@ -168,7 +168,8 @@ class ResumableShardDownloader(ShardDownloader):
yield await task
# TODO: except Exception
except Exception as e:
logger.error("Error downloading shard:", e)
task.cancel()
logger.opt(exception=e).error("Error downloading shard")
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Annotated
from typing import Annotated, Any
import aiofiles
import aiofiles.os as aios
@@ -7,7 +7,14 @@ import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt, field_validator
from pydantic import (
AliasChoices,
BaseModel,
Field,
PositiveInt,
field_validator,
model_validator,
)
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.types.common import ModelId
@@ -711,15 +718,18 @@ if EXO_ENABLE_IMAGE_MODELS:
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
hidden_size: Annotated[int, Field(ge=0)] | None = None
layer_count: int = Field(
validation_alias=AliasChoices(
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
)
)
@property
def supports_tensor(self) -> bool:
@@ -734,25 +744,27 @@ class ConfigData(BaseModel):
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
@model_validator(mode="before")
@classmethod
def defer_to_text_config(cls, data: dict[str, Any]):
text_config = data.get("text_config")
if text_config is None:
return data
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
for field in [
"architectures",
"hidden_size",
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
]:
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
data[field] = val
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
return data
async def get_config_data(model_id: ModelId) -> ConfigData:

View File

@@ -1,10 +1,8 @@
import time
from typing import Generic, TypeVar
K = TypeVar("K")
from collections.abc import Hashable
class KeyedBackoff(Generic[K]):
class KeyedBackoff[K: Hashable]:
"""Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0):

4
uv.lock generated
View File

@@ -366,7 +366,6 @@ version = "0.3.0"
source = { editable = "." }
dependencies = [
{ name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -387,7 +386,6 @@ dependencies = [
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[package.dev-dependencies]
@@ -403,7 +401,6 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "aiofiles", specifier = ">=24.1.0" },
{ name = "aiohttp", specifier = ">=3.12.14" },
{ name = "anyio", specifier = "==4.11.0" },
{ name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" },
{ name = "fastapi", specifier = ">=0.116.1" },
@@ -424,7 +421,6 @@ requires-dist = [
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
]
[package.metadata.requires-dev]