Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
02b3bbfb3d aaa 2026-01-28 15:57:35 +00:00
6 changed files with 57 additions and 66 deletions

View File

@@ -6,8 +6,6 @@ readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"aiofiles>=24.1.0", "aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"pydantic>=2.11.7", "pydantic>=2.11.7",
"fastapi>=0.116.1", "fastapi>=0.116.1",
"filelock>=3.18.0", "filelock>=3.18.0",

View File

@@ -121,6 +121,7 @@ class DownloadCoordinator:
def _start_download_task( def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None: ) -> None:
logger.warning("starting download for {shard}")
model_id = shard.model_card.model_id model_id = shard.model_card.model_id
# Emit ongoing status # Emit ongoing status

View File

@@ -8,13 +8,13 @@ import traceback
from collections.abc import Awaitable from collections.abc import Awaitable
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Callable, Literal from typing import Callable, Literal, cast
from urllib.parse import urljoin from urllib.parse import urljoin
import aiofiles import aiofiles
import aiofiles.os as aios import aiofiles.os as aios
import aiohttp
import certifi import certifi
import httpx
from huggingface_hub import ( from huggingface_hub import (
snapshot_download, # pyright: ignore[reportUnknownVariableType] snapshot_download, # pyright: ignore[reportUnknownVariableType]
) )
@@ -176,7 +176,7 @@ async def fetch_file_list_with_cache(
# Fetch failed - try cache fallback # Fetch failed - try cache fallback
if await aios.path.exists(cache_file): if await aios.path.exists(cache_file):
logger.warning( logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}" f"{type(e).__name__}: Failed to fetch file list for {model_id}, using cached data"
) )
async with aiofiles.open(cache_file, "r") as f: async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read()) return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
@@ -196,7 +196,7 @@ async def fetch_file_list_with_retry(
except Exception as e: except Exception as e:
if attempt == n_attempts - 1: if attempt == n_attempts - 1:
raise e raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt)))) await asyncio.sleep(min(16, 0.5 * float(2.0 ** int(attempt))))
raise Exception( raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}" f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
) )
@@ -211,26 +211,25 @@ async def _fetch_file_list(
headers = await get_download_headers() headers = await get_download_headers()
async with ( async with (
create_http_session(timeout_profile="short") as session, create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
): ):
if response.status in [401, 403]: response = await session.get(url, headers=headers)
msg = await _build_auth_error_message(response.status, model_id) if response.status_code in [401, 403]:
msg = await _build_auth_error_message(response.status_code, model_id)
raise HuggingFaceAuthenticationError(msg) raise HuggingFaceAuthenticationError(msg)
if response.status == 200: if response.status_code != 200:
data_json = await response.text() raise Exception(f"Failed to fetch file list: {response.status_code}")
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = [] data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
for item in data: files: list[FileListEntry] = []
if item.type == "file": for item in data:
files.append(FileListEntry.model_validate(item)) if item.type == "file":
elif item.type == "directory" and recursive: files.append(FileListEntry.model_validate(item))
subfiles = await _fetch_file_list( elif item.type == "directory" and recursive:
model_id, revision, item.path, recursive subfiles = await _fetch_file_list(
) model_id, revision, item.path, recursive
files.extend(subfiles) )
return files files.extend(subfiles)
else: return files
raise Exception(f"Failed to fetch file list: {response.status}")
async def get_download_headers() -> dict[str, str]: async def get_download_headers() -> dict[str, str]:
@@ -238,34 +237,29 @@ async def get_download_headers() -> dict[str, str]:
def create_http_session( def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long", timeout_profile: Literal["short", "long"] = "long",
) -> aiohttp.ClientSession: ) -> httpx.AsyncClient:
if timeout_profile == "short": if timeout_profile == "short":
total_timeout = 30 total_timeout = 30
connect_timeout = 10 connect_timeout = 10
sock_read_timeout = 30 read_timeout = 30
sock_connect_timeout = 10
else: else:
total_timeout = 1800 total_timeout = 1800
connect_timeout = 60 connect_timeout = 60
sock_read_timeout = 1800 read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context( ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where() cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
) )
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession( # default here is to load env vars
auto_decompress=auto_decompress, return httpx.AsyncClient(
connector=connector, verify=ssl_context,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None, timeout=httpx.Timeout(
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout, connect=connect_timeout,
sock_read=sock_read_timeout, read=read_timeout,
sock_connect=sock_connect_timeout, write=total_timeout,
pool=total_timeout,
), ),
) )
@@ -292,26 +286,28 @@ async def file_meta(
headers = await get_download_headers() headers = await get_download_headers()
async with ( async with (
create_http_session(timeout_profile="short") as session, create_http_session(timeout_profile="short") as session,
session.head(url, headers=headers) as r, session.stream("HEAD", url, headers=headers) as r,
): ):
if r.status == 307: if r.status_code == 307:
# On redirect, only trust Hugging Face's x-linked-* headers. # On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = r.headers.get("x-linked-size") x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
x_linked_etag = r.headers.get("x-linked-etag") x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
if x_linked_size and x_linked_etag: if x_linked_size and x_linked_etag:
content_length = int(x_linked_size) content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag) etag = trim_etag(x_linked_etag)
return content_length, etag return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash # Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location") redirected_location = cast(str | None, r.headers.get("location"))
return await file_meta(model_id, revision, path, redirected_location) return await file_meta(model_id, revision, path, redirected_location)
if r.status in [401, 403]: if r.status_code in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id) msg = await _build_auth_error_message(r.status_code, model_id)
raise HuggingFaceAuthenticationError(msg) raise HuggingFaceAuthenticationError(msg)
content_length = int( content_length = cast(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0 str | None,
r.headers.get("x-linked-size") or r.headers.get("content-length"),
) )
etag = r.headers.get("x-linked-etag") or r.headers.get("etag") content_length = 0 if content_length is None else int(content_length)
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
assert content_length > 0, f"No content length for {url}" assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}" assert etag is not None, f"No remote hash for {url}"
etag = trim_etag(etag) etag = trim_etag(etag)
@@ -340,7 +336,7 @@ async def download_file_with_retry(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}" f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
) )
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt))) await asyncio.sleep(min(16, 0.5 * (2.0**attempt)))
raise Exception( raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}" f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
) )
@@ -353,6 +349,7 @@ async def _download_file(
target_dir: Path, target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None, on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path: ) -> Path:
logger.warning(f"downloading {path} from {model_id} to {target_dir}")
target_path = target_dir / path target_path = target_dir / path
if await aios.path.exists(target_path): if await aios.path.exists(target_path):
@@ -392,20 +389,20 @@ async def _download_file(
n_read = resume_byte_pos or 0 n_read = resume_byte_pos or 0
async with ( async with (
create_http_session(timeout_profile="long") as session, create_http_session(timeout_profile="long") as session,
session.get(url, headers=headers) as r, session.stream("GET", url, headers=headers, follow_redirects=True) as r,
): ):
if r.status == 404: if r.status_code == 404:
raise FileNotFoundError(f"File not found: {url}") raise FileNotFoundError(f"File not found: {url}")
if r.status in [401, 403]: if r.status_code in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id) msg = await _build_auth_error_message(r.status_code, model_id)
raise HuggingFaceAuthenticationError(msg) raise HuggingFaceAuthenticationError(msg)
assert r.status in [200, 206], ( assert r.status_code in [200, 206], (
f"Failed to download {path} from {url}: {r.status}" f"Failed to download {path} from {url}: {r.status_code}"
) )
async with aiofiles.open( async with aiofiles.open(
partial_path, "ab" if resume_byte_pos else "wb" partial_path, "ab" if resume_byte_pos else "wb"
) as f: ) as f:
while chunk := await r.content.read(8 * 1024 * 1024): async for chunk in r.aiter_bytes(8 * 1024 * 1024):
n_read = n_read + (await f.write(chunk)) n_read = n_read + (await f.write(chunk))
on_progress(n_read, length, False) on_progress(n_read, length, False)

