Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
e8078f5a0e wow 2026-01-20 15:07:03 +00:00
14 changed files with 332 additions and 275 deletions

View File

@@ -24,6 +24,7 @@ dependencies = [
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
]
[project.scripts]

View File

@@ -19,8 +19,12 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
get_model_card,
)
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -86,7 +90,7 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: str) -> ModelCard:
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
@@ -236,7 +240,7 @@ class API:
async def get_placement(
self,
model_id: str,
model_id: ModelId,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
@@ -551,7 +555,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(payload.model)
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -578,7 +582,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(payload.model)
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(

View File

@@ -1,16 +1,18 @@
from pydantic import PositiveInt
from typing import Annotated
from exo.shared.types.common import Id
import aiofiles
import aiofiles.os as aios
import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
_card_cache: dict[str, "ModelCard"] = {}
class ModelCard(CamelCaseModel):
@@ -20,6 +22,44 @@ class ModelCard(CamelCaseModel):
hidden_size: PositiveInt
supports_tensor: bool
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=False,
)
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
@@ -308,3 +348,116 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
}
from exo.worker.download.download_utils import ( # noqa: E402
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_card(model_id: ModelId) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_card(model_id: ModelId) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == model_id),
None,
)
return ModelCard(
model_id=model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -1,122 +0,0 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_card(model_id: str) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class PlaceInstanceParams(BaseModel):
model_id: str
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1

View File

@@ -25,6 +25,14 @@ class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -1,3 +1,8 @@
from datetime import timedelta
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
@@ -42,3 +47,50 @@ class DownloadOngoing(BaseDownloadProgress):
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)

View File

@@ -17,17 +17,20 @@ import aiohttp
import certifi
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
Field,
PositiveInt,
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.downloads import (
DownloadProgressData,
FileListEntry,
ModelSafetensorsIndex,
RepoDownloadProgress,
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -37,53 +40,6 @@ from exo.worker.download.huggingface_utils import (
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
@@ -125,12 +81,12 @@ def map_repo_download_progress_to_download_progress_data(
)
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.replace("/", "--")
def build_model_path(model_id: ModelId) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.normalize()
async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
return (await ensure_models_dir()) / model_id.normalize()
async def ensure_models_dir() -> Path:
@@ -138,8 +94,8 @@ async def ensure_models_dir() -> Path:
return EXO_MODELS_DIR
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
@@ -164,19 +120,17 @@ async def seed_models(seed_dir: str | Path):
async def fetch_file_list_with_cache(
repo_id: str, revision: str = "main", recursive: bool = False
model_id: ModelId, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
@@ -184,25 +138,25 @@ async def fetch_file_list_with_cache(
async def fetch_file_list_with_retry(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(repo_id, revision, path, recursive)
return await _fetch_file_list(model_id, revision, path, recursive)
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
async def _fetch_file_list(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_download_headers()
@@ -219,7 +173,7 @@ async def _fetch_file_list(
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
@@ -276,10 +230,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
async def file_meta(
repo_id: str, revision: str, path: str, redirected_location: str | None = None
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
) -> tuple[int, str]:
url = (
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
@@ -298,7 +252,7 @@ async def file_meta(
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
return await file_meta(model_id, revision, path, redirected_location)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -310,7 +264,7 @@ async def file_meta(
async def download_file_with_retry(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -320,23 +274,23 @@ async def download_file_with_retry(
for attempt in range(n_attempts):
try:
return await _download_file(
repo_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
async def _download_file(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -345,7 +299,7 @@ async def _download_file(
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(repo_id, revision, path)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
@@ -354,7 +308,7 @@ async def _download_file(
else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
headers = await get_download_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
@@ -394,7 +348,7 @@ async def _download_file(
def calculate_repo_progress(
shard: ShardMetadata,
repo_id: str,
model_id: ModelId,
revision: str,
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
@@ -423,7 +377,7 @@ def calculate_repo_progress(
else "not_started"
)
return RepoDownloadProgress(
repo_id=repo_id,
repo_id=model_id,
repo_revision=revision,
shard=shard,
completed_files=len(
@@ -442,11 +396,11 @@ def calculate_repo_progress(
)
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
model_id, revision, "model.safetensors.index.json", target_dir
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
@@ -478,7 +432,7 @@ async def get_downloaded_size(path: Path) -> int:
async def download_progress_for_local_path(
repo_id: str, shard: ShardMetadata, local_path: Path
model_id: ModelId, shard: ShardMetadata, local_path: Path
) -> RepoDownloadProgress:
file_progress: dict[str, RepoFileDownloadProgress] = {}
total_files = 0
@@ -492,7 +446,7 @@ async def download_progress_for_local_path(
size = (await aios.stat(file_path)).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
repo_id=model_id,
repo_revision="local",
file_path=rel_path,
downloaded=Memory.from_bytes(size),
@@ -509,7 +463,7 @@ async def download_progress_for_local_path(
raise ValueError(f"Local path {local_path} is not a directory")
return RepoDownloadProgress(
repo_id=repo_id,
repo_id=model_id,
repo_revision="local",
shard=shard,
completed_files=total_files,
@@ -539,7 +493,7 @@ async def download_shard(
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
shard.model_card.model_id, shard, local_path
)
revision = "main"
@@ -558,7 +512,7 @@ async def download_shard(
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_card.model_id), revision, recursive=True
shard.model_card.model_id, revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -592,7 +546,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -609,7 +563,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_card.model_id),
shard.model_card.model_id,
revision,
file_progress,
all_start_time,
@@ -619,7 +573,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -643,7 +597,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
str(shard.model_card.model_id),
shard.model_card.model_id,
revision,
file.path,
target_dir,
@@ -657,7 +611,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
shard, shard.model_card.model_id, revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -3,8 +3,7 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import MODEL_CARDS, ModelId, get_model_card
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -19,7 +18,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
)
async def build_base_shard(model_id: str) -> ShardMetadata:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await get_model_card(model_id)
return PipelineShardMetadata(
model_card=model_card,
@@ -31,7 +30,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata:
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
@@ -148,7 +147,7 @@ class ResumableShardDownloader(ShardDownloader):
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
async def _status_for_model(
model_id: str,
model_id: ModelId,
) -> tuple[Path, RepoDownloadProgress]:
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)

View File

@@ -23,6 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
@@ -296,7 +297,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
@@ -320,7 +321,9 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.

View File

@@ -11,8 +11,9 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
return False
async def download_tokenizer_files(model_id: str) -> Path:
async def download_tokenizer_files(model_id: ModelId) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir = await ensure_models_dir() / model_id.normalize()
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
@@ -72,22 +72,22 @@ async def download_tokenizer_files(model_id: str) -> Path:
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
def get_test_models() -> list[ModelCard]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for _, card in MODEL_CARDS.items():
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (card.model_id.short(), card)
families[family] = card
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
TEST_MODELS: list[ModelCard] = get_test_models()
pytestmark = pytest.mark.slow
@@ -101,14 +101,13 @@ def event_loop():
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
@@ -167,16 +166,15 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -209,19 +207,18 @@ async def test_tokenizer_has_required_attributes(
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -301,16 +298,14 @@ async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) ->
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_card = kimi_models[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -349,17 +344,15 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
]
if not glm_models:
if not glm_model_cards:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_card = glm_model_cards[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)

View File

@@ -1,6 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress

12
uv.lock generated
View File

@@ -248,6 +248,7 @@ dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
@@ -281,6 +282,7 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
]
@@ -315,6 +317,16 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"