mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-31 01:01:11 -05:00
Compare commits
1 Commits
ciaran/pro
...
aiohttp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02b3bbfb3d |
@@ -6,8 +6,6 @@ readme = "README.md"
|
|||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiofiles>=24.1.0",
|
"aiofiles>=24.1.0",
|
||||||
"aiohttp>=3.12.14",
|
|
||||||
"types-aiofiles>=24.1.0.20250708",
|
|
||||||
"pydantic>=2.11.7",
|
"pydantic>=2.11.7",
|
||||||
"fastapi>=0.116.1",
|
"fastapi>=0.116.1",
|
||||||
"filelock>=3.18.0",
|
"filelock>=3.18.0",
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class DownloadCoordinator:
|
|||||||
def _start_download_task(
|
def _start_download_task(
|
||||||
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
|
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logger.warning("starting download for {shard}")
|
||||||
model_id = shard.model_card.model_id
|
model_id = shard.model_card.model_id
|
||||||
|
|
||||||
# Emit ongoing status
|
# Emit ongoing status
|
||||||
|
|||||||
@@ -8,13 +8,13 @@ import traceback
|
|||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Literal
|
from typing import Callable, Literal, cast
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os as aios
|
import aiofiles.os as aios
|
||||||
import aiohttp
|
|
||||||
import certifi
|
import certifi
|
||||||
|
import httpx
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
snapshot_download, # pyright: ignore[reportUnknownVariableType]
|
snapshot_download, # pyright: ignore[reportUnknownVariableType]
|
||||||
)
|
)
|
||||||
@@ -176,7 +176,7 @@ async def fetch_file_list_with_cache(
|
|||||||
# Fetch failed - try cache fallback
|
# Fetch failed - try cache fallback
|
||||||
if await aios.path.exists(cache_file):
|
if await aios.path.exists(cache_file):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
f"{type(e).__name__}: Failed to fetch file list for {model_id}, using cached data"
|
||||||
)
|
)
|
||||||
async with aiofiles.open(cache_file, "r") as f:
|
async with aiofiles.open(cache_file, "r") as f:
|
||||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||||
@@ -196,7 +196,7 @@ async def fetch_file_list_with_retry(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt == n_attempts - 1:
|
if attempt == n_attempts - 1:
|
||||||
raise e
|
raise e
|
||||||
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
|
await asyncio.sleep(min(16, 0.5 * float(2.0 ** int(attempt))))
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
||||||
)
|
)
|
||||||
@@ -211,26 +211,25 @@ async def _fetch_file_list(
|
|||||||
headers = await get_download_headers()
|
headers = await get_download_headers()
|
||||||
async with (
|
async with (
|
||||||
create_http_session(timeout_profile="short") as session,
|
create_http_session(timeout_profile="short") as session,
|
||||||
session.get(url, headers=headers) as response,
|
|
||||||
):
|
):
|
||||||
if response.status in [401, 403]:
|
response = await session.get(url, headers=headers)
|
||||||
msg = await _build_auth_error_message(response.status, model_id)
|
if response.status_code in [401, 403]:
|
||||||
|
msg = await _build_auth_error_message(response.status_code, model_id)
|
||||||
raise HuggingFaceAuthenticationError(msg)
|
raise HuggingFaceAuthenticationError(msg)
|
||||||
if response.status == 200:
|
if response.status_code != 200:
|
||||||
data_json = await response.text()
|
raise Exception(f"Failed to fetch file list: {response.status_code}")
|
||||||
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
|
|
||||||
files: list[FileListEntry] = []
|
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
|
||||||
for item in data:
|
files: list[FileListEntry] = []
|
||||||
if item.type == "file":
|
for item in data:
|
||||||
files.append(FileListEntry.model_validate(item))
|
if item.type == "file":
|
||||||
elif item.type == "directory" and recursive:
|
files.append(FileListEntry.model_validate(item))
|
||||||
subfiles = await _fetch_file_list(
|
elif item.type == "directory" and recursive:
|
||||||
model_id, revision, item.path, recursive
|
subfiles = await _fetch_file_list(
|
||||||
)
|
model_id, revision, item.path, recursive
|
||||||
files.extend(subfiles)
|
)
|
||||||
return files
|
files.extend(subfiles)
|
||||||
else:
|
return files
|
||||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_download_headers() -> dict[str, str]:
|
async def get_download_headers() -> dict[str, str]:
|
||||||
@@ -238,34 +237,29 @@ async def get_download_headers() -> dict[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
def create_http_session(
|
def create_http_session(
|
||||||
auto_decompress: bool = False,
|
|
||||||
timeout_profile: Literal["short", "long"] = "long",
|
timeout_profile: Literal["short", "long"] = "long",
|
||||||
) -> aiohttp.ClientSession:
|
) -> httpx.AsyncClient:
|
||||||
if timeout_profile == "short":
|
if timeout_profile == "short":
|
||||||
total_timeout = 30
|
total_timeout = 30
|
||||||
connect_timeout = 10
|
connect_timeout = 10
|
||||||
sock_read_timeout = 30
|
read_timeout = 30
|
||||||
sock_connect_timeout = 10
|
|
||||||
else:
|
else:
|
||||||
total_timeout = 1800
|
total_timeout = 1800
|
||||||
connect_timeout = 60
|
connect_timeout = 60
|
||||||
sock_read_timeout = 1800
|
read_timeout = 1800
|
||||||
sock_connect_timeout = 60
|
|
||||||
|
|
||||||
ssl_context = ssl.create_default_context(
|
ssl_context = ssl.create_default_context(
|
||||||
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
|
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
|
||||||
)
|
)
|
||||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
||||||
|
|
||||||
return aiohttp.ClientSession(
|
# default here is to load env vars
|
||||||
auto_decompress=auto_decompress,
|
return httpx.AsyncClient(
|
||||||
connector=connector,
|
verify=ssl_context,
|
||||||
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
|
timeout=httpx.Timeout(
|
||||||
timeout=aiohttp.ClientTimeout(
|
|
||||||
total=total_timeout,
|
|
||||||
connect=connect_timeout,
|
connect=connect_timeout,
|
||||||
sock_read=sock_read_timeout,
|
read=read_timeout,
|
||||||
sock_connect=sock_connect_timeout,
|
write=total_timeout,
|
||||||
|
pool=total_timeout,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -292,26 +286,28 @@ async def file_meta(
|
|||||||
headers = await get_download_headers()
|
headers = await get_download_headers()
|
||||||
async with (
|
async with (
|
||||||
create_http_session(timeout_profile="short") as session,
|
create_http_session(timeout_profile="short") as session,
|
||||||
session.head(url, headers=headers) as r,
|
session.stream("HEAD", url, headers=headers) as r,
|
||||||
):
|
):
|
||||||
if r.status == 307:
|
if r.status_code == 307:
|
||||||
# On redirect, only trust Hugging Face's x-linked-* headers.
|
# On redirect, only trust Hugging Face's x-linked-* headers.
|
||||||
x_linked_size = r.headers.get("x-linked-size")
|
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
|
||||||
x_linked_etag = r.headers.get("x-linked-etag")
|
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
|
||||||
if x_linked_size and x_linked_etag:
|
if x_linked_size and x_linked_etag:
|
||||||
content_length = int(x_linked_size)
|
content_length = int(x_linked_size)
|
||||||
etag = trim_etag(x_linked_etag)
|
etag = trim_etag(x_linked_etag)
|
||||||
return content_length, etag
|
return content_length, etag
|
||||||
# Otherwise, follow the redirect to get authoritative size/hash
|
# Otherwise, follow the redirect to get authoritative size/hash
|
||||||
redirected_location = r.headers.get("location")
|
redirected_location = cast(str | None, r.headers.get("location"))
|
||||||
return await file_meta(model_id, revision, path, redirected_location)
|
return await file_meta(model_id, revision, path, redirected_location)
|
||||||
if r.status in [401, 403]:
|
if r.status_code in [401, 403]:
|
||||||
msg = await _build_auth_error_message(r.status, model_id)
|
msg = await _build_auth_error_message(r.status_code, model_id)
|
||||||
raise HuggingFaceAuthenticationError(msg)
|
raise HuggingFaceAuthenticationError(msg)
|
||||||
content_length = int(
|
content_length = cast(
|
||||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
str | None,
|
||||||
|
r.headers.get("x-linked-size") or r.headers.get("content-length"),
|
||||||
)
|
)
|
||||||
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
|
content_length = 0 if content_length is None else int(content_length)
|
||||||
|
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
|
||||||
assert content_length > 0, f"No content length for {url}"
|
assert content_length > 0, f"No content length for {url}"
|
||||||
assert etag is not None, f"No remote hash for {url}"
|
assert etag is not None, f"No remote hash for {url}"
|
||||||
etag = trim_etag(etag)
|
etag = trim_etag(etag)
|
||||||
@@ -340,7 +336,7 @@ async def download_file_with_retry(
|
|||||||
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||||
)
|
)
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
await asyncio.sleep(min(16, 0.5 * (2.0**attempt)))
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||||
)
|
)
|
||||||
@@ -353,6 +349,7 @@ async def _download_file(
|
|||||||
target_dir: Path,
|
target_dir: Path,
|
||||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
|
logger.warning(f"downloading {path} from {model_id} to {target_dir}")
|
||||||
target_path = target_dir / path
|
target_path = target_dir / path
|
||||||
|
|
||||||
if await aios.path.exists(target_path):
|
if await aios.path.exists(target_path):
|
||||||
@@ -392,20 +389,20 @@ async def _download_file(
|
|||||||
n_read = resume_byte_pos or 0
|
n_read = resume_byte_pos or 0
|
||||||
async with (
|
async with (
|
||||||
create_http_session(timeout_profile="long") as session,
|
create_http_session(timeout_profile="long") as session,
|
||||||
session.get(url, headers=headers) as r,
|
session.stream("GET", url, headers=headers, follow_redirects=True) as r,
|
||||||
):
|
):
|
||||||
if r.status == 404:
|
if r.status_code == 404:
|
||||||
raise FileNotFoundError(f"File not found: {url}")
|
raise FileNotFoundError(f"File not found: {url}")
|
||||||
if r.status in [401, 403]:
|
if r.status_code in [401, 403]:
|
||||||
msg = await _build_auth_error_message(r.status, model_id)
|
msg = await _build_auth_error_message(r.status_code, model_id)
|
||||||
raise HuggingFaceAuthenticationError(msg)
|
raise HuggingFaceAuthenticationError(msg)
|
||||||
assert r.status in [200, 206], (
|
assert r.status_code in [200, 206], (
|
||||||
f"Failed to download {path} from {url}: {r.status}"
|
f"Failed to download {path} from {url}: {r.status_code}"
|
||||||
)
|
)
|
||||||
async with aiofiles.open(
|
async with aiofiles.open(
|
||||||
partial_path, "ab" if resume_byte_pos else "wb"
|
partial_path, "ab" if resume_byte_pos else "wb"
|
||||||
) as f:
|
) as f:
|
||||||
while chunk := await r.content.read(8 * 1024 * 1024):
|
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
|
||||||
n_read = n_read + (await f.write(chunk))
|
n_read = n_read + (await f.write(chunk))
|
||||||
on_progress(n_read, length, False)
|
on_progress(n_read, length, False)
|
||||||
|
|
||||||
|
|||||||
@@ -168,7 +168,8 @@ class ResumableShardDownloader(ShardDownloader):
|
|||||||
yield await task
|
yield await task
|
||||||
# TODO: except Exception
|
# TODO: except Exception
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error downloading shard:", e)
|
task.cancel()
|
||||||
|
logger.opt(exception=e).error("Error downloading shard")
|
||||||
|
|
||||||
async def get_shard_download_status_for_shard(
|
async def get_shard_download_status_for_shard(
|
||||||
self, shard: ShardMetadata
|
self, shard: ShardMetadata
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Generic, TypeVar
|
from collections.abc import Hashable
|
||||||
|
|
||||||
K = TypeVar("K")
|
|
||||||
|
|
||||||
|
|
||||||
class KeyedBackoff(Generic[K]):
|
class KeyedBackoff[K: Hashable]:
|
||||||
"""Tracks exponential backoff state per key."""
|
"""Tracks exponential backoff state per key."""
|
||||||
|
|
||||||
def __init__(self, base: float = 0.5, cap: float = 10.0):
|
def __init__(self, base: float = 0.5, cap: float = 10.0):
|
||||||
|
|||||||
4
uv.lock
generated
4
uv.lock
generated
@@ -366,7 +366,6 @@ version = "0.3.0"
|
|||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
|
||||||
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
@@ -387,7 +386,6 @@ dependencies = [
|
|||||||
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dev-dependencies]
|
[package.dev-dependencies]
|
||||||
@@ -403,7 +401,6 @@ dev = [
|
|||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "aiofiles", specifier = ">=24.1.0" },
|
{ name = "aiofiles", specifier = ">=24.1.0" },
|
||||||
{ name = "aiohttp", specifier = ">=3.12.14" },
|
|
||||||
{ name = "anyio", specifier = "==4.11.0" },
|
{ name = "anyio", specifier = "==4.11.0" },
|
||||||
{ name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" },
|
{ name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" },
|
||||||
{ name = "fastapi", specifier = ">=0.116.1" },
|
{ name = "fastapi", specifier = ">=0.116.1" },
|
||||||
@@ -424,7 +421,6 @@ requires-dist = [
|
|||||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||||
{ name = "tiktoken", specifier = ">=0.12.0" },
|
{ name = "tiktoken", specifier = ">=0.12.0" },
|
||||||
{ name = "tomlkit", specifier = ">=0.14.0" },
|
{ name = "tomlkit", specifier = ">=0.14.0" },
|
||||||
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
|
|||||||
Reference in New Issue
Block a user