mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-19 04:05:23 -04:00
A few targeted tweaks to address HF rate limits (#2009)
## Motivation - exo bursts ~200 HF Hub-API requests on every cold start, blowing past the anonymous 500-req/5-min budget. - The existing retry loop catches 429 generically and gives up in ~3s — well before HF's reset window. - `file_meta` and `_download_file` had no 429 handling at all (became `AssertionError`). - Disk file-list cache was bypassed on every process restart. ## Changes All in `src/exo/download/download_utils.py` + tests. - Parse `t=` from HF's `RateLimit` header on 429; sleep `min(t, 300s) + jitter`. - Handle 429 at all three call sites (`_fetch_file_list`, `file_meta`, `_download_file`). - `n_attempts`: 3 → 5. - Disk cache now primary across restarts (24h mtime TTL). - `?recursive=true` instead of N+1 subdir walks. ## Why It Works `t=<seconds>` is HF's "wait this long and you'll be unblocked" — sleeping that long lets the window reset. Disk-cache-as-primary plus recursive listing cuts cold-start Hub-API traffic by ~10×. ## Test Plan ### Manual Testing MacBook Pro M1 Max. Tripped the real HF 429. Pre-fix: failed in 3.4s. Post-fix: slept (HF returned `t=158`) and recovered. ### Automated Testing - New `test_rate_limit_handling.py` (19 tests) — header parsing, retry-loop behaviour, plus HTTP-level coverage that mocks aiohttp to return a 429 and asserts each call site raises `HuggingFaceRateLimitError(retry_after=52.0)`. - New `TestFileListCacheTTL` in `test_offline_mode.py` — fresh cache hits, stale cache refetches. - 421 tests pass; basedpyright / ruff / nix fmt clean.
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import ssl
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Awaitable, Mapping
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
@@ -55,6 +56,36 @@ class HuggingFaceAuthenticationError(Exception):
|
||||
class HuggingFaceRateLimitError(Exception):
|
||||
"""429 Huggingface code"""
|
||||
|
||||
def __init__(self, msg: str, retry_after: float | None = None) -> None:
|
||||
super().__init__(msg)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
def _parse_retry_after(headers: Mapping[str, str]) -> float | None:
|
||||
"""Parse seconds-to-reset from HF's RateLimit header.
|
||||
|
||||
HF sends e.g. ``ratelimit: "api";r=0;t=52`` on 429s; ``t`` is the wait.
|
||||
Returns ``None`` if the header is missing or has no ``t`` field.
|
||||
"""
|
||||
raw = headers.get("RateLimit") or headers.get("ratelimit")
|
||||
if raw is None:
|
||||
return None
|
||||
for part in raw.split(";"):
|
||||
key, _, val = part.strip().partition("=")
|
||||
if key == "t":
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
# reset window is 5 min
|
||||
_RATE_LIMIT_MAX_SLEEP_SECS = 300.0
|
||||
|
||||
# 24h. Manually clear the cache (or `delete_model`) to force a refresh.
|
||||
_FILE_LIST_CACHE_TTL_SECS = 24 * 60 * 60
|
||||
|
||||
|
||||
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
|
||||
token = await get_hf_token()
|
||||
@@ -348,9 +379,6 @@ async def _build_file_list_from_local_directory(
|
||||
return None
|
||||
|
||||
|
||||
_fetched_file_lists_this_session: set[str] = set()
|
||||
|
||||
|
||||
async def fetch_file_list_with_cache(
|
||||
model_id: ModelId,
|
||||
revision: str = "main",
|
||||
@@ -360,13 +388,16 @@ async def fetch_file_list_with_cache(
|
||||
) -> list[FileListEntry]:
|
||||
target_dir = await ensure_cache_dir(model_id)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
cache_key = f"{model_id.normalize()}--{revision}"
|
||||
|
||||
if cache_key in _fetched_file_lists_this_session and await aios.path.exists(
|
||||
cache_file
|
||||
):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
# cache survives process restarts so cold starts don't re-burst HF
|
||||
if await aios.path.exists(cache_file):
|
||||
try:
|
||||
cache_age = time.time() - (await aios.stat(cache_file)).st_mtime
|
||||
except OSError:
|
||||
cache_age = float("inf")
|
||||
if cache_age < _FILE_LIST_CACHE_TTL_SECS:
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
|
||||
if skip_internet:
|
||||
if await aios.path.exists(cache_file):
|
||||
@@ -395,7 +426,6 @@ async def fetch_file_list_with_cache(
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
|
||||
)
|
||||
_fetched_file_lists_this_session.add(cache_key)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
@@ -426,17 +456,29 @@ async def fetch_file_list_with_retry(
|
||||
recursive: bool = False,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
) -> list[FileListEntry]:
|
||||
n_attempts = 3
|
||||
n_attempts = 5
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except HuggingFaceRateLimitError as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise
|
||||
sleep_for = e.retry_after if e.retry_after is not None else 2.0**attempt
|
||||
sleep_for = min(sleep_for, _RATE_LIMIT_MAX_SLEEP_SECS) + random.uniform(
|
||||
0, 1
|
||||
)
|
||||
logger.warning(
|
||||
f"Rate limited by HuggingFace fetching file list for {model_id}; "
|
||||
f"sleeping {sleep_for:.1f}s before retry {attempt + 2}/{n_attempts}"
|
||||
)
|
||||
await asyncio.sleep(sleep_for)
|
||||
except Exception as e:
|
||||
on_connection_lost()
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
await asyncio.sleep(2.0**attempt + random.uniform(0, 1))
|
||||
raise Exception(
|
||||
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
||||
)
|
||||
@@ -447,6 +489,9 @@ async def _fetch_file_list(
|
||||
) -> list[FileListEntry]:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
# ?recursive=true returns the whole subtree in one request
|
||||
if recursive:
|
||||
url = f"{url}?recursive=true"
|
||||
|
||||
headers = await get_download_headers()
|
||||
async with (
|
||||
@@ -458,7 +503,8 @@ async def _fetch_file_list(
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
elif response.status == 429:
|
||||
raise HuggingFaceRateLimitError(
|
||||
f"Couldn't download {model_id} because of HuggingFace rate limit."
|
||||
f"HuggingFace rate limit hit fetching file list for {model_id}",
|
||||
retry_after=_parse_retry_after(response.headers),
|
||||
)
|
||||
elif response.status == 200:
|
||||
data_json = await response.text()
|
||||
@@ -468,10 +514,14 @@ async def _fetch_file_list(
|
||||
if item.type == "file":
|
||||
files.append(FileListEntry.model_validate(item))
|
||||
elif item.type == "directory" and recursive:
|
||||
subfiles = await _fetch_file_list(
|
||||
model_id, revision, item.path, recursive
|
||||
)
|
||||
files.extend(subfiles)
|
||||
# already inlined by ?recursive=true
|
||||
continue
|
||||
if recursive and len(data) >= 1000:
|
||||
# HF tree endpoint paginates at 1000; we don't follow cursors
|
||||
logger.warning(
|
||||
f"File list for {model_id} hit the 1000-entry page cap "
|
||||
"and may be truncated; cursor pagination is not implemented"
|
||||
)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
@@ -552,6 +602,11 @@ async def file_meta(
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
if r.status == 429:
|
||||
raise HuggingFaceRateLimitError(
|
||||
f"HuggingFace rate limit hit fetching metadata for {model_id}/{path}",
|
||||
retry_after=_parse_retry_after(r.headers),
|
||||
)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
@@ -571,7 +626,7 @@ async def download_file_with_retry(
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
n_attempts = 3
|
||||
n_attempts = 5
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
@@ -583,12 +638,16 @@ async def download_file_with_retry(
|
||||
raise
|
||||
except HuggingFaceRateLimitError as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
raise
|
||||
sleep_for = e.retry_after if e.retry_after is not None else 2.0**attempt
|
||||
sleep_for = min(sleep_for, _RATE_LIMIT_MAX_SLEEP_SECS) + random.uniform(
|
||||
0, 1
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
logger.warning(
|
||||
f"Rate limited by HuggingFace downloading {model_id}/{path}; "
|
||||
f"sleeping {sleep_for:.1f}s before retry {attempt + 2}/{n_attempts}"
|
||||
)
|
||||
await asyncio.sleep(sleep_for)
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1:
|
||||
on_connection_lost()
|
||||
@@ -597,7 +656,7 @@ async def download_file_with_retry(
|
||||
f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(2.0**attempt)
|
||||
await asyncio.sleep(2.0**attempt + random.uniform(0, 1))
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
@@ -665,6 +724,11 @@ async def _download_file(
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
if r.status == 429:
|
||||
raise HuggingFaceRateLimitError(
|
||||
f"HuggingFace rate limit hit downloading {model_id}/{path}",
|
||||
retry_after=_parse_retry_after(r.headers),
|
||||
)
|
||||
assert r.status in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status}"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for offline/air-gapped mode."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
@@ -231,3 +233,64 @@ class TestFetchFileListOffline:
|
||||
raise FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="No internet"):
|
||||
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)
|
||||
|
||||
|
||||
class TestFileListCacheTTL:
|
||||
async def test_uses_fresh_cache_without_fetching(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
cache_dir = temp_models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
cached_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
|
||||
)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == cached_list
|
||||
mock_fetch.assert_not_called()
|
||||
|
||||
async def test_refetches_when_cache_older_than_ttl(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from exo.download.download_utils import (
|
||||
_FILE_LIST_CACHE_TTL_SECS, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
cache_dir = temp_models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
stale_list = [FileListEntry(type="file", path="stale.bin", size=1)]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(stale_list).decode()
|
||||
)
|
||||
|
||||
old_mtime = time.time() - _FILE_LIST_CACHE_TTL_SECS - 60
|
||||
os.utime(cache_file, (old_mtime, old_mtime))
|
||||
|
||||
fresh_list = [FileListEntry(type="file", path="fresh.bin", size=2)]
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fresh_list,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == fresh_list
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
355
src/exo/download/tests/test_rate_limit_handling.py
Normal file
355
src/exo/download/tests/test_rate_limit_handling.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Tests for HuggingFace 429 rate-limit handling in download_utils."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
|
||||
from exo.download.download_utils import (
|
||||
HuggingFaceRateLimitError,
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
_fetch_file_list, # pyright: ignore[reportPrivateUsage]
|
||||
_parse_retry_after, # pyright: ignore[reportPrivateUsage]
|
||||
download_file_with_retry,
|
||||
fetch_file_list_with_retry,
|
||||
file_meta,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
|
||||
# captured from a real HF 429 on 2026-04-30 (header is lowercased by Cloudfront)
|
||||
REAL_HF_429_HEADERS_2026_04_30 = {
|
||||
"ratelimit": '"api";r=0;t=52',
|
||||
"ratelimit-policy": '"fixed window";"api";q=500;w=300',
|
||||
}
|
||||
|
||||
|
||||
class TestParseRetryAfter:
|
||||
def test_parses_documented_format(self) -> None:
|
||||
assert _parse_retry_after({"RateLimit": '"api";r=0;t=243'}) == 243.0
|
||||
|
||||
def test_parses_real_hf_response(self) -> None:
|
||||
assert _parse_retry_after(REAL_HF_429_HEADERS_2026_04_30) == 52.0
|
||||
|
||||
def test_parses_resolvers_bucket(self) -> None:
|
||||
assert _parse_retry_after({"ratelimit": '"resolvers";r=0;t=120'}) == 120.0
|
||||
|
||||
def test_parses_pages_bucket(self) -> None:
|
||||
assert _parse_retry_after({"ratelimit": '"pages";r=0;t=10'}) == 10.0
|
||||
|
||||
def test_returns_none_when_header_missing(self) -> None:
|
||||
assert _parse_retry_after({}) is None
|
||||
|
||||
def test_returns_none_when_only_retry_after_present(self) -> None:
|
||||
assert _parse_retry_after({"Retry-After": "60"}) is None
|
||||
|
||||
def test_returns_none_when_format_unrecognised(self) -> None:
|
||||
assert _parse_retry_after({"ratelimit": "garbage"}) is None
|
||||
|
||||
def test_handles_extra_whitespace(self) -> None:
|
||||
assert _parse_retry_after({"ratelimit": '"api"; r=0; t=42'}) == 42.0
|
||||
|
||||
|
||||
class TestFetchFileListRetry:
|
||||
async def test_uses_retry_after_from_error(self) -> None:
|
||||
sleeps: list[float] = []
|
||||
|
||||
async def fake_sleep(seconds: float) -> None:
|
||||
sleeps.append(seconds)
|
||||
|
||||
async def fake_fetch(*args: object, **kwargs: object) -> list[object]:
|
||||
if not sleeps:
|
||||
raise HuggingFaceRateLimitError("rate limited", retry_after=2.0)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils._fetch_file_list", side_effect=fake_fetch
|
||||
),
|
||||
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
|
||||
):
|
||||
result = await fetch_file_list_with_retry(ModelId("test/model"))
|
||||
|
||||
assert result == []
|
||||
assert len(sleeps) == 1
|
||||
assert 2.0 <= sleeps[0] < 3.0 # retry_after + jitter[0,1)
|
||||
|
||||
async def test_falls_back_to_exp_backoff_when_no_retry_after(self) -> None:
|
||||
sleeps: list[float] = []
|
||||
|
||||
async def fake_sleep(seconds: float) -> None:
|
||||
sleeps.append(seconds)
|
||||
|
||||
async def fake_fetch(*args: object, **kwargs: object) -> list[object]:
|
||||
if not sleeps:
|
||||
raise HuggingFaceRateLimitError("rate limited", retry_after=None)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils._fetch_file_list", side_effect=fake_fetch
|
||||
),
|
||||
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
|
||||
):
|
||||
await fetch_file_list_with_retry(ModelId("test/model"))
|
||||
|
||||
assert len(sleeps) == 1
|
||||
assert 1.0 <= sleeps[0] < 2.0 # 2**0 + jitter[0,1)
|
||||
|
||||
async def test_caps_sleep_at_max_window(self) -> None:
|
||||
sleeps: list[float] = []
|
||||
|
||||
async def fake_sleep(seconds: float) -> None:
|
||||
sleeps.append(seconds)
|
||||
|
||||
async def fake_fetch(*args: object, **kwargs: object) -> list[object]:
|
||||
if not sleeps:
|
||||
raise HuggingFaceRateLimitError("rate limited", retry_after=10_000.0)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils._fetch_file_list", side_effect=fake_fetch
|
||||
),
|
||||
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
|
||||
):
|
||||
await fetch_file_list_with_retry(ModelId("test/model"))
|
||||
|
||||
assert len(sleeps) == 1
|
||||
assert 300.0 <= sleeps[0] < 301.0 # cap + jitter[0,1)
|
||||
|
||||
async def test_retries_up_to_five_times(self) -> None:
|
||||
sleeps: list[float] = []
|
||||
|
||||
async def fake_sleep(seconds: float) -> None:
|
||||
sleeps.append(seconds)
|
||||
|
||||
async def fake_fetch(*args: object, **kwargs: object) -> list[object]:
|
||||
raise HuggingFaceRateLimitError("rate limited", retry_after=1.0)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils._fetch_file_list", side_effect=fake_fetch
|
||||
),
|
||||
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
|
||||
pytest.raises(HuggingFaceRateLimitError),
|
||||
):
|
||||
await fetch_file_list_with_retry(ModelId("test/model"))
|
||||
|
||||
assert len(sleeps) == 4 # 5 attempts -> 4 sleeps before giving up
|
||||
|
||||
|
||||
class TestDownloadFileRetry:
|
||||
@pytest.fixture
|
||||
async def target_dir(self, tmp_path: Path) -> AsyncIterator[Path]:
|
||||
target = tmp_path / "downloads"
|
||||
await aios.makedirs(target, exist_ok=True)
|
||||
yield target
|
||||
|
||||
async def test_uses_retry_after_from_error(self, target_dir: Path) -> None:
|
||||
sleeps: list[float] = []
|
||||
results: list[Path] = [target_dir / "file.bin"]
|
||||
|
||||
async def fake_sleep(seconds: float) -> None:
|
||||
sleeps.append(seconds)
|
||||
|
||||
async def fake_download(*args: object, **kwargs: object) -> Path:
|
||||
if not sleeps:
|
||||
raise HuggingFaceRateLimitError("rate limited", retry_after=5.0)
|
||||
return results[0]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils._download_file",
|
||||
side_effect=fake_download,
|
||||
),
|
||||
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
|
||||
):
|
||||
result = await download_file_with_retry(
|
||||
ModelId("test/model"), "main", "file.bin", target_dir
|
||||
)
|
||||
|
||||
assert result == results[0]
|
||||
assert len(sleeps) == 1
|
||||
assert 5.0 <= sleeps[0] < 6.0
|
||||
|
||||
async def test_caps_sleep_at_max_window(self, target_dir: Path) -> None:
|
||||
sleeps: list[float] = []
|
||||
results: list[Path] = [target_dir / "file.bin"]
|
||||
|
||||
async def fake_sleep(seconds: float) -> None:
|
||||
sleeps.append(seconds)
|
||||
|
||||
async def fake_download(*args: object, **kwargs: object) -> Path:
|
||||
if not sleeps:
|
||||
raise HuggingFaceRateLimitError("rate limited", retry_after=99_999.0)
|
||||
return results[0]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils._download_file",
|
||||
side_effect=fake_download,
|
||||
),
|
||||
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
|
||||
):
|
||||
await download_file_with_retry(
|
||||
ModelId("test/model"), "main", "file.bin", target_dir
|
||||
)
|
||||
|
||||
assert len(sleeps) == 1
|
||||
assert 300.0 <= sleeps[0] < 301.0
|
||||
|
||||
async def test_retries_up_to_five_times(self, target_dir: Path) -> None:
|
||||
sleeps: list[float] = []
|
||||
|
||||
async def fake_sleep(seconds: float) -> None:
|
||||
sleeps.append(seconds)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils._download_file",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=HuggingFaceRateLimitError("rate limited", retry_after=1.0),
|
||||
),
|
||||
patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep),
|
||||
pytest.raises(HuggingFaceRateLimitError),
|
||||
):
|
||||
await download_file_with_retry(
|
||||
ModelId("test/model"), "main", "file.bin", target_dir
|
||||
)
|
||||
|
||||
assert len(sleeps) == 4
|
||||
|
||||
|
||||
def _make_mock_session_returning(
|
||||
response_attrs: dict[str, object], method: str = "get"
|
||||
) -> MagicMock:
|
||||
"""Build a MagicMock that mimics ``create_http_session`` returning a
|
||||
response whose ``status`` / ``headers`` are set from ``response_attrs``.
|
||||
|
||||
Mocks the chain ``create_http_session().__aenter__() -> session``, and
|
||||
``session.<method>().__aenter__() -> response``.
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
for k, v in response_attrs.items():
|
||||
setattr(mock_response, k, v)
|
||||
|
||||
mock_session = MagicMock()
|
||||
method_mock = getattr(mock_session, method) # pyright: ignore[reportAny]
|
||||
method_mock.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_response
|
||||
)
|
||||
method_mock.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_session
|
||||
)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
return mock_factory
|
||||
|
||||
|
||||
REAL_HF_429_HEADER_DICT = {"ratelimit": '"api";r=0;t=52'}
|
||||
|
||||
|
||||
class TestRateLimitAtHttpCallSites:
|
||||
"""Verify each HF call site translates an HTTP 429 into a
|
||||
``HuggingFaceRateLimitError`` carrying the parsed ``retry_after``.
|
||||
|
||||
These tests would catch regressions where (a) the 429 branch is
|
||||
deleted, (b) ``_parse_retry_after`` stops being called, or
|
||||
(c) the wrong header object is passed to it.
|
||||
"""
|
||||
|
||||
async def test_fetch_file_list_maps_429_to_rate_limit_error(self) -> None:
|
||||
mock_factory = _make_mock_session_returning(
|
||||
{"status": 429, "headers": REAL_HF_429_HEADER_DICT}
|
||||
)
|
||||
with (
|
||||
patch("exo.download.download_utils.create_http_session", mock_factory),
|
||||
pytest.raises(HuggingFaceRateLimitError) as exc_info,
|
||||
):
|
||||
await _fetch_file_list(ModelId("test/model"), "main")
|
||||
assert exc_info.value.retry_after == 52.0
|
||||
|
||||
async def test_file_meta_maps_429_to_rate_limit_error(self) -> None:
|
||||
mock_factory = _make_mock_session_returning(
|
||||
{"status": 429, "headers": REAL_HF_429_HEADER_DICT}, method="head"
|
||||
)
|
||||
with (
|
||||
patch("exo.download.download_utils.create_http_session", mock_factory),
|
||||
pytest.raises(HuggingFaceRateLimitError) as exc_info,
|
||||
):
|
||||
await file_meta(ModelId("test/model"), "main", "weights.safetensors")
|
||||
assert exc_info.value.retry_after == 52.0
|
||||
|
||||
async def test_file_meta_maps_429_after_307_redirect(self) -> None:
|
||||
"""When the initial HEAD 307s and the redirected HEAD then 429s,
|
||||
the 429 must still surface as ``HuggingFaceRateLimitError``."""
|
||||
# First HEAD -> 307 with a Location header pointing somewhere new.
|
||||
first_response = MagicMock()
|
||||
first_response.status = 307
|
||||
first_response.headers = {"location": "/redirected/url"}
|
||||
|
||||
# Second HEAD (the recursive call) -> 429 with the real-HF header.
|
||||
second_response = MagicMock()
|
||||
second_response.status = 429
|
||||
second_response.headers = REAL_HF_429_HEADER_DICT
|
||||
|
||||
responses = iter([first_response, second_response])
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.head.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
side_effect=lambda: next(responses)
|
||||
)
|
||||
mock_session.head.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_session
|
||||
)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.create_http_session", mock_factory),
|
||||
pytest.raises(HuggingFaceRateLimitError) as exc_info,
|
||||
):
|
||||
await file_meta(ModelId("test/model"), "main", "weights.safetensors")
|
||||
assert exc_info.value.retry_after == 52.0
|
||||
|
||||
async def test_download_file_maps_429_to_rate_limit_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
# No local file -> _download_file goes straight to file_meta then GET.
|
||||
# We need both calls to succeed enough to reach the GET branch:
|
||||
# - file_meta returns a non-429 (size, etag) so we proceed.
|
||||
# - the GET then 429s.
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(100, "abc123"),
|
||||
),
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session",
|
||||
_make_mock_session_returning(
|
||||
{"status": 429, "headers": REAL_HF_429_HEADER_DICT}
|
||||
),
|
||||
),
|
||||
pytest.raises(HuggingFaceRateLimitError) as exc_info,
|
||||
):
|
||||
await _download_file(
|
||||
ModelId("test/model"), "main", "weights.safetensors", target_dir
|
||||
)
|
||||
assert exc_info.value.retry_after == 52.0
|
||||
Reference in New Issue
Block a user