View File

@@ -168,7 +168,8 @@ class ResumableShardDownloader(ShardDownloader):
yield await task yield await task
# TODO: except Exception # TODO: except Exception
except Exception as e: except Exception as e:
logger.error("Error downloading shard:", e) task.cancel()
logger.opt(exception=e).error("Error downloading shard")
async def get_shard_download_status_for_shard( async def get_shard_download_status_for_shard(
self, shard: ShardMetadata self, shard: ShardMetadata

View File

@@ -1,10 +1,8 @@
import time import time
from typing import Generic, TypeVar from collections.abc import Hashable
K = TypeVar("K")
class KeyedBackoff(Generic[K]): class KeyedBackoff[K: Hashable]:
"""Tracks exponential backoff state per key.""" """Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0): def __init__(self, base: float = 0.5, cap: float = 10.0):

4
uv.lock generated
View File

@@ -366,7 +366,6 @@ version = "0.3.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -387,7 +386,6 @@ dependencies = [
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
] ]
[package.dev-dependencies] [package.dev-dependencies]
@@ -403,7 +401,6 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiofiles", specifier = ">=24.1.0" }, { name = "aiofiles", specifier = ">=24.1.0" },
{ name = "aiohttp", specifier = ">=3.12.14" },
{ name = "anyio", specifier = "==4.11.0" }, { name = "anyio", specifier = "==4.11.0" },
{ name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" }, { name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" },
{ name = "fastapi", specifier = ">=0.116.1" }, { name = "fastapi", specifier = ">=0.116.1" },
@@ -424,7 +421,6 @@ requires-dist = [
{ name = "rustworkx", specifier = ">=0.17.1" }, { name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" }, { name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" }, { name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
] ]
[package.metadata.requires-dev] [package.metadata.requires-dev]