A few targeted tweaks to address HF rate limits (#2009)

## Motivation

- exo bursts ~200 HF Hub-API requests on every cold start, blowing past
the anonymous 500-req/5-min budget.
- The existing retry loop catches 429 generically and gives up in ~3s —
well before HF's reset window.
- `file_meta` and `_download_file` had no 429 handling at all (became
`AssertionError`).
- Disk file-list cache was bypassed on every process restart.

## Changes

All in `src/exo/download/download_utils.py` + tests.

- Parse `t=` from HF's `RateLimit` header on 429; sleep `min(t, 300s) +
jitter`.
- Handle 429 at all three call sites (`_fetch_file_list`, `file_meta`,
`_download_file`).
- `n_attempts`: 3 → 5.
- Disk cache now primary across restarts (24h mtime TTL).
- `?recursive=true` instead of N+1 subdir walks.

## Why It Works

`t=<seconds>` is HF's "wait this long and you'll be unblocked" —
sleeping that long lets the window reset. Disk-cache-as-primary plus
recursive listing cuts cold-start Hub-API traffic by ~10×.

## Test Plan

### Manual Testing

MacBook Pro M1 Max. Tripped the real HF 429. Pre-fix: failed in 3.4s.
Post-fix: slept (HF returned `t=158`) and recovered.

### Automated Testing

- New `test_rate_limit_handling.py` (19 tests) — header parsing,
retry-loop behaviour, plus HTTP-level coverage that mocks aiohttp to
return a 429 and asserts each call site raises
`HuggingFaceRateLimitError(retry_after=52.0)`.
- New `TestFileListCacheTTL` in `test_offline_mode.py` — fresh cache
hits, stale cache refetches.
- 421 tests pass; basedpyright / ruff / nix fmt clean.
This commit is contained in:
ciaranbor
2026-04-30 19:06:15 +01:00
committed by GitHub
parent fb12b403ea
commit 8dae3ecb9a
3 changed files with 507 additions and 25 deletions

View File

@@ -1,11 +1,12 @@
import asyncio
import hashlib
import os
import random
import shutil
import ssl
import time
import traceback
from collections.abc import Awaitable
from collections.abc import Awaitable, Mapping
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
@@ -55,6 +56,36 @@ class HuggingFaceAuthenticationError(Exception):
class HuggingFaceRateLimitError(Exception):
"""429 Huggingface code"""
def __init__(self, msg: str, retry_after: float | None = None) -> None:
super().__init__(msg)
self.retry_after = retry_after
def _parse_retry_after(headers: Mapping[str, str]) -> float | None:
"""Parse seconds-to-reset from HF's RateLimit header.
HF sends e.g. ``ratelimit: "api";r=0;t=52`` on 429s; ``t`` is the wait.
Returns ``None`` if the header is missing or has no ``t`` field.
"""
raw = headers.get("RateLimit") or headers.get("ratelimit")
if raw is None:
return None
for part in raw.split(";"):
key, _, val = part.strip().partition("=")
if key == "t":
try:
return float(val)
except ValueError:
return None
return None
# reset window is 5 min
_RATE_LIMIT_MAX_SLEEP_SECS = 300.0
# 24h. Manually clear the cache (or `delete_model`) to force a refresh.
_FILE_LIST_CACHE_TTL_SECS = 24 * 60 * 60
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
token = await get_hf_token()
@@ -348,9 +379,6 @@ async def _build_file_list_from_local_directory(
return None
_fetched_file_lists_this_session: set[str] = set()
async def fetch_file_list_with_cache(
model_id: ModelId,
revision: str = "main",
@@ -360,13 +388,16 @@ async def fetch_file_list_with_cache(
) -> list[FileListEntry]:
target_dir = await ensure_cache_dir(model_id)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
cache_key = f"{model_id.normalize()}--{revision}"
if cache_key in _fetched_file_lists_this_session and await aios.path.exists(
cache_file
):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
# cache survives process restarts so cold starts don't re-burst HF
if await aios.path.exists(cache_file):
try:
cache_age = time.time() - (await aios.stat(cache_file)).st_mtime
except OSError:
cache_age = float("inf")
if cache_age < _FILE_LIST_CACHE_TTL_SECS:
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
if skip_internet:
if await aios.path.exists(cache_file):
@@ -395,7 +426,6 @@ async def fetch_file_list_with_cache(
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
logger.opt(exception=e).warning(
@@ -426,17 +456,29 @@ async def fetch_file_list_with_retry(
recursive: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
) -> list[FileListEntry]:
n_attempts = 3
n_attempts = 5
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
except HuggingFaceAuthenticationError:
raise
except HuggingFaceRateLimitError as e:
if attempt == n_attempts - 1:
raise
sleep_for = e.retry_after if e.retry_after is not None else 2.0**attempt
sleep_for = min(sleep_for, _RATE_LIMIT_MAX_SLEEP_SECS) + random.uniform(
0, 1
)
logger.warning(
f"Rate limited by HuggingFace fetching file list for {model_id}; "
f"sleeping {sleep_for:.1f}s before retry {attempt + 2}/{n_attempts}"
)
await asyncio.sleep(sleep_for)
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(2.0**attempt)
await asyncio.sleep(2.0**attempt + random.uniform(0, 1))
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
@@ -447,6 +489,9 @@ async def _fetch_file_list(
) -> list[FileListEntry]:
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
# ?recursive=true returns the whole subtree in one request
if recursive:
url = f"{url}?recursive=true"
headers = await get_download_headers()
async with (
@@ -458,7 +503,8 @@ async def _fetch_file_list(
raise HuggingFaceAuthenticationError(msg)
elif response.status == 429:
raise HuggingFaceRateLimitError(
f"Couldn't download {model_id} because of HuggingFace rate limit."
f"HuggingFace rate limit hit fetching file list for {model_id}",
retry_after=_parse_retry_after(response.headers),
)
elif response.status == 200:
data_json = await response.text()
@@ -468,10 +514,14 @@ async def _fetch_file_list(
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
model_id, revision, item.path, recursive
)
files.extend(subfiles)
# already inlined by ?recursive=true
continue
if recursive and len(data) >= 1000:
# HF tree endpoint paginates at 1000; we don't follow cursors
logger.warning(
f"File list for {model_id} hit the 1000-entry page cap "
"and may be truncated; cursor pagination is not implemented"
)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
@@ -552,6 +602,11 @@ async def file_meta(
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
raise HuggingFaceAuthenticationError(msg)
if r.status == 429:
raise HuggingFaceRateLimitError(
f"HuggingFace rate limit hit fetching metadata for {model_id}/{path}",
retry_after=_parse_retry_after(r.headers),
)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -571,7 +626,7 @@ async def download_file_with_retry(
on_connection_lost: Callable[[], None] = lambda: None,
skip_internet: bool = False,
) -> Path:
n_attempts = 3
n_attempts = 5
for attempt in range(n_attempts):
try:
return await _download_file(
@@ -583,12 +638,16 @@ async def download_file_with_retry(
raise
except HuggingFaceRateLimitError as e:
if attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
raise
sleep_for = e.retry_after if e.retry_after is not None else 2.0**attempt
sleep_for = min(sleep_for, _RATE_LIMIT_MAX_SLEEP_SECS) + random.uniform(
0, 1
)
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
logger.warning(
f"Rate limited by HuggingFace downloading {model_id}/{path}; "
f"sleeping {sleep_for:.1f}s before retry {attempt + 2}/{n_attempts}"
)
await asyncio.sleep(sleep_for)
except Exception as e:
if attempt == n_attempts - 1:
on_connection_lost()
@@ -597,7 +656,7 @@ async def download_file_with_retry(
f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
await asyncio.sleep(2.0**attempt + random.uniform(0, 1))
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
@@ -665,6 +724,11 @@ async def _download_file(
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
raise HuggingFaceAuthenticationError(msg)
if r.status == 429:
raise HuggingFaceRateLimitError(
f"HuggingFace rate limit hit downloading {model_id}/{path}",
retry_after=_parse_retry_after(r.headers),
)
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
)

View File

@@ -1,5 +1,7 @@
"""Tests for offline/air-gapped mode."""
import os
import time
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
@@ -231,3 +233,64 @@ class TestFetchFileListOffline:
raise FileNotFoundError."""
with pytest.raises(FileNotFoundError, match="No internet"):
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)
class TestFileListCacheTTL:
async def test_uses_fresh_cache_without_fetching(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
from pydantic import TypeAdapter
cache_dir = temp_models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
cached_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(model_id, "main")
assert result == cached_list
mock_fetch.assert_not_called()
async def test_refetches_when_cache_older_than_ttl(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
from pydantic import TypeAdapter
from exo.download.download_utils import (
_FILE_LIST_CACHE_TTL_SECS, # pyright: ignore[reportPrivateUsage]
)
cache_dir = temp_models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
stale_list = [FileListEntry(type="file", path="stale.bin", size=1)]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(stale_list).decode()
)
old_mtime = time.time() - _FILE_LIST_CACHE_TTL_SECS - 60
os.utime(cache_file, (old_mtime, old_mtime))
fresh_list = [FileListEntry(type="file", path="fresh.bin", size=2)]
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
return_value=fresh_list,
) as mock_fetch:
result = await fetch_file_list_with_cache(model_id, "main")
assert result == fresh_list
mock_fetch.assert_called_once()

View File

@@ -0,0 +1,355 @@
"""Tests for HuggingFace 429 rate-limit handling in download_utils."""
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiofiles.os as aios
import pytest
from exo.download.download_utils import (
HuggingFaceRateLimitError,
_download_file, # pyright: ignore[reportPrivateUsage]
_fetch_file_list, # pyright: ignore[reportPrivateUsage]
_parse_retry_after, # pyright: ignore[reportPrivateUsage]
download_file_with_retry,
fetch_file_list_with_retry,
file_meta,
)
from exo.shared.types.common import ModelId
# captured from a real HF 429 on 2026-04-30 (header is lowercased by Cloudfront)
REAL_HF_429_HEADERS_2026_04_30 = {
"ratelimit": '"api";r=0;t=52',
"ratelimit-policy": '"fixed window";"api";q=500;w=300',
}
class TestParseRetryAfter:
def test_parses_documented_format(self) -> None:
assert _parse_retry_after({"RateLimit": '"api";r=0;t=243'}) == 243.0
def test_parses_real_hf_response(self) -> None:
assert _parse_retry_after(REAL_HF_429_HEADERS_2026_04_30) == 52.0
def test_parses_resolvers_bucket(self) -> None:
assert _parse_retry_after({"ratelimit": '"resolvers";r=0;t=120'}) == 120.0
def test_parses_pages_bucket(self) -> None:
assert _parse_retry_after({"ratelimit": '"pages";r=0;t=10'}) == 10.0
def test_returns_none_when_header_missing(self) -> None:
assert _parse_retry_after({}) is None
def test_returns_none_when_only_retry_after_present(self) -> None:
assert _parse_retry_after({"Retry-After": "60"}) is None
def test_returns_none_when_format_unrecognised(self) -> None:
assert _parse_retry_after({"ratelimit": "garbage"}) is None
def test_handles_extra_whitespace(self) -> None:
assert _parse_retry_after({"ratelimit": '"api"; r=0; t=42'}) == 42.0
class TestFetchFileListRetry:
async def test_uses_retry_after_from_error(self) -> None:
sleeps: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
async def fake_fetch(*args: object, **kwargs: object) -> list[object]:
if not sleeps:
raise HuggingFaceRateLimitError("rate limited", retry_after=2.0)
return []
with (
patch(
"exo.download.download_utils._fetch_file_list", side_effect=fake_fetch
),
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
):
result = await fetch_file_list_with_retry(ModelId("test/model"))
assert result == []
assert len(sleeps) == 1
assert 2.0 <= sleeps[0] < 3.0 # retry_after + jitter[0,1)
async def test_falls_back_to_exp_backoff_when_no_retry_after(self) -> None:
sleeps: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
async def fake_fetch(*args: object, **kwargs: object) -> list[object]:
if not sleeps:
raise HuggingFaceRateLimitError("rate limited", retry_after=None)
return []
with (
patch(
"exo.download.download_utils._fetch_file_list", side_effect=fake_fetch
),
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
):
await fetch_file_list_with_retry(ModelId("test/model"))
assert len(sleeps) == 1
assert 1.0 <= sleeps[0] < 2.0 # 2**0 + jitter[0,1)
async def test_caps_sleep_at_max_window(self) -> None:
sleeps: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
async def fake_fetch(*args: object, **kwargs: object) -> list[object]:
if not sleeps:
raise HuggingFaceRateLimitError("rate limited", retry_after=10_000.0)
return []
with (
patch(
"exo.download.download_utils._fetch_file_list", side_effect=fake_fetch
),
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
):
await fetch_file_list_with_retry(ModelId("test/model"))
assert len(sleeps) == 1
assert 300.0 <= sleeps[0] < 301.0 # cap + jitter[0,1)
async def test_retries_up_to_five_times(self) -> None:
sleeps: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
async def fake_fetch(*args: object, **kwargs: object) -> list[object]:
raise HuggingFaceRateLimitError("rate limited", retry_after=1.0)
with (
patch(
"exo.download.download_utils._fetch_file_list", side_effect=fake_fetch
),
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
pytest.raises(HuggingFaceRateLimitError),
):
await fetch_file_list_with_retry(ModelId("test/model"))
assert len(sleeps) == 4 # 5 attempts -> 4 sleeps before giving up
class TestDownloadFileRetry:
@pytest.fixture
async def target_dir(self, tmp_path: Path) -> AsyncIterator[Path]:
target = tmp_path / "downloads"
await aios.makedirs(target, exist_ok=True)
yield target
async def test_uses_retry_after_from_error(self, target_dir: Path) -> None:
sleeps: list[float] = []
results: list[Path] = [target_dir / "file.bin"]
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
async def fake_download(*args: object, **kwargs: object) -> Path:
if not sleeps:
raise HuggingFaceRateLimitError("rate limited", retry_after=5.0)
return results[0]
with (
patch(
"exo.download.download_utils._download_file",
side_effect=fake_download,
),
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
):
result = await download_file_with_retry(
ModelId("test/model"), "main", "file.bin", target_dir
)
assert result == results[0]
assert len(sleeps) == 1
assert 5.0 <= sleeps[0] < 6.0
async def test_caps_sleep_at_max_window(self, target_dir: Path) -> None:
sleeps: list[float] = []
results: list[Path] = [target_dir / "file.bin"]
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
async def fake_download(*args: object, **kwargs: object) -> Path:
if not sleeps:
raise HuggingFaceRateLimitError("rate limited", retry_after=99_999.0)
return results[0]
with (
patch(
"exo.download.download_utils._download_file",
side_effect=fake_download,
),
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
):
await download_file_with_retry(
ModelId("test/model"), "main", "file.bin", target_dir
)
assert len(sleeps) == 1
assert 300.0 <= sleeps[0] < 301.0
async def test_retries_up_to_five_times(self, target_dir: Path) -> None:
sleeps: list[float] = []
async def fake_sleep(seconds: float) -> None:
sleeps.append(seconds)
with (
patch(
"exo.download.download_utils._download_file",
new_callable=AsyncMock,
side_effect=HuggingFaceRateLimitError("rate limited", retry_after=1.0),
),
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
pytest.raises(HuggingFaceRateLimitError),
):
await download_file_with_retry(
ModelId("test/model"), "main", "file.bin", target_dir
)
assert len(sleeps) == 4
def _make_mock_session_returning(
response_attrs: dict[str, object], method: str = "get"
) -> MagicMock:
"""Build a MagicMock that mimics ``create_http_session`` returning a
response whose ``status`` / ``headers`` are set from ``response_attrs``.
Mocks the chain ``create_http_session().__aenter__() -> session``, and
``session.<method>().__aenter__() -> response``.
"""
mock_response = MagicMock()
for k, v in response_attrs.items():
setattr(mock_response, k, v)
mock_session = MagicMock()
method_mock = getattr(mock_session, method) # pyright: ignore[reportAny]
method_mock.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_response
)
method_mock.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_session
)
mock_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
return mock_factory
REAL_HF_429_HEADER_DICT = {"ratelimit": '"api";r=0;t=52'}
class TestRateLimitAtHttpCallSites:
"""Verify each HF call site translates an HTTP 429 into a
``HuggingFaceRateLimitError`` carrying the parsed ``retry_after``.
These tests would catch regressions where (a) the 429 branch is
deleted, (b) ``_parse_retry_after`` stops being called, or
(c) the wrong header object is passed to it.
"""
async def test_fetch_file_list_maps_429_to_rate_limit_error(self) -> None:
mock_factory = _make_mock_session_returning(
{"status": 429, "headers": REAL_HF_429_HEADER_DICT}
)
with (
patch("exo.download.download_utils.create_http_session", mock_factory),
pytest.raises(HuggingFaceRateLimitError) as exc_info,
):
await _fetch_file_list(ModelId("test/model"), "main")
assert exc_info.value.retry_after == 52.0
async def test_file_meta_maps_429_to_rate_limit_error(self) -> None:
mock_factory = _make_mock_session_returning(
{"status": 429, "headers": REAL_HF_429_HEADER_DICT}, method="head"
)
with (
patch("exo.download.download_utils.create_http_session", mock_factory),
pytest.raises(HuggingFaceRateLimitError) as exc_info,
):
await file_meta(ModelId("test/model"), "main", "weights.safetensors")
assert exc_info.value.retry_after == 52.0
async def test_file_meta_maps_429_after_307_redirect(self) -> None:
"""When the initial HEAD 307s and the redirected HEAD then 429s,
the 429 must still surface as ``HuggingFaceRateLimitError``."""
# First HEAD -> 307 with a Location header pointing somewhere new.
first_response = MagicMock()
first_response.status = 307
first_response.headers = {"location": "/redirected/url"}
# Second HEAD (the recursive call) -> 429 with the real-HF header.
second_response = MagicMock()
second_response.status = 429
second_response.headers = REAL_HF_429_HEADER_DICT
responses = iter([first_response, second_response])
mock_session = MagicMock()
mock_session.head.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
side_effect=lambda: next(responses)
)
mock_session.head.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_session
)
mock_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
with (
patch("exo.download.download_utils.create_http_session", mock_factory),
pytest.raises(HuggingFaceRateLimitError) as exc_info,
):
await file_meta(ModelId("test/model"), "main", "weights.safetensors")
assert exc_info.value.retry_after == 52.0
async def test_download_file_maps_429_to_rate_limit_error(
self, tmp_path: Path
) -> None:
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# No local file -> _download_file goes straight to file_meta then GET.
# We need both calls to succeed enough to reach the GET branch:
# - file_meta returns a non-429 (size, etag) so we proceed.
# - the GET then 429s.
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(100, "abc123"),
),
patch(
"exo.download.download_utils.create_http_session",
_make_mock_session_returning(
{"status": 429, "headers": REAL_HF_429_HEADER_DICT}
),
),
pytest.raises(HuggingFaceRateLimitError) as exc_info,
):
await _download_file(
ModelId("test/model"), "main", "weights.safetensors", target_dir
)
assert exc_info.value.retry_after == 52.0