mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 07:17:30 -05:00
Compare commits
1 Commits
gh-screens
...
aiohttp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
523eff541e |
@@ -6,8 +6,6 @@ readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.12.14",
|
||||
"types-aiofiles>=24.1.0.20250708",
|
||||
"pydantic>=2.11.7",
|
||||
"fastapi>=0.116.1",
|
||||
"filelock>=3.18.0",
|
||||
|
||||
@@ -8,13 +8,13 @@ import traceback
|
||||
from collections.abc import Awaitable
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
from typing import Callable, Literal, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import certifi
|
||||
import httpx
|
||||
from huggingface_hub import (
|
||||
snapshot_download, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
@@ -330,17 +330,17 @@ async def _fetch_file_list(
|
||||
headers = await get_download_headers()
|
||||
async with (
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.get(url, headers=headers) as response,
|
||||
):
|
||||
if response.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(response.status, model_id)
|
||||
response = await session.get(url, headers=headers)
|
||||
if response.status_code in [401, 403]:
|
||||
msg = await _build_auth_error_message(response.status_code, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
elif response.status == 429:
|
||||
elif response.status_code == 429:
|
||||
raise HuggingFaceRateLimitError(
|
||||
f"Couldn't download {model_id} because of HuggingFace rate limit."
|
||||
)
|
||||
elif response.status == 200:
|
||||
data_json = await response.text()
|
||||
elif response.status_code == 200:
|
||||
data_json = response.text
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
|
||||
files: list[FileListEntry] = []
|
||||
for item in data:
|
||||
@@ -353,7 +353,7 @@ async def _fetch_file_list(
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
raise Exception(f"Failed to fetch file list: {response.status_code}")
|
||||
|
||||
|
||||
async def get_download_headers() -> dict[str, str]:
|
||||
@@ -361,34 +361,29 @@ async def get_download_headers() -> dict[str, str]:
|
||||
|
||||
|
||||
def create_http_session(
|
||||
auto_decompress: bool = False,
|
||||
timeout_profile: Literal["short", "long"] = "long",
|
||||
) -> aiohttp.ClientSession:
|
||||
) -> httpx.AsyncClient:
|
||||
if timeout_profile == "short":
|
||||
total_timeout = 30
|
||||
connect_timeout = 10
|
||||
sock_read_timeout = 30
|
||||
sock_connect_timeout = 10
|
||||
read_timeout = 30
|
||||
else:
|
||||
total_timeout = 1800
|
||||
connect_timeout = 60
|
||||
sock_read_timeout = 60
|
||||
sock_connect_timeout = 60
|
||||
read_timeout = 60
|
||||
|
||||
ssl_context = ssl.create_default_context(
|
||||
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
|
||||
)
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
return aiohttp.ClientSession(
|
||||
auto_decompress=auto_decompress,
|
||||
connector=connector,
|
||||
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=total_timeout,
|
||||
# default here is to load env vars
|
||||
return httpx.AsyncClient(
|
||||
verify=ssl_context,
|
||||
timeout=httpx.Timeout(
|
||||
connect=connect_timeout,
|
||||
sock_read=sock_read_timeout,
|
||||
sock_connect=sock_connect_timeout,
|
||||
read=read_timeout,
|
||||
write=total_timeout,
|
||||
pool=total_timeout,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -415,26 +410,28 @@ async def file_meta(
|
||||
headers = await get_download_headers()
|
||||
async with (
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.head(url, headers=headers) as r,
|
||||
session.stream("HEAD", url, headers=headers) as r,
|
||||
):
|
||||
if r.status == 307:
|
||||
if r.status_code == 307:
|
||||
# On redirect, only trust Hugging Face's x-linked-* headers.
|
||||
x_linked_size = r.headers.get("x-linked-size")
|
||||
x_linked_etag = r.headers.get("x-linked-etag")
|
||||
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
|
||||
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
|
||||
if x_linked_size and x_linked_etag:
|
||||
content_length = int(x_linked_size)
|
||||
etag = trim_etag(x_linked_etag)
|
||||
return content_length, etag
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = r.headers.get("location")
|
||||
redirected_location = cast(str | None, r.headers.get("location"))
|
||||
return await file_meta(model_id, revision, path, redirected_location)
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
if r.status_code in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status_code, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
content_length = cast(
|
||||
str | None,
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length"),
|
||||
)
|
||||
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
|
||||
content_length = 0 if content_length is None else int(content_length)
|
||||
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
etag = trim_etag(etag)
|
||||
@@ -537,20 +534,20 @@ async def _download_file(
|
||||
n_read = resume_byte_pos or 0
|
||||
async with (
|
||||
create_http_session(timeout_profile="long") as session,
|
||||
session.get(url, headers=headers) as r,
|
||||
session.stream("GET", url, headers=headers, follow_redirects=True) as r,
|
||||
):
|
||||
if r.status == 404:
|
||||
if r.status_code == 404:
|
||||
raise FileNotFoundError(f"File not found: {url}")
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
if r.status_code in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status_code, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
assert r.status in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status}"
|
||||
assert r.status_code in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status_code}"
|
||||
)
|
||||
async with aiofiles.open(
|
||||
partial_path, "ab" if resume_byte_pos else "wb"
|
||||
) as f:
|
||||
while chunk := await r.content.read(8 * 1024 * 1024):
|
||||
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
|
||||
n_read = n_read + (await f.write(chunk))
|
||||
on_progress(n_read, length, False)
|
||||
|
||||
|
||||
@@ -189,6 +189,7 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
try:
|
||||
yield await task
|
||||
except Exception as e:
|
||||
task.cancel()
|
||||
logger.warning(f"Error downloading shard: {type(e).__name__}")
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import time
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
K = TypeVar("K")
|
||||
from collections.abc import Hashable
|
||||
|
||||
|
||||
class KeyedBackoff(Generic[K]):
|
||||
class KeyedBackoff[K: Hashable]:
|
||||
"""Tracks exponential backoff state per key."""
|
||||
|
||||
def __init__(self, base: float = 0.5, cap: float = 10.0):
|
||||
|
||||
4
uv.lock
generated
4
uv.lock
generated
@@ -367,7 +367,6 @@ version = "0.3.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -389,7 +388,6 @@ dependencies = [
|
||||
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "zstandard", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
@@ -406,7 +404,6 @@ dev = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiofiles", specifier = ">=24.1.0" },
|
||||
{ name = "aiohttp", specifier = ">=3.12.14" },
|
||||
{ name = "anyio", specifier = "==4.11.0" },
|
||||
{ name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" },
|
||||
{ name = "fastapi", specifier = ">=0.116.1" },
|
||||
@@ -428,7 +425,6 @@ requires-dist = [
|
||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||
{ name = "tiktoken", specifier = ">=0.12.0" },
|
||||
{ name = "tomlkit", specifier = ">=0.14.0" },
|
||||
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
|
||||
{ name = "zstandard", specifier = ">=0.23.0" },
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user