switch from uvicorn to hypercorn

This commit is contained in:
Evan Quiney
2025-12-05 17:29:06 +00:00
committed by GitHub
parent e8566a3f95
commit c9e2062f6e
13 changed files with 213 additions and 144 deletions

View File

@@ -17,6 +17,8 @@
23. Do we need cache_limit? We went back and forth on that a lot because we thought it might be causing issues. One problem is it sets it relative to model size. So if you have multiple models loaded in it will take the most recent model size for the cache_limit. This is problematic if you launch DeepSeek -> Llama for example. 23. Do we need cache_limit? We went back and forth on that a lot because we thought it might be causing issues. One problem is it sets it relative to model size. So if you have multiple models loaded in it will take the most recent model size for the cache_limit. This is problematic if you launch DeepSeek -> Llama for example.
24. further openai/lmstudio api compatibility 24. further openai/lmstudio api compatibility
25. Rethink retry logic 25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
Potential refactors: Potential refactors:

View File

@@ -13,7 +13,6 @@ dependencies = [
"base58>=2.1.1", "base58>=2.1.1",
"cryptography>=45.0.5", "cryptography>=45.0.5",
"fastapi>=0.116.1", "fastapi>=0.116.1",
"uvicorn>=0.35.0",
"filelock>=3.18.0", "filelock>=3.18.0",
"aiosqlite>=0.21.0", "aiosqlite>=0.21.0",
"networkx>=3.5", "networkx>=3.5",
@@ -26,7 +25,6 @@ dependencies = [
"greenlet>=3.2.4", "greenlet>=3.2.4",
"huggingface-hub>=0.33.4", "huggingface-hub>=0.33.4",
"psutil>=7.0.0", "psutil>=7.0.0",
"cobs>=1.2.2",
"loguru>=0.7.3", "loguru>=0.7.3",
"textual>=5.3.0", "textual>=5.3.0",
"exo_pyo3_bindings", # rust bindings "exo_pyo3_bindings", # rust bindings
@@ -35,6 +33,7 @@ dependencies = [
"mlx>=0.29.3", "mlx>=0.29.3",
"mlx-lm>=0.28.3", "mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer "tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
] ]
[project.scripts] [project.scripts]

View File

@@ -6,6 +6,7 @@ from typing import Self
import anyio import anyio
from anyio.abc import TaskGroup from anyio.abc import TaskGroup
from loguru import logger
from pydantic import PositiveInt from pydantic import PositiveInt
import exo.routing.topics as topics import exo.routing.topics as topics
@@ -14,7 +15,7 @@ from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult from exo.shared.election import Election, ElectionResult
from exo.shared.logging import logger, logger_cleanup, logger_setup from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.commands import KillCommand from exo.shared.types.commands import KillCommand
from exo.shared.types.common import NodeId, SessionId from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel from exo.utils.channels import Receiver, channel

View File

@@ -1,21 +1,23 @@
import asyncio
import os import os
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import final from typing import cast
import uvicorn import anyio
from anyio import Event as AsyncTaskEvent
from anyio import create_task_group from anyio import create_task_group
from anyio.abc import TaskGroup from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger from loguru import logger
from exo.shared.apply import apply from exo.shared.apply import apply
from exo.shared.election import ElectionMessage from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import ( from exo.shared.types.api import (
@@ -46,9 +48,10 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId from exo.shared.types.worker.instances import Instance, InstanceId
from exo.utils.banner import print_startup_banner from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer from exo.utils.event_buffer import OrderedBuffer
from exo.worker.engines.mlx.constants import HIDE_THINKING
HIDE_THINKING = False
def chunk_to_response( def chunk_to_response(
@@ -76,7 +79,6 @@ async def resolve_model_meta(model_id: str) -> ModelMetadata:
return await get_model_meta(model_id) return await get_model_meta(model_id)
@final
class API: class API:
def __init__( def __init__(
self, self,
@@ -101,7 +103,7 @@ class API:
self.port = port self.port = port
self.paused: bool = False self.paused: bool = False
self.paused_ev: AsyncTaskEvent = AsyncTaskEvent() self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI() self.app = FastAPI()
self._setup_cors() self._setup_cors()
@@ -121,7 +123,7 @@ class API:
name="dashboard", name="dashboard",
) )
self._chat_completion_queues: dict[CommandId, asyncio.Queue[TokenChunk]] = {} self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._tg: TaskGroup | None = None self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int): def reset(self, new_session_id: SessionId, result_clock: int):
@@ -135,7 +137,7 @@ class API:
self.last_completed_election = result_clock self.last_completed_election = result_clock
self.paused = False self.paused = False
self.paused_ev.set() self.paused_ev.set()
self.paused_ev = AsyncTaskEvent() self.paused_ev = anyio.Event()
def _setup_cors(self) -> None: def _setup_cors(self) -> None:
self.app.add_middleware( self.app.add_middleware(
@@ -210,37 +212,40 @@ class API:
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings.""" """Generate chat completion stream as JSON strings."""
self._chat_completion_queues[command_id] = asyncio.Queue() try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
finished = False is_thinking = False
is_thinking = False with recv as token_chunks:
while not finished: async for chunk in token_chunks:
# TODO: how long should this timeout be? if HIDE_THINKING:
chunk = await asyncio.wait_for( if chunk.text == "<think>":
self._chat_completion_queues[command_id].get(), timeout=600 is_thinking = True
if chunk.text == "</think>":
is_thinking = False
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
if not (is_thinking and HIDE_THINKING):
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
) )
assert isinstance(chunk, TokenChunk) """
# TODO: Do we want this? raise
if HIDE_THINKING: finally:
if chunk.text == "<think>": command = TaskFinished(finished_command_id=command_id)
chunk.text = "\n" await self._send(command)
if chunk.text == "</think>": del self._chat_completion_queues[command_id]
chunk.text = "\n"
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
if not HIDE_THINKING or not is_thinking:
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
finished = True
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
del self._chat_completion_queues[command_id]
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None: async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning( logger.warning(
@@ -298,30 +303,28 @@ class API:
) )
async def run(self): async def run(self):
uvicorn_config = uvicorn.Config( cfg = Config()
self.app, host="0.0.0.0", port=self.port, access_log=False cfg.bind = f"0.0.0.0:{self.port}"
) # nb: shared.logging needs updating if any of this changes
uvicorn_server = uvicorn.Server(uvicorn_config) cfg.accesslog = None
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
async with create_task_group() as tg: async with create_task_group() as tg:
self._tg = tg self._tg = tg
logger.info("Starting API") logger.info("Starting API")
tg.start_soon(uvicorn_server.serve)
tg.start_soon(self._apply_state) tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election) tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._print_banner_when_ready, uvicorn_server) print_startup_banner(self.port)
await serve(
cast(ASGIFramework, self.app),
cfg,
shutdown_trigger=lambda: anyio.sleep_forever(),
)
self.command_sender.close() self.command_sender.close()
self.global_event_receiver.close() self.global_event_receiver.close()
async def _print_banner_when_ready(self, uvicorn_server: uvicorn.Server):
"""Wait for the uvicorn server to be ready, then print the startup banner."""
# TODO: Is this the best condition to check for?
# The point is this should log when exo is ready.
while not uvicorn_server.started:
await asyncio.sleep(0.1)
print_startup_banner(self.port)
async def _apply_state(self): async def _apply_state(self):
with self.global_event_receiver as events: with self.global_event_receiver as events:
async for f_event in events: async for f_event in events:
@@ -333,7 +336,7 @@ class API:
and event.command_id in self._chat_completion_queues and event.command_id in self._chat_completion_queues
): ):
assert isinstance(event.chunk, TokenChunk) assert isinstance(event.chunk, TokenChunk)
self._chat_completion_queues[event.command_id].put_nowait( await self._chat_completion_queues[event.command_id].send(
event.chunk event.chunk
) )

View File

@@ -1,4 +1,3 @@
import inspect
import os import os
from pathlib import Path from pathlib import Path
@@ -34,19 +33,3 @@ LIBP2P_COMMANDS_TOPIC = "commands"
LB_TFLOPS = 2.3 LB_TFLOPS = 2.3
LB_MEMBW_GBPS = 68 LB_MEMBW_GBPS = 68
LB_DISK_GBPS = 1.5 LB_DISK_GBPS = 1.5
# little helper function to get the name of the module that raised the error
def get_caller_module_name() -> str:
frm = inspect.stack()[1]
mod = inspect.getmodule(frm[0])
if mod is None:
return "UNKNOWN MODULE"
return mod.__name__
def get_error_reporting_message() -> str:
return (
f"THIS IS A BUG IN THE EXO SOFTWARE, PLEASE REPORT IT AT https://github.com/exo-explore/exo/\n"
f"The module that raised the error was: {get_caller_module_name()}"
)

View File

@@ -1,12 +1,39 @@
import logging
import sys import sys
from pathlib import Path from pathlib import Path
from hypercorn import Config
from hypercorn.logging import Logger as HypercornLogger
from loguru import logger from loguru import logger
class InterceptLogger(HypercornLogger):
def __init__(self, config: Config):
super().__init__(config)
assert self.error_logger
# TODO: Decide if we want to provide access logs
# assert self.access_logger
# self.access_logger.handlers = [_InterceptHandler()]
self.error_logger.handlers = [_InterceptHandler()]
class _InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord):
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())
def logger_setup(log_file: Path | None, verbosity: int = 0): def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output""" """Set up logging for this process - formatting, file handles, verbosity and output"""
logger.remove() logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru
logging.basicConfig(handlers=[_InterceptHandler()], level=0)
if verbosity == 0: if verbosity == 0:
logger.add( logger.add(
sys.__stderr__, # type: ignore sys.__stderr__, # type: ignore
@@ -37,3 +64,29 @@ def logger_setup(log_file: Path | None, verbosity: int = 0):
def logger_cleanup(): def logger_cleanup():
"""Flush all queues before shutting down so any in-flight logs are written to disk""" """Flush all queues before shutting down so any in-flight logs are written to disk"""
logger.complete() logger.complete()
""" --- TODO: Capture MLX Log output:
import contextlib
import sys
from loguru import logger
class StreamToLogger:
def __init__(self, level="INFO"):
self._level = level
def write(self, buffer):
for line in buffer.rstrip().splitlines():
logger.opt(depth=1).log(self._level, line.rstrip())
def flush(self):
pass
logger.remove()
logger.add(sys.__stdout__)
stream = StreamToLogger()
with contextlib.redirect_stdout(stream):
print("Standard output is sent to added handlers.")
"""

View File

@@ -4,6 +4,8 @@ import asyncio
from typing import Generator from typing import Generator
import pytest import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.models import ModelId, ModelMetadata
@@ -41,3 +43,16 @@ def get_pipeline_shard_metadata(
end_layer=32, end_layer=32,
n_layers=32, n_layers=32,
) )
@pytest.fixture
def caplog(caplog: LogCaptureFixture):
handler_id = logger.add(
caplog.handler,
format="{message}",
level=0,
filter=lambda record: record["level"].no >= caplog.handler.level,
enqueue=True, # Set to 'True' if your test is spawning child processes.
)
yield caplog
logger.remove(handler_id)

View File

@@ -1,5 +1,4 @@
import contextlib import contextlib
import logging
import multiprocessing import multiprocessing
import os import os
from multiprocessing import Event, Queue, Semaphore from multiprocessing import Event, Queue, Semaphore
@@ -8,6 +7,7 @@ from multiprocessing.queues import Queue as QueueT
from multiprocessing.synchronize import Event as EventT from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT from multiprocessing.synchronize import Semaphore as SemaphoreT
from loguru import logger
from pytest import LogCaptureFixture from pytest import LogCaptureFixture
from exo.routing.router import get_node_id_keypair from exo.routing.router import get_node_id_keypair
@@ -17,20 +17,13 @@ NUM_CONCURRENT_PROCS = 10
def _get_keypair_concurrent_subprocess_task( def _get_keypair_concurrent_subprocess_task(
pid: int, sem: SemaphoreT, ev: EventT, queue: QueueT[bytes] sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]
) -> None: ) -> None:
try: # synchronise with parent process
# synchronise with parent process sem.release()
logging.info(msg=f"SUBPROCESS {pid}: Started") # wait to be told to begin simultaneous read
sem.release() ev.wait()
queue.put(get_node_id_keypair().to_protobuf_encoding())
# wait to be told to begin simultaneous read
ev.wait()
logging.info(msg=f"SUBPROCESS {pid}: Reading start")
queue.put(get_node_id_keypair().to_protobuf_encoding())
logging.info(msg=f"SUBPROCESS {pid}: Reading end")
except Exception as e:
logging.error(msg=f"SUBPROCESS {pid}: Error encountered: {e}")
def _get_keypair_concurrent(num_procs: int) -> bytes: def _get_keypair_concurrent(num_procs: int) -> bytes:
@@ -41,11 +34,11 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
queue: QueueT[bytes] = Queue(maxsize=num_procs) queue: QueueT[bytes] = Queue(maxsize=num_procs)
# make parent process wait for all subprocesses to start # make parent process wait for all subprocesses to start
logging.info(msg=f"PARENT: Starting {num_procs} subprocesses") logger.info(f"PARENT: Starting {num_procs} subprocesses")
ps: list[BaseProcess] = [] ps: list[BaseProcess] = []
for i in range(num_procs): for _ in range(num_procs):
p = multiprocessing.get_context("fork").Process( p = multiprocessing.get_context("fork").Process(
target=_get_keypair_concurrent_subprocess_task, args=(i + 1, sem, ev, queue) target=_get_keypair_concurrent_subprocess_task, args=(sem, ev, queue)
) )
ps.append(p) ps.append(p)
p.start() p.start()
@@ -53,7 +46,7 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
sem.acquire() sem.acquire()
# start all the sub processes simultaneously # start all the sub processes simultaneously
logging.info(msg="PARENT: Beginning read") logger.info("PARENT: Beginning read")
ev.set() ev.set()
# wait until all subprocesses are done & read results # wait until all subprocesses are done & read results
@@ -62,7 +55,7 @@ def _get_keypair_concurrent(num_procs: int) -> bytes:
# check that the input/output order match, and that # check that the input/output order match, and that
# all subprocesses end up reading the same file # all subprocesses end up reading the same file
logging.info(msg="PARENT: Checking consistency") logger.info("PARENT: Checking consistency")
keypair: bytes | None = None keypair: bytes | None = None
qsize = 0 # cannot use Queue.qsize due to MacOS incompatibility :( qsize = 0 # cannot use Queue.qsize due to MacOS incompatibility :(
while not queue.empty(): while not queue.empty():
@@ -88,7 +81,7 @@ def test_node_id_fetching(caplog: LogCaptureFixture):
_delete_if_exists(EXO_NODE_ID_KEYPAIR) _delete_if_exists(EXO_NODE_ID_KEYPAIR)
kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS) kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)
with caplog.at_level(logging.CRITICAL): # supress logs with caplog.at_level(101): # supress logs
# make sure that continuous fetches return the same value # make sure that continuous fetches return the same value
for _ in range(reps): for _ in range(reps):
assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS) assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)

View File

@@ -1,6 +1,7 @@
def print_startup_banner(port: int) -> None: def print_startup_banner(port: int) -> None:
"""Print a prominent startup banner with API endpoint information.""" """Print a prominent startup banner with API endpoint information."""
banner = """ dashboard_url = f"http://localhost:{port}"
banner = f"""
╔═══════════════════════════════════════════════════════════════════════╗ ╔═══════════════════════════════════════════════════════════════════════╗
║ ║ ║ ║
║ ███████╗██╗ ██╗ ██████╗ ║ ║ ███████╗██╗ ██╗ ██████╗ ║
@@ -13,11 +14,7 @@ def print_startup_banner(port: int) -> None:
║ Distributed AI Inference Cluster ║ ║ Distributed AI Inference Cluster ║
║ ║ ║ ║
╚═══════════════════════════════════════════════════════════════════════╝ ╚═══════════════════════════════════════════════════════════════════════╝
"""
dashboard_url = f"http://localhost:{port}"
api_info = f"""
╔═══════════════════════════════════════════════════════════════════════╗ ╔═══════════════════════════════════════════════════════════════════════╗
║ ║ ║ ║
║ 🌐 Dashboard & API Ready ║ ║ 🌐 Dashboard & API Ready ║
@@ -27,8 +24,7 @@ def print_startup_banner(port: int) -> None:
║ Click the URL above to open the dashboard in your browser ║ ║ Click the URL above to open the dashboard in your browser ║
║ ║ ║ ║
╚═══════════════════════════════════════════════════════════════════════╝ ╚═══════════════════════════════════════════════════════════════════════╝
""" """
print(banner) print(banner)
print(api_info)
print()

View File

@@ -77,9 +77,6 @@ class _MpEndOfStream:
pass pass
MP_END_OF_STREAM = _MpEndOfStream()
class MpState[T]: class MpState[T]:
def __init__(self, max_buffer_size: float): def __init__(self, max_buffer_size: float):
if max_buffer_size == inf: if max_buffer_size == inf:
@@ -133,7 +130,7 @@ class MpSender[T]:
def close(self) -> None: def close(self) -> None:
if not self._state.closed.is_set(): if not self._state.closed.is_set():
self._state.closed.set() self._state.closed.set()
self._state.buffer.put(MP_END_OF_STREAM) self._state.buffer.put(_MpEndOfStream())
self._state.buffer.close() self._state.buffer.close()
# == unique to Mp channels == # == unique to Mp channels ==
@@ -177,10 +174,9 @@ class MpReceiver[T]:
try: try:
item = self._state.buffer.get(block=False) item = self._state.buffer.get(block=False)
if item == MP_END_OF_STREAM: if isinstance(item, _MpEndOfStream):
self.close() self.close()
raise EndOfStream raise EndOfStream
assert not isinstance(item, _MpEndOfStream)
return item return item
except Empty: except Empty:
raise WouldBlock from None raise WouldBlock from None
@@ -193,10 +189,9 @@ class MpReceiver[T]:
return self.receive_nowait() return self.receive_nowait()
except WouldBlock: except WouldBlock:
item = self._state.buffer.get() item = self._state.buffer.get()
if item == MP_END_OF_STREAM: if isinstance(item, _MpEndOfStream):
self.close() self.close()
raise EndOfStream from None raise EndOfStream from None
assert not isinstance(item, _MpEndOfStream)
return item return item
# nb: this function will not cancel particularly well # nb: this function will not cancel particularly well

View File

@@ -20,7 +20,6 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: str) -> ShardMetadata: async def build_base_shard(model_id: str) -> ShardMetadata:
model_meta = await get_model_meta(model_id) model_meta = await get_model_meta(model_id)
# print(f"build_base_shard {model_id=} {model_meta=}")
return PipelineShardMetadata( return PipelineShardMetadata(
model_meta=model_meta, model_meta=model_meta,
device_rank=0, device_rank=0,
@@ -92,10 +91,8 @@ class CachedShardDownloader(ShardDownloader):
self, shard: ShardMetadata, config_only: bool = False self, shard: ShardMetadata, config_only: bool = False
) -> Path: ) -> Path:
if (shard.model_meta.model_id, shard) in self.cache: if (shard.model_meta.model_id, shard) in self.cache:
# print(f"ensure_shard cache hit {shard=}")
return self.cache[(shard.model_meta.model_id, shard)] return self.cache[(shard.model_meta.model_id, shard)]
# print(f"ensure_shard cache miss {shard=}")
target_dir = await self.shard_downloader.ensure_shard(shard, config_only) target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_meta.model_id, shard)] = target_dir self.cache[(shard.model_meta.model_id, shard)] = target_dir
return target_dir return target_dir
@@ -135,7 +132,6 @@ class ResumableShardDownloader(ShardDownloader):
) -> Path: ) -> Path:
allow_patterns = ["config.json"] if config_only else None allow_patterns = ["config.json"] if config_only else None
# print(f"ensure_shard {shard=} {config_only=} {allow_patterns=}")
target_dir, _ = await download_shard( target_dir, _ = await download_shard(
shard, shard,
self.on_progress_wrapper, self.on_progress_wrapper,
@@ -147,7 +143,6 @@ class ResumableShardDownloader(ShardDownloader):
async def get_shard_download_status( async def get_shard_download_status(
self, self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]: ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
# print("get_shard_download_status")
async def _status_for_model( async def _status_for_model(
model_id: str, model_id: str,
) -> tuple[Path, RepoDownloadProgress]: ) -> tuple[Path, RepoDownloadProgress]:
@@ -165,9 +160,8 @@ class ResumableShardDownloader(ShardDownloader):
for task in asyncio.as_completed(tasks): for task in asyncio.as_completed(tasks):
try: try:
result = await task yield await task
path, progress = result # TODO: except Exception
yield (path, progress)
except Exception as e: except Exception as e:
print("Error downloading shard:", e) print("Error downloading shard:", e)

View File

@@ -14,5 +14,3 @@ TEMPERATURE: float = 1.0
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True # TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True TRUST_REMOTE_CODE: bool = True
# TODO: Do we really want this?
HIDE_THINKING: bool = False

85
uv.lock generated
View File

@@ -258,21 +258,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" },
] ]
[[package]]
name = "click"
version = "8.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/46/61/de6cd827efad202d7057d93e0fed9294b96952e188f7384832791c7b2254/click-8.3.0.tar.gz", hash = "sha256:e7b8232224eba16f4ebe410c25ced9f7875cb5f3263ffc93cc3e8da705e229c4", size = 276943, upload-time = "2025-09-18T17:32:23.696Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/d3/9dcc0f5797f070ec8edf30fbadfb200e71d9db6b84d211e3b2085a7589a0/click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc", size = 107295, upload-time = "2025-09-18T17:32:22.42Z" },
]
[[package]]
name = "cobs"
version = "1.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/34/ef/ea149311227a4fc3160cc885fce06da7c7d76782a308ef070b8065c69953/cobs-1.2.2.tar.gz", hash = "sha256:dbdd5e32111d72786f83d0c269215dcd6ac629b1ac1962c6878221f3b2ca98da", size = 14582, upload-time = "2025-07-20T01:08:35.434Z" }
[[package]] [[package]]
name = "cryptography" name = "cryptography"
version = "46.0.3" version = "46.0.3"
@@ -331,13 +316,13 @@ dependencies = [
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "base58", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "base58", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "bidict", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "bidict", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "cobs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "cryptography", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "cryptography", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -354,7 +339,6 @@ dependencies = [
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typeguard", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typeguard", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "uvicorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
] ]
[package.dev-dependencies] [package.dev-dependencies]
@@ -373,13 +357,13 @@ requires-dist = [
{ name = "anyio", specifier = "==4.11.0" }, { name = "anyio", specifier = "==4.11.0" },
{ name = "base58", specifier = ">=2.1.1" }, { name = "base58", specifier = ">=2.1.1" },
{ name = "bidict", specifier = ">=0.23.1" }, { name = "bidict", specifier = ">=0.23.1" },
{ name = "cobs", specifier = ">=1.2.2" },
{ name = "cryptography", specifier = ">=45.0.5" }, { name = "cryptography", specifier = ">=45.0.5" },
{ name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" }, { name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" },
{ name = "fastapi", specifier = ">=0.116.1" }, { name = "fastapi", specifier = ">=0.116.1" },
{ name = "filelock", specifier = ">=3.18.0" }, { name = "filelock", specifier = ">=3.18.0" },
{ name = "greenlet", specifier = ">=3.2.4" }, { name = "greenlet", specifier = ">=3.2.4" },
{ name = "huggingface-hub", specifier = ">=0.33.4" }, { name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" }, { name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", specifier = ">=0.29.3" }, { name = "mlx", specifier = ">=0.29.3" },
{ name = "mlx-lm", specifier = ">=0.28.3" }, { name = "mlx-lm", specifier = ">=0.28.3" },
@@ -396,7 +380,6 @@ requires-dist = [
{ name = "tiktoken", specifier = ">=0.12.0" }, { name = "tiktoken", specifier = ">=0.12.0" },
{ name = "typeguard", specifier = ">=4.4.4" }, { name = "typeguard", specifier = ">=4.4.4" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" }, { name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
{ name = "uvicorn", specifier = ">=0.35.0" },
] ]
[package.metadata.requires-dev] [package.metadata.requires-dev]
@@ -557,6 +540,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
] ]
[[package]]
name = "h2"
version = "4.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "hpack", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "hyperframe", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" },
]
[[package]] [[package]]
name = "hf-xet" name = "hf-xet"
version = "1.2.0" version = "1.2.0"
@@ -583,6 +579,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" }, { url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" },
] ]
[[package]]
name = "hpack"
version = "4.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" },
]
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.36.0" version = "0.36.0"
@@ -602,6 +607,30 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d", size = 566094, upload-time = "2025-10-23T12:11:59.557Z" }, { url = "https://files.pythonhosted.org/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d", size = 566094, upload-time = "2025-10-23T12:11:59.557Z" },
] ]
[[package]]
name = "hypercorn"
version = "0.18.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "h2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "priority", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "wsproto", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/44/01/39f41a014b83dd5c795217362f2ca9071cf243e6a75bdcd6cd5b944658cc/hypercorn-0.18.0.tar.gz", hash = "sha256:d63267548939c46b0247dc8e5b45a9947590e35e64ee73a23c074aa3cf88e9da", size = 68420, upload-time = "2025-11-08T13:54:04.78Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/93/35/850277d1b17b206bd10874c8a9a3f52e059452fb49bb0d22cbb908f6038b/hypercorn-0.18.0-py3-none-any.whl", hash = "sha256:225e268f2c1c2f28f6d8f6db8f40cb8c992963610c5725e13ccfcddccb24b1cd", size = 61640, upload-time = "2025-11-08T13:54:03.202Z" },
]
[[package]]
name = "hyperframe"
version = "6.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" },
]
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.11" version = "3.11"
@@ -926,6 +955,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
] ]
[[package]]
name = "priority"
version = "2.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f5/3c/eb7c35f4dcede96fca1842dac5f4f5d15511aa4b52f3a961219e68ae9204/priority-2.0.0.tar.gz", hash = "sha256:c965d54f1b8d0d0b19479db3924c7c36cf672dbf2aec92d43fbdaf4492ba18c0", size = 24792, upload-time = "2021-06-27T10:15:05.487Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5e/5f/82c8074f7e84978129347c2c6ec8b6c59f3584ff1a20bc3c940a3e061790/priority-2.0.0-py3-none-any.whl", hash = "sha256:6f8eefce5f3ad59baf2c080a664037bb4725cd0a790d53d59ab4059288faf6aa", size = 8946, upload-time = "2021-06-27T10:15:03.856Z" },
]
[[package]] [[package]]
name = "propcache" name = "propcache"
version = "0.4.1" version = "0.4.1"
@@ -1524,16 +1562,15 @@ wheels = [
] ]
[[package]] [[package]]
name = "uvicorn" name = "wsproto"
version = "0.38.0" version = "1.3.2"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "click", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/f06b84e2697fef4688ca63bdb2fdf113ca0a3be33f94488f2cadb690b0cf/uvicorn-0.38.0.tar.gz", hash = "sha256:fd97093bdd120a2609fc0d3afe931d4d4ad688b6e75f0f929fde1bc36fe0e91d", size = 80605, upload-time = "2025-10-18T13:46:44.63Z" } sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/ee/d9/d88e73ca598f4f6ff671fb5fde8a32925c2e08a637303a1d12883c7305fa/uvicorn-0.38.0-py3-none-any.whl", hash = "sha256:48c0afd214ceb59340075b4a052ea1ee91c16fbc2a9b1469cca0e54566977b02", size = 68109, upload-time = "2025-10-18T13:46:42.958Z" }, { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" },
] ]
[[package]] [[package]]