From 8dae3ecb9a58cd169723e064560dd05e5d5672f3 Mon Sep 17 00:00:00 2001 From: ciaranbor <81697641+ciaranbor@users.noreply.github.com> Date: Thu, 30 Apr 2026 19:06:15 +0100 Subject: [PATCH] A few targeted tweaks to address HF rate limits (#2009) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation - exo bursts ~200 HF Hub-API requests on every cold start, blowing past the anonymous 500-req/5-min budget. - The existing retry loop catches 429 generically and gives up in ~3s — well before HF's reset window. - `file_meta` and `_download_file` had no 429 handling at all (became `AssertionError`). - Disk file-list cache was bypassed on every process restart. ## Changes All in `src/exo/download/download_utils.py` + tests. - Parse `t=` from HF's `RateLimit` header on 429; sleep `min(t, 300s) + jitter`. - Handle 429 at all three call sites (`_fetch_file_list`, `file_meta`, `_download_file`). - `n_attempts`: 3 → 5. - Disk cache now primary across restarts (24h mtime TTL). - `?recursive=true` instead of N+1 subdir walks. ## Why It Works `t=` is HF's "wait this long and you'll be unblocked" — sleeping that long lets the window reset. Disk-cache-as-primary plus recursive listing cuts cold-start Hub-API traffic by ~10×. ## Test Plan ### Manual Testing MacBook Pro M1 Max. Tripped the real HF 429. Pre-fix: failed in 3.4s. Post-fix: slept (HF returned `t=158`) and recovered. ### Automated Testing - New `test_rate_limit_handling.py` (19 tests) — header parsing, retry-loop behaviour, plus HTTP-level coverage that mocks aiohttp to return a 429 and asserts each call site raises `HuggingFaceRateLimitError(retry_after=52.0)`. - New `TestFileListCacheTTL` in `test_offline_mode.py` — fresh cache hits, stale cache refetches. - 421 tests pass; basedpyright / ruff / nix fmt clean. --- src/exo/download/download_utils.py | 114 ++++-- src/exo/download/tests/test_offline_mode.py | 63 ++++ .../tests/test_rate_limit_handling.py | 355 ++++++++++++++++++ 3 files changed, 507 insertions(+), 25 deletions(-) create mode 100644 src/exo/download/tests/test_rate_limit_handling.py diff --git a/src/exo/download/download_utils.py b/src/exo/download/download_utils.py index bab70a1c5..1e81f6d09 100644 --- a/src/exo/download/download_utils.py +++ b/src/exo/download/download_utils.py @@ -1,11 +1,12 @@ import asyncio import hashlib import os +import random import shutil import ssl import time import traceback -from collections.abc import Awaitable +from collections.abc import Awaitable, Mapping from datetime import timedelta from pathlib import Path from typing import Callable, Literal @@ -55,6 +56,36 @@ class HuggingFaceAuthenticationError(Exception): class HuggingFaceRateLimitError(Exception): """429 Huggingface code""" + def __init__(self, msg: str, retry_after: float | None = None) -> None: + super().__init__(msg) + self.retry_after = retry_after + + +def _parse_retry_after(headers: Mapping[str, str]) -> float | None: + """Parse seconds-to-reset from HF's RateLimit header. + + HF sends e.g. ``ratelimit: "api";r=0;t=52`` on 429s; ``t`` is the wait. + Returns ``None`` if the header is missing or has no ``t`` field. + """ + raw = headers.get("RateLimit") or headers.get("ratelimit") + if raw is None: + return None + for part in raw.split(";"): + key, _, val = part.strip().partition("=") + if key == "t": + try: + return float(val) + except ValueError: + return None + return None + + +# reset window is 5 min +_RATE_LIMIT_MAX_SLEEP_SECS = 300.0 + +# 24h. Manually clear the cache (or `delete_model`) to force a refresh. +_FILE_LIST_CACHE_TTL_SECS = 24 * 60 * 60 + async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str: token = await get_hf_token() @@ -348,9 +379,6 @@ async def _build_file_list_from_local_directory( return None -_fetched_file_lists_this_session: set[str] = set() - - async def fetch_file_list_with_cache( model_id: ModelId, revision: str = "main", @@ -360,13 +388,16 @@ async def fetch_file_list_with_cache( ) -> list[FileListEntry]: target_dir = await ensure_cache_dir(model_id) cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json" - cache_key = f"{model_id.normalize()}--{revision}" - if cache_key in _fetched_file_lists_this_session and await aios.path.exists( - cache_file - ): - async with aiofiles.open(cache_file, "r") as f: - return TypeAdapter(list[FileListEntry]).validate_json(await f.read()) + # cache survives process restarts so cold starts don't re-burst HF + if await aios.path.exists(cache_file): + try: + cache_age = time.time() - (await aios.stat(cache_file)).st_mtime + except OSError: + cache_age = float("inf") + if cache_age < _FILE_LIST_CACHE_TTL_SECS: + async with aiofiles.open(cache_file, "r") as f: + return TypeAdapter(list[FileListEntry]).validate_json(await f.read()) if skip_internet: if await aios.path.exists(cache_file): @@ -395,7 +426,6 @@ async def fetch_file_list_with_cache( await f.write( TypeAdapter(list[FileListEntry]).dump_json(file_list).decode() ) - _fetched_file_lists_this_session.add(cache_key) return file_list except Exception as e: logger.opt(exception=e).warning( @@ -426,17 +456,29 @@ async def fetch_file_list_with_retry( recursive: bool = False, on_connection_lost: Callable[[], None] = lambda: None, ) -> list[FileListEntry]: - n_attempts = 3 + n_attempts = 5 for attempt in range(n_attempts): try: return await _fetch_file_list(model_id, revision, path, recursive) except HuggingFaceAuthenticationError: raise + except HuggingFaceRateLimitError as e: + if attempt == n_attempts - 1: + raise + sleep_for = e.retry_after if e.retry_after is not None else 2.0**attempt + sleep_for = min(sleep_for, _RATE_LIMIT_MAX_SLEEP_SECS) + random.uniform( + 0, 1 + ) + logger.warning( + f"Rate limited by HuggingFace fetching file list for {model_id}; " + f"sleeping {sleep_for:.1f}s before retry {attempt + 2}/{n_attempts}" + ) + await asyncio.sleep(sleep_for) except Exception as e: on_connection_lost() if attempt == n_attempts - 1: raise e - await asyncio.sleep(2.0**attempt) + await asyncio.sleep(2.0**attempt + random.uniform(0, 1)) raise Exception( f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}" ) @@ -447,6 +489,9 @@ async def _fetch_file_list( ) -> list[FileListEntry]: api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}" url = f"{api_url}/{path}" if path else api_url + # ?recursive=true returns the whole subtree in one request + if recursive: + url = f"{url}?recursive=true" headers = await get_download_headers() async with ( @@ -458,7 +503,8 @@ async def _fetch_file_list( raise HuggingFaceAuthenticationError(msg) elif response.status == 429: raise HuggingFaceRateLimitError( - f"Couldn't download {model_id} because of HuggingFace rate limit." + f"HuggingFace rate limit hit fetching file list for {model_id}", + retry_after=_parse_retry_after(response.headers), ) elif response.status == 200: data_json = await response.text() @@ -468,10 +514,14 @@ async def _fetch_file_list( if item.type == "file": files.append(FileListEntry.model_validate(item)) elif item.type == "directory" and recursive: - subfiles = await _fetch_file_list( - model_id, revision, item.path, recursive - ) - files.extend(subfiles) + # already inlined by ?recursive=true + continue + if recursive and len(data) >= 1000: + # HF tree endpoint paginates at 1000; we don't follow cursors + logger.warning( + f"File list for {model_id} hit the 1000-entry page cap " + "and may be truncated; cursor pagination is not implemented" + ) return files else: raise Exception(f"Failed to fetch file list: {response.status}") @@ -552,6 +602,11 @@ async def file_meta( if r.status in [401, 403]: msg = await _build_auth_error_message(r.status, model_id) raise HuggingFaceAuthenticationError(msg) + if r.status == 429: + raise HuggingFaceRateLimitError( + f"HuggingFace rate limit hit fetching metadata for {model_id}/{path}", + retry_after=_parse_retry_after(r.headers), + ) content_length = int( r.headers.get("x-linked-size") or r.headers.get("content-length") or 0 ) @@ -571,7 +626,7 @@ async def download_file_with_retry( on_connection_lost: Callable[[], None] = lambda: None, skip_internet: bool = False, ) -> Path: - n_attempts = 3 + n_attempts = 5 for attempt in range(n_attempts): try: return await _download_file( @@ -583,12 +638,16 @@ async def download_file_with_retry( raise except HuggingFaceRateLimitError as e: if attempt == n_attempts - 1: - raise e - logger.error( - f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}" + raise + sleep_for = e.retry_after if e.retry_after is not None else 2.0**attempt + sleep_for = min(sleep_for, _RATE_LIMIT_MAX_SLEEP_SECS) + random.uniform( + 0, 1 ) - logger.error(traceback.format_exc()) - await asyncio.sleep(2.0**attempt) + logger.warning( + f"Rate limited by HuggingFace downloading {model_id}/{path}; " + f"sleeping {sleep_for:.1f}s before retry {attempt + 2}/{n_attempts}" + ) + await asyncio.sleep(sleep_for) except Exception as e: if attempt == n_attempts - 1: on_connection_lost() @@ -597,7 +656,7 @@ async def download_file_with_retry( f"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}" ) logger.error(traceback.format_exc()) - await asyncio.sleep(2.0**attempt) + await asyncio.sleep(2.0**attempt + random.uniform(0, 1)) raise Exception( f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}" ) @@ -665,6 +724,11 @@ async def _download_file( if r.status in [401, 403]: msg = await _build_auth_error_message(r.status, model_id) raise HuggingFaceAuthenticationError(msg) + if r.status == 429: + raise HuggingFaceRateLimitError( + f"HuggingFace rate limit hit downloading {model_id}/{path}", + retry_after=_parse_retry_after(r.headers), + ) assert r.status in [200, 206], ( f"Failed to download {path} from {url}: {r.status}" ) diff --git a/src/exo/download/tests/test_offline_mode.py b/src/exo/download/tests/test_offline_mode.py index 9a94e5205..29a69275a 100644 --- a/src/exo/download/tests/test_offline_mode.py +++ b/src/exo/download/tests/test_offline_mode.py @@ -1,5 +1,7 @@ """Tests for offline/air-gapped mode.""" +import os +import time from collections.abc import AsyncIterator from pathlib import Path from unittest.mock import AsyncMock, patch @@ -231,3 +233,64 @@ class TestFetchFileListOffline: raise FileNotFoundError.""" with pytest.raises(FileNotFoundError, match="No internet"): await fetch_file_list_with_cache(model_id, "main", skip_internet=True) + + +class TestFileListCacheTTL: + async def test_uses_fresh_cache_without_fetching( + self, model_id: ModelId, temp_models_dir: Path + ) -> None: + from pydantic import TypeAdapter + + cache_dir = temp_models_dir / "caches" / model_id.normalize() + await aios.makedirs(cache_dir, exist_ok=True) + + cached_list = [ + FileListEntry(type="file", path="model.safetensors", size=1000), + ] + cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json" + async with aiofiles.open(cache_file, "w") as f: + await f.write( + TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode() + ) + + with patch( + "exo.download.download_utils.fetch_file_list_with_retry", + new_callable=AsyncMock, + ) as mock_fetch: + result = await fetch_file_list_with_cache(model_id, "main") + + assert result == cached_list + mock_fetch.assert_not_called() + + async def test_refetches_when_cache_older_than_ttl( + self, model_id: ModelId, temp_models_dir: Path + ) -> None: + from pydantic import TypeAdapter + + from exo.download.download_utils import ( + _FILE_LIST_CACHE_TTL_SECS, # pyright: ignore[reportPrivateUsage] + ) + + cache_dir = temp_models_dir / "caches" / model_id.normalize() + await aios.makedirs(cache_dir, exist_ok=True) + + stale_list = [FileListEntry(type="file", path="stale.bin", size=1)] + cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json" + async with aiofiles.open(cache_file, "w") as f: + await f.write( + TypeAdapter(list[FileListEntry]).dump_json(stale_list).decode() + ) + + old_mtime = time.time() - _FILE_LIST_CACHE_TTL_SECS - 60 + os.utime(cache_file, (old_mtime, old_mtime)) + + fresh_list = [FileListEntry(type="file", path="fresh.bin", size=2)] + with patch( + "exo.download.download_utils.fetch_file_list_with_retry", + new_callable=AsyncMock, + return_value=fresh_list, + ) as mock_fetch: + result = await fetch_file_list_with_cache(model_id, "main") + + assert result == fresh_list + mock_fetch.assert_called_once() diff --git a/src/exo/download/tests/test_rate_limit_handling.py b/src/exo/download/tests/test_rate_limit_handling.py new file mode 100644 index 000000000..3af71334c --- /dev/null +++ b/src/exo/download/tests/test_rate_limit_handling.py @@ -0,0 +1,355 @@ +"""Tests for HuggingFace 429 rate-limit handling in download_utils.""" + +from collections.abc import AsyncIterator +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import aiofiles.os as aios +import pytest + +from exo.download.download_utils import ( + HuggingFaceRateLimitError, + _download_file, # pyright: ignore[reportPrivateUsage] + _fetch_file_list, # pyright: ignore[reportPrivateUsage] + _parse_retry_after, # pyright: ignore[reportPrivateUsage] + download_file_with_retry, + fetch_file_list_with_retry, + file_meta, +) +from exo.shared.types.common import ModelId + +# captured from a real HF 429 on 2026-04-30 (header is lowercased by Cloudfront) +REAL_HF_429_HEADERS_2026_04_30 = { + "ratelimit": '"api";r=0;t=52', + "ratelimit-policy": '"fixed window";"api";q=500;w=300', +} + + +class TestParseRetryAfter: + def test_parses_documented_format(self) -> None: + assert _parse_retry_after({"RateLimit": '"api";r=0;t=243'}) == 243.0 + + def test_parses_real_hf_response(self) -> None: + assert _parse_retry_after(REAL_HF_429_HEADERS_2026_04_30) == 52.0 + + def test_parses_resolvers_bucket(self) -> None: + assert _parse_retry_after({"ratelimit": '"resolvers";r=0;t=120'}) == 120.0 + + def test_parses_pages_bucket(self) -> None: + assert _parse_retry_after({"ratelimit": '"pages";r=0;t=10'}) == 10.0 + + def test_returns_none_when_header_missing(self) -> None: + assert _parse_retry_after({}) is None + + def test_returns_none_when_only_retry_after_present(self) -> None: + assert _parse_retry_after({"Retry-After": "60"}) is None + + def test_returns_none_when_format_unrecognised(self) -> None: + assert _parse_retry_after({"ratelimit": "garbage"}) is None + + def test_handles_extra_whitespace(self) -> None: + assert _parse_retry_after({"ratelimit": '"api"; r=0; t=42'}) == 42.0 + + +class TestFetchFileListRetry: + async def test_uses_retry_after_from_error(self) -> None: + sleeps: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + sleeps.append(seconds) + + async def fake_fetch(*args: object, **kwargs: object) -> list[object]: + if not sleeps: + raise HuggingFaceRateLimitError("rate limited", retry_after=2.0) + return [] + + with ( + patch( + "exo.download.download_utils._fetch_file_list", side_effect=fake_fetch + ), + patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep), + ): + result = await fetch_file_list_with_retry(ModelId("test/model")) + + assert result == [] + assert len(sleeps) == 1 + assert 2.0 <= sleeps[0] < 3.0 # retry_after + jitter[0,1) + + async def test_falls_back_to_exp_backoff_when_no_retry_after(self) -> None: + sleeps: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + sleeps.append(seconds) + + async def fake_fetch(*args: object, **kwargs: object) -> list[object]: + if not sleeps: + raise HuggingFaceRateLimitError("rate limited", retry_after=None) + return [] + + with ( + patch( + "exo.download.download_utils._fetch_file_list", side_effect=fake_fetch + ), + patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep), + ): + await fetch_file_list_with_retry(ModelId("test/model")) + + assert len(sleeps) == 1 + assert 1.0 <= sleeps[0] < 2.0 # 2**0 + jitter[0,1) + + async def test_caps_sleep_at_max_window(self) -> None: + sleeps: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + sleeps.append(seconds) + + async def fake_fetch(*args: object, **kwargs: object) -> list[object]: + if not sleeps: + raise HuggingFaceRateLimitError("rate limited", retry_after=10_000.0) + return [] + + with ( + patch( + "exo.download.download_utils._fetch_file_list", side_effect=fake_fetch + ), + patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep), + ): + await fetch_file_list_with_retry(ModelId("test/model")) + + assert len(sleeps) == 1 + assert 300.0 <= sleeps[0] < 301.0 # cap + jitter[0,1) + + async def test_retries_up_to_five_times(self) -> None: + sleeps: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + sleeps.append(seconds) + + async def fake_fetch(*args: object, **kwargs: object) -> list[object]: + raise HuggingFaceRateLimitError("rate limited", retry_after=1.0) + + with ( + patch( + "exo.download.download_utils._fetch_file_list", side_effect=fake_fetch + ), + patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep), + pytest.raises(HuggingFaceRateLimitError), + ): + await fetch_file_list_with_retry(ModelId("test/model")) + + assert len(sleeps) == 4 # 5 attempts -> 4 sleeps before giving up + + +class TestDownloadFileRetry: + @pytest.fixture + async def target_dir(self, tmp_path: Path) -> AsyncIterator[Path]: + target = tmp_path / "downloads" + await aios.makedirs(target, exist_ok=True) + yield target + + async def test_uses_retry_after_from_error(self, target_dir: Path) -> None: + sleeps: list[float] = [] + results: list[Path] = [target_dir / "file.bin"] + + async def fake_sleep(seconds: float) -> None: + sleeps.append(seconds) + + async def fake_download(*args: object, **kwargs: object) -> Path: + if not sleeps: + raise HuggingFaceRateLimitError("rate limited", retry_after=5.0) + return results[0] + + with ( + patch( + "exo.download.download_utils._download_file", + side_effect=fake_download, + ), + patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep), + ): + result = await download_file_with_retry( + ModelId("test/model"), "main", "file.bin", target_dir + ) + + assert result == results[0] + assert len(sleeps) == 1 + assert 5.0 <= sleeps[0] < 6.0 + + async def test_caps_sleep_at_max_window(self, target_dir: Path) -> None: + sleeps: list[float] = [] + results: list[Path] = [target_dir / "file.bin"] + + async def fake_sleep(seconds: float) -> None: + sleeps.append(seconds) + + async def fake_download(*args: object, **kwargs: object) -> Path: + if not sleeps: + raise HuggingFaceRateLimitError("rate limited", retry_after=99_999.0) + return results[0] + + with ( + patch( + "exo.download.download_utils._download_file", + side_effect=fake_download, + ), + patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep), + ): + await download_file_with_retry( + ModelId("test/model"), "main", "file.bin", target_dir + ) + + assert len(sleeps) == 1 + assert 300.0 <= sleeps[0] < 301.0 + + async def test_retries_up_to_five_times(self, target_dir: Path) -> None: + sleeps: list[float] = [] + + async def fake_sleep(seconds: float) -> None: + sleeps.append(seconds) + + with ( + patch( + "exo.download.download_utils._download_file", + new_callable=AsyncMock, + side_effect=HuggingFaceRateLimitError("rate limited", retry_after=1.0), + ), + patch("exo.download.download_utils.asyncio.sleep", side_effect=fake_sleep), + pytest.raises(HuggingFaceRateLimitError), + ): + await download_file_with_retry( + ModelId("test/model"), "main", "file.bin", target_dir + ) + + assert len(sleeps) == 4 + + +def _make_mock_session_returning( + response_attrs: dict[str, object], method: str = "get" +) -> MagicMock: + """Build a MagicMock that mimics ``create_http_session`` returning a + response whose ``status`` / ``headers`` are set from ``response_attrs``. + + Mocks the chain ``create_http_session().__aenter__() -> session``, and + ``session.().__aenter__() -> response``. + """ + mock_response = MagicMock() + for k, v in response_attrs.items(): + setattr(mock_response, k, v) + + mock_session = MagicMock() + method_mock = getattr(mock_session, method) # pyright: ignore[reportAny] + method_mock.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny] + return_value=mock_response + ) + method_mock.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny] + return_value=None + ) + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny] + return_value=mock_session + ) + mock_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny] + return_value=None + ) + return mock_factory + + +REAL_HF_429_HEADER_DICT = {"ratelimit": '"api";r=0;t=52'} + + +class TestRateLimitAtHttpCallSites: + """Verify each HF call site translates an HTTP 429 into a + ``HuggingFaceRateLimitError`` carrying the parsed ``retry_after``. + + These tests would catch regressions where (a) the 429 branch is + deleted, (b) ``_parse_retry_after`` stops being called, or + (c) the wrong header object is passed to it. + """ + + async def test_fetch_file_list_maps_429_to_rate_limit_error(self) -> None: + mock_factory = _make_mock_session_returning( + {"status": 429, "headers": REAL_HF_429_HEADER_DICT} + ) + with ( + patch("exo.download.download_utils.create_http_session", mock_factory), + pytest.raises(HuggingFaceRateLimitError) as exc_info, + ): + await _fetch_file_list(ModelId("test/model"), "main") + assert exc_info.value.retry_after == 52.0 + + async def test_file_meta_maps_429_to_rate_limit_error(self) -> None: + mock_factory = _make_mock_session_returning( + {"status": 429, "headers": REAL_HF_429_HEADER_DICT}, method="head" + ) + with ( + patch("exo.download.download_utils.create_http_session", mock_factory), + pytest.raises(HuggingFaceRateLimitError) as exc_info, + ): + await file_meta(ModelId("test/model"), "main", "weights.safetensors") + assert exc_info.value.retry_after == 52.0 + + async def test_file_meta_maps_429_after_307_redirect(self) -> None: + """When the initial HEAD 307s and the redirected HEAD then 429s, + the 429 must still surface as ``HuggingFaceRateLimitError``.""" + # First HEAD -> 307 with a Location header pointing somewhere new. + first_response = MagicMock() + first_response.status = 307 + first_response.headers = {"location": "/redirected/url"} + + # Second HEAD (the recursive call) -> 429 with the real-HF header. + second_response = MagicMock() + second_response.status = 429 + second_response.headers = REAL_HF_429_HEADER_DICT + + responses = iter([first_response, second_response]) + + mock_session = MagicMock() + mock_session.head.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny] + side_effect=lambda: next(responses) + ) + mock_session.head.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny] + return_value=None + ) + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny] + return_value=mock_session + ) + mock_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny] + return_value=None + ) + + with ( + patch("exo.download.download_utils.create_http_session", mock_factory), + pytest.raises(HuggingFaceRateLimitError) as exc_info, + ): + await file_meta(ModelId("test/model"), "main", "weights.safetensors") + assert exc_info.value.retry_after == 52.0 + + async def test_download_file_maps_429_to_rate_limit_error( + self, tmp_path: Path + ) -> None: + target_dir = tmp_path / "downloads" + await aios.makedirs(target_dir, exist_ok=True) + # No local file -> _download_file goes straight to file_meta then GET. + # We need both calls to succeed enough to reach the GET branch: + # - file_meta returns a non-429 (size, etag) so we proceed. + # - the GET then 429s. + with ( + patch( + "exo.download.download_utils.file_meta", + new_callable=AsyncMock, + return_value=(100, "abc123"), + ), + patch( + "exo.download.download_utils.create_http_session", + _make_mock_session_returning( + {"status": 429, "headers": REAL_HF_429_HEADER_DICT} + ), + ), + pytest.raises(HuggingFaceRateLimitError) as exc_info, + ): + await _download_file( + ModelId("test/model"), "main", "weights.safetensors", target_dir + ) + assert exc_info.value.retry_after == 52.0