mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-21 04:22:21 -05:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
758464703d | ||
|
|
9e2179c848 | ||
|
|
22b5d836ef |
@@ -24,6 +24,7 @@ dependencies = [
|
|||||||
"hypercorn>=0.18.0",
|
"hypercorn>=0.18.0",
|
||||||
"openai-harmony>=0.0.8",
|
"openai-harmony>=0.0.8",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
|
"tomlkit>=0.14.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -19,8 +19,11 @@ from exo.master.placement import place_instance as get_instance_placements
|
|||||||
from exo.shared.apply import apply
|
from exo.shared.apply import apply
|
||||||
from exo.shared.election import ElectionMessage
|
from exo.shared.election import ElectionMessage
|
||||||
from exo.shared.logging import InterceptLogger
|
from exo.shared.logging import InterceptLogger
|
||||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
from exo.shared.models.model_cards import (
|
||||||
from exo.shared.models.model_meta import get_model_card
|
MODEL_CARDS,
|
||||||
|
ModelCard,
|
||||||
|
ModelId,
|
||||||
|
)
|
||||||
from exo.shared.types.api import (
|
from exo.shared.types.api import (
|
||||||
BenchChatCompletionResponse,
|
BenchChatCompletionResponse,
|
||||||
BenchChatCompletionTaskParams,
|
BenchChatCompletionTaskParams,
|
||||||
@@ -86,12 +89,12 @@ def chunk_to_response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def resolve_model_card(model_id: str) -> ModelCard:
|
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||||
if model_id in MODEL_CARDS:
|
if model_id in MODEL_CARDS:
|
||||||
model_card = MODEL_CARDS[model_id]
|
model_card = MODEL_CARDS[model_id]
|
||||||
return model_card
|
return model_card
|
||||||
else:
|
else:
|
||||||
return await get_model_card(model_id)
|
return await ModelCard.from_hf(model_id)
|
||||||
|
|
||||||
|
|
||||||
class API:
|
class API:
|
||||||
@@ -236,7 +239,7 @@ class API:
|
|||||||
|
|
||||||
async def get_placement(
|
async def get_placement(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: ModelId,
|
||||||
sharding: Sharding = Sharding.Pipeline,
|
sharding: Sharding = Sharding.Pipeline,
|
||||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||||
min_nodes: int = 1,
|
min_nodes: int = 1,
|
||||||
@@ -551,7 +554,7 @@ class API:
|
|||||||
self, payload: ChatCompletionTaskParams
|
self, payload: ChatCompletionTaskParams
|
||||||
) -> ChatCompletionResponse | StreamingResponse:
|
) -> ChatCompletionResponse | StreamingResponse:
|
||||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||||
model_card = await resolve_model_card(payload.model)
|
model_card = await resolve_model_card(ModelId(payload.model))
|
||||||
payload.model = model_card.model_id
|
payload.model = model_card.model_id
|
||||||
|
|
||||||
if not any(
|
if not any(
|
||||||
@@ -578,7 +581,7 @@ class API:
|
|||||||
async def bench_chat_completions(
|
async def bench_chat_completions(
|
||||||
self, payload: BenchChatCompletionTaskParams
|
self, payload: BenchChatCompletionTaskParams
|
||||||
) -> BenchChatCompletionResponse:
|
) -> BenchChatCompletionResponse:
|
||||||
model_card = await resolve_model_card(payload.model)
|
model_card = await resolve_model_card(ModelId(payload.model))
|
||||||
payload.model = model_card.model_id
|
payload.model = model_card.model_id
|
||||||
|
|
||||||
if not any(
|
if not any(
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
from pydantic import PositiveInt
|
from typing import Annotated
|
||||||
|
|
||||||
from exo.shared.types.common import Id
|
import aiofiles
|
||||||
|
import aiofiles.os as aios
|
||||||
|
import tomlkit
|
||||||
|
from anyio import Path, open_file
|
||||||
|
from huggingface_hub import model_info
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel, Field, PositiveInt
|
||||||
|
|
||||||
|
from exo.shared.types.common import ModelId
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel
|
from exo.utils.pydantic_ext import CamelCaseModel
|
||||||
|
|
||||||
|
_card_cache: dict[str, "ModelCard"] = {}
|
||||||
class ModelId(Id):
|
|
||||||
def normalize(self) -> str:
|
|
||||||
return self.replace("/", "--")
|
|
||||||
|
|
||||||
def short(self) -> str:
|
|
||||||
return self.split("/")[-1]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(CamelCaseModel):
|
class ModelCard(CamelCaseModel):
|
||||||
@@ -20,6 +22,43 @@ class ModelCard(CamelCaseModel):
|
|||||||
hidden_size: PositiveInt
|
hidden_size: PositiveInt
|
||||||
supports_tensor: bool
|
supports_tensor: bool
|
||||||
|
|
||||||
|
async def save(self, path: Path) -> None:
|
||||||
|
async with await open_file(path, "w") as f:
|
||||||
|
py = self.model_dump()
|
||||||
|
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
await f.write(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def load_from_path(path: Path) -> "ModelCard":
|
||||||
|
async with await open_file(path, "r") as f:
|
||||||
|
py = tomlkit.loads(await f.read())
|
||||||
|
return ModelCard.model_validate(py)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def load(model_id: ModelId) -> "ModelCard":
|
||||||
|
if model_id in MODEL_CARDS:
|
||||||
|
return MODEL_CARDS[model_id]
|
||||||
|
return await ModelCard.from_hf(model_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def from_hf(model_id: ModelId) -> "ModelCard":
|
||||||
|
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||||
|
if (mc := _card_cache.get(model_id)) is not None:
|
||||||
|
return mc
|
||||||
|
config_data = await get_config_data(model_id)
|
||||||
|
num_layers = config_data.layer_count
|
||||||
|
mem_size_bytes = await get_safetensors_size(model_id)
|
||||||
|
|
||||||
|
mc = ModelCard(
|
||||||
|
model_id=ModelId(model_id),
|
||||||
|
storage_size=mem_size_bytes,
|
||||||
|
n_layers=num_layers,
|
||||||
|
hidden_size=config_data.hidden_size or 0,
|
||||||
|
supports_tensor=config_data.supports_tensor,
|
||||||
|
)
|
||||||
|
_card_cache[model_id] = mc
|
||||||
|
return mc
|
||||||
|
|
||||||
|
|
||||||
MODEL_CARDS: dict[str, ModelCard] = {
|
MODEL_CARDS: dict[str, ModelCard] = {
|
||||||
# deepseek v3
|
# deepseek v3
|
||||||
@@ -308,3 +347,99 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
|||||||
supports_tensor=True,
|
supports_tensor=True,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
from exo.worker.download.download_utils import ( # noqa: E402
|
||||||
|
ModelSafetensorsIndex,
|
||||||
|
download_file_with_retry,
|
||||||
|
ensure_models_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigData(BaseModel):
|
||||||
|
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||||
|
|
||||||
|
# Common field names for number of layers across different architectures
|
||||||
|
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||||
|
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||||
|
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||||
|
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||||
|
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||||
|
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||||
|
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||||
|
architectures: list[str] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_tensor(self) -> bool:
|
||||||
|
return self.architectures in [
|
||||||
|
["Glm4MoeLiteForCausalLM"],
|
||||||
|
["DeepseekV32ForCausalLM"],
|
||||||
|
["DeepseekV3ForCausalLM"],
|
||||||
|
["Qwen3NextForCausalLM"],
|
||||||
|
["Qwen3MoeForCausalLM"],
|
||||||
|
["MiniMaxM2ForCausalLM"],
|
||||||
|
["LlamaForCausalLM"],
|
||||||
|
["GptOssForCausalLM"],
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layer_count(self) -> int:
|
||||||
|
# Check common field names for layer count
|
||||||
|
layer_fields = [
|
||||||
|
self.num_hidden_layers,
|
||||||
|
self.num_layers,
|
||||||
|
self.n_layer,
|
||||||
|
self.n_layers,
|
||||||
|
self.num_decoder_layers,
|
||||||
|
self.decoder_layers,
|
||||||
|
]
|
||||||
|
|
||||||
|
for layer_count in layer_fields:
|
||||||
|
if layer_count is not None:
|
||||||
|
return layer_count
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||||
|
"""Downloads and parses config.json for a model."""
|
||||||
|
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||||
|
await aios.makedirs(target_dir, exist_ok=True)
|
||||||
|
config_path = await download_file_with_retry(
|
||||||
|
model_id,
|
||||||
|
"main",
|
||||||
|
"config.json",
|
||||||
|
target_dir,
|
||||||
|
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||||
|
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async with aiofiles.open(config_path, "r") as f:
|
||||||
|
return ConfigData.model_validate_json(await f.read())
|
||||||
|
|
||||||
|
|
||||||
|
async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||||
|
"""Gets model size from safetensors index or falls back to HF API."""
|
||||||
|
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||||
|
await aios.makedirs(target_dir, exist_ok=True)
|
||||||
|
index_path = await download_file_with_retry(
|
||||||
|
model_id,
|
||||||
|
"main",
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
target_dir,
|
||||||
|
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||||
|
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async with aiofiles.open(index_path, "r") as f:
|
||||||
|
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||||
|
|
||||||
|
metadata = index_data.metadata
|
||||||
|
if metadata is not None:
|
||||||
|
return Memory.from_bytes(metadata.total_size)
|
||||||
|
|
||||||
|
info = model_info(model_id)
|
||||||
|
if info.safetensors is None:
|
||||||
|
raise ValueError(f"No safetensors info found for {model_id}")
|
||||||
|
return Memory.from_bytes(info.safetensors.total)
|
||||||
|
|||||||
@@ -1,122 +0,0 @@
|
|||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import aiofiles
|
|
||||||
import aiofiles.os as aios
|
|
||||||
from huggingface_hub import model_info
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
|
||||||
from exo.shared.types.memory import Memory
|
|
||||||
from exo.worker.download.download_utils import (
|
|
||||||
ModelSafetensorsIndex,
|
|
||||||
download_file_with_retry,
|
|
||||||
ensure_models_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigData(BaseModel):
|
|
||||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
|
||||||
|
|
||||||
# Common field names for number of layers across different architectures
|
|
||||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
|
||||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
|
||||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
|
||||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
|
||||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
|
||||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
|
||||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def layer_count(self) -> int:
|
|
||||||
# Check common field names for layer count
|
|
||||||
layer_fields = [
|
|
||||||
self.num_hidden_layers,
|
|
||||||
self.num_layers,
|
|
||||||
self.n_layer,
|
|
||||||
self.n_layers,
|
|
||||||
self.num_decoder_layers,
|
|
||||||
self.decoder_layers,
|
|
||||||
]
|
|
||||||
|
|
||||||
for layer_count in layer_fields:
|
|
||||||
if layer_count is not None:
|
|
||||||
return layer_count
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_config_data(model_id: str) -> ConfigData:
|
|
||||||
"""Downloads and parses config.json for a model."""
|
|
||||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
|
||||||
await aios.makedirs(target_dir, exist_ok=True)
|
|
||||||
config_path = await download_file_with_retry(
|
|
||||||
model_id,
|
|
||||||
"main",
|
|
||||||
"config.json",
|
|
||||||
target_dir,
|
|
||||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
|
||||||
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async with aiofiles.open(config_path, "r") as f:
|
|
||||||
return ConfigData.model_validate_json(await f.read())
|
|
||||||
|
|
||||||
|
|
||||||
async def get_safetensors_size(model_id: str) -> Memory:
|
|
||||||
"""Gets model size from safetensors index or falls back to HF API."""
|
|
||||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
|
||||||
await aios.makedirs(target_dir, exist_ok=True)
|
|
||||||
index_path = await download_file_with_retry(
|
|
||||||
model_id,
|
|
||||||
"main",
|
|
||||||
"model.safetensors.index.json",
|
|
||||||
target_dir,
|
|
||||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
|
||||||
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async with aiofiles.open(index_path, "r") as f:
|
|
||||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
|
||||||
|
|
||||||
metadata = index_data.metadata
|
|
||||||
if metadata is not None:
|
|
||||||
return Memory.from_bytes(metadata.total_size)
|
|
||||||
|
|
||||||
info = model_info(model_id)
|
|
||||||
if info.safetensors is None:
|
|
||||||
raise ValueError(f"No safetensors info found for {model_id}")
|
|
||||||
return Memory.from_bytes(info.safetensors.total)
|
|
||||||
|
|
||||||
|
|
||||||
_model_card_cache: dict[str, ModelCard] = {}
|
|
||||||
|
|
||||||
|
|
||||||
async def get_model_card(model_id: str) -> ModelCard:
|
|
||||||
if model_id in _model_card_cache:
|
|
||||||
return _model_card_cache[model_id]
|
|
||||||
model_card = await _get_model_card(model_id)
|
|
||||||
_model_card_cache[model_id] = model_card
|
|
||||||
return model_card
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_model_card(model_id: str) -> ModelCard:
|
|
||||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
|
||||||
config_data = await get_config_data(model_id)
|
|
||||||
num_layers = config_data.layer_count
|
|
||||||
mem_size_bytes = await get_safetensors_size(model_id)
|
|
||||||
model_card = next(
|
|
||||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ModelCard(
|
|
||||||
model_id=ModelId(model_id),
|
|
||||||
storage_size=mem_size_bytes,
|
|
||||||
n_layers=num_layers,
|
|
||||||
hidden_size=config_data.hidden_size or 0,
|
|
||||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
|
||||||
supports_tensor=model_card.supports_tensor if model_card is not None else False,
|
|
||||||
)
|
|
||||||
@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
|||||||
|
|
||||||
|
|
||||||
class PlaceInstanceParams(BaseModel):
|
class PlaceInstanceParams(BaseModel):
|
||||||
model_id: str
|
model_id: ModelId
|
||||||
sharding: Sharding = Sharding.Pipeline
|
sharding: Sharding = Sharding.Pipeline
|
||||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||||
min_nodes: int = 1
|
min_nodes: int = 1
|
||||||
|
|||||||
@@ -25,6 +25,14 @@ class NodeId(Id):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelId(Id):
|
||||||
|
def normalize(self) -> str:
|
||||||
|
return self.replace("/", "--")
|
||||||
|
|
||||||
|
def short(self) -> str:
|
||||||
|
return self.split("/")[-1]
|
||||||
|
|
||||||
|
|
||||||
class SessionId(CamelCaseModel):
|
class SessionId(CamelCaseModel):
|
||||||
master_node_id: NodeId
|
master_node_id: NodeId
|
||||||
election_clock: int
|
election_clock: int
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
|
||||||
|
|
||||||
from exo.shared.types.common import NodeId
|
from exo.shared.types.common import NodeId
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.worker.shards import ShardMetadata
|
from exo.shared.types.worker.shards import ShardMetadata
|
||||||
@@ -42,3 +47,50 @@ class DownloadOngoing(BaseDownloadProgress):
|
|||||||
DownloadProgress = (
|
DownloadProgress = (
|
||||||
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
|
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSafetensorsIndexMetadata(BaseModel):
|
||||||
|
total_size: PositiveInt
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSafetensorsIndex(BaseModel):
|
||||||
|
metadata: ModelSafetensorsIndexMetadata | None
|
||||||
|
weight_map: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class FileListEntry(BaseModel):
|
||||||
|
type: Literal["file", "directory"]
|
||||||
|
path: str
|
||||||
|
size: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RepoFileDownloadProgress(BaseModel):
|
||||||
|
repo_id: str
|
||||||
|
repo_revision: str
|
||||||
|
file_path: str
|
||||||
|
downloaded: Memory
|
||||||
|
downloaded_this_session: Memory
|
||||||
|
total: Memory
|
||||||
|
speed: float
|
||||||
|
eta: timedelta
|
||||||
|
status: Literal["not_started", "in_progress", "complete"]
|
||||||
|
start_time: float
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
|
||||||
|
class RepoDownloadProgress(BaseModel):
|
||||||
|
repo_id: str
|
||||||
|
repo_revision: str
|
||||||
|
shard: ShardMetadata
|
||||||
|
completed_files: int
|
||||||
|
total_files: int
|
||||||
|
downloaded_bytes: Memory
|
||||||
|
downloaded_bytes_this_session: Memory
|
||||||
|
total_bytes: Memory
|
||||||
|
overall_speed: float
|
||||||
|
overall_eta: timedelta
|
||||||
|
status: Literal["not_started", "in_progress", "complete"]
|
||||||
|
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|||||||
@@ -17,17 +17,20 @@ import aiohttp
|
|||||||
import certifi
|
import certifi
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
|
||||||
ConfigDict,
|
|
||||||
DirectoryPath,
|
DirectoryPath,
|
||||||
Field,
|
|
||||||
PositiveInt,
|
|
||||||
TypeAdapter,
|
TypeAdapter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from exo.shared.constants import EXO_MODELS_DIR
|
from exo.shared.constants import EXO_MODELS_DIR
|
||||||
|
from exo.shared.types.common import ModelId
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.worker.downloads import DownloadProgressData
|
from exo.shared.types.worker.downloads import (
|
||||||
|
DownloadProgressData,
|
||||||
|
FileListEntry,
|
||||||
|
ModelSafetensorsIndex,
|
||||||
|
RepoDownloadProgress,
|
||||||
|
RepoFileDownloadProgress,
|
||||||
|
)
|
||||||
from exo.shared.types.worker.shards import ShardMetadata
|
from exo.shared.types.worker.shards import ShardMetadata
|
||||||
from exo.worker.download.huggingface_utils import (
|
from exo.worker.download.huggingface_utils import (
|
||||||
filter_repo_objects,
|
filter_repo_objects,
|
||||||
@@ -37,53 +40,6 @@ from exo.worker.download.huggingface_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelSafetensorsIndexMetadata(BaseModel):
|
|
||||||
total_size: PositiveInt
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSafetensorsIndex(BaseModel):
|
|
||||||
metadata: ModelSafetensorsIndexMetadata | None
|
|
||||||
weight_map: dict[str, str]
|
|
||||||
|
|
||||||
|
|
||||||
class FileListEntry(BaseModel):
|
|
||||||
type: Literal["file", "directory"]
|
|
||||||
path: str
|
|
||||||
size: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class RepoFileDownloadProgress(BaseModel):
|
|
||||||
repo_id: str
|
|
||||||
repo_revision: str
|
|
||||||
file_path: str
|
|
||||||
downloaded: Memory
|
|
||||||
downloaded_this_session: Memory
|
|
||||||
total: Memory
|
|
||||||
speed: float
|
|
||||||
eta: timedelta
|
|
||||||
status: Literal["not_started", "in_progress", "complete"]
|
|
||||||
start_time: float
|
|
||||||
|
|
||||||
model_config = ConfigDict(frozen=True)
|
|
||||||
|
|
||||||
|
|
||||||
class RepoDownloadProgress(BaseModel):
|
|
||||||
repo_id: str
|
|
||||||
repo_revision: str
|
|
||||||
shard: ShardMetadata
|
|
||||||
completed_files: int
|
|
||||||
total_files: int
|
|
||||||
downloaded_bytes: Memory
|
|
||||||
downloaded_bytes_this_session: Memory
|
|
||||||
total_bytes: Memory
|
|
||||||
overall_speed: float
|
|
||||||
overall_eta: timedelta
|
|
||||||
status: Literal["not_started", "in_progress", "complete"]
|
|
||||||
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
model_config = ConfigDict(frozen=True)
|
|
||||||
|
|
||||||
|
|
||||||
def trim_etag(etag: str) -> str:
|
def trim_etag(etag: str) -> str:
|
||||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||||
return etag[1:-1]
|
return etag[1:-1]
|
||||||
@@ -125,12 +81,12 @@ def map_repo_download_progress_to_download_progress_data(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_model_path(model_id: str) -> DirectoryPath:
|
def build_model_path(model_id: ModelId) -> DirectoryPath:
|
||||||
return EXO_MODELS_DIR / model_id.replace("/", "--")
|
return EXO_MODELS_DIR / model_id.normalize()
|
||||||
|
|
||||||
|
|
||||||
async def resolve_model_path_for_repo(repo_id: str) -> Path:
|
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
|
||||||
return (await ensure_models_dir()) / repo_id.replace("/", "--")
|
return (await ensure_models_dir()) / model_id.normalize()
|
||||||
|
|
||||||
|
|
||||||
async def ensure_models_dir() -> Path:
|
async def ensure_models_dir() -> Path:
|
||||||
@@ -138,8 +94,8 @@ async def ensure_models_dir() -> Path:
|
|||||||
return EXO_MODELS_DIR
|
return EXO_MODELS_DIR
|
||||||
|
|
||||||
|
|
||||||
async def delete_model(repo_id: str) -> bool:
|
async def delete_model(model_id: ModelId) -> bool:
|
||||||
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
|
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||||
if not await aios.path.exists(model_dir):
|
if not await aios.path.exists(model_dir):
|
||||||
return False
|
return False
|
||||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||||
@@ -164,19 +120,17 @@ async def seed_models(seed_dir: str | Path):
|
|||||||
|
|
||||||
|
|
||||||
async def fetch_file_list_with_cache(
|
async def fetch_file_list_with_cache(
|
||||||
repo_id: str, revision: str = "main", recursive: bool = False
|
model_id: ModelId, revision: str = "main", recursive: bool = False
|
||||||
) -> list[FileListEntry]:
|
) -> list[FileListEntry]:
|
||||||
target_dir = (
|
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||||
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
|
|
||||||
)
|
|
||||||
await aios.makedirs(target_dir, exist_ok=True)
|
await aios.makedirs(target_dir, exist_ok=True)
|
||||||
cache_file = (
|
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||||
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
|
||||||
)
|
|
||||||
if await aios.path.exists(cache_file):
|
if await aios.path.exists(cache_file):
|
||||||
async with aiofiles.open(cache_file, "r") as f:
|
async with aiofiles.open(cache_file, "r") as f:
|
||||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||||
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
|
file_list = await fetch_file_list_with_retry(
|
||||||
|
model_id, revision, recursive=recursive
|
||||||
|
)
|
||||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||||
async with aiofiles.open(cache_file, "w") as f:
|
async with aiofiles.open(cache_file, "w") as f:
|
||||||
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||||
@@ -184,25 +138,25 @@ async def fetch_file_list_with_cache(
|
|||||||
|
|
||||||
|
|
||||||
async def fetch_file_list_with_retry(
|
async def fetch_file_list_with_retry(
|
||||||
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
|
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||||
) -> list[FileListEntry]:
|
) -> list[FileListEntry]:
|
||||||
n_attempts = 30
|
n_attempts = 30
|
||||||
for attempt in range(n_attempts):
|
for attempt in range(n_attempts):
|
||||||
try:
|
try:
|
||||||
return await _fetch_file_list(repo_id, revision, path, recursive)
|
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt == n_attempts - 1:
|
if attempt == n_attempts - 1:
|
||||||
raise e
|
raise e
|
||||||
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
|
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
|
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_file_list(
|
async def _fetch_file_list(
|
||||||
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
|
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||||
) -> list[FileListEntry]:
|
) -> list[FileListEntry]:
|
||||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
|
||||||
url = f"{api_url}/{path}" if path else api_url
|
url = f"{api_url}/{path}" if path else api_url
|
||||||
|
|
||||||
headers = await get_download_headers()
|
headers = await get_download_headers()
|
||||||
@@ -219,7 +173,7 @@ async def _fetch_file_list(
|
|||||||
files.append(FileListEntry.model_validate(item))
|
files.append(FileListEntry.model_validate(item))
|
||||||
elif item.type == "directory" and recursive:
|
elif item.type == "directory" and recursive:
|
||||||
subfiles = await _fetch_file_list(
|
subfiles = await _fetch_file_list(
|
||||||
repo_id, revision, item.path, recursive
|
model_id, revision, item.path, recursive
|
||||||
)
|
)
|
||||||
files.extend(subfiles)
|
files.extend(subfiles)
|
||||||
return files
|
return files
|
||||||
@@ -276,10 +230,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
|
|||||||
|
|
||||||
|
|
||||||
async def file_meta(
|
async def file_meta(
|
||||||
repo_id: str, revision: str, path: str, redirected_location: str | None = None
|
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
|
||||||
) -> tuple[int, str]:
|
) -> tuple[int, str]:
|
||||||
url = (
|
url = (
|
||||||
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
|
||||||
if redirected_location is None
|
if redirected_location is None
|
||||||
else f"{get_hf_endpoint()}{redirected_location}"
|
else f"{get_hf_endpoint()}{redirected_location}"
|
||||||
)
|
)
|
||||||
@@ -298,7 +252,7 @@ async def file_meta(
|
|||||||
return content_length, etag
|
return content_length, etag
|
||||||
# Otherwise, follow the redirect to get authoritative size/hash
|
# Otherwise, follow the redirect to get authoritative size/hash
|
||||||
redirected_location = r.headers.get("location")
|
redirected_location = r.headers.get("location")
|
||||||
return await file_meta(repo_id, revision, path, redirected_location)
|
return await file_meta(model_id, revision, path, redirected_location)
|
||||||
content_length = int(
|
content_length = int(
|
||||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||||
)
|
)
|
||||||
@@ -310,7 +264,7 @@ async def file_meta(
|
|||||||
|
|
||||||
|
|
||||||
async def download_file_with_retry(
|
async def download_file_with_retry(
|
||||||
repo_id: str,
|
model_id: ModelId,
|
||||||
revision: str,
|
revision: str,
|
||||||
path: str,
|
path: str,
|
||||||
target_dir: Path,
|
target_dir: Path,
|
||||||
@@ -320,23 +274,23 @@ async def download_file_with_retry(
|
|||||||
for attempt in range(n_attempts):
|
for attempt in range(n_attempts):
|
||||||
try:
|
try:
|
||||||
return await _download_file(
|
return await _download_file(
|
||||||
repo_id, revision, path, target_dir, on_progress
|
model_id, revision, path, target_dir, on_progress
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||||
raise e
|
raise e
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
|
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||||
)
|
)
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
|
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _download_file(
|
async def _download_file(
|
||||||
repo_id: str,
|
model_id: ModelId,
|
||||||
revision: str,
|
revision: str,
|
||||||
path: str,
|
path: str,
|
||||||
target_dir: Path,
|
target_dir: Path,
|
||||||
@@ -345,7 +299,7 @@ async def _download_file(
|
|||||||
if await aios.path.exists(target_dir / path):
|
if await aios.path.exists(target_dir / path):
|
||||||
return target_dir / path
|
return target_dir / path
|
||||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||||
length, etag = await file_meta(repo_id, revision, path)
|
length, etag = await file_meta(model_id, revision, path)
|
||||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||||
partial_path = target_dir / f"{path}.partial"
|
partial_path = target_dir / f"{path}.partial"
|
||||||
resume_byte_pos = (
|
resume_byte_pos = (
|
||||||
@@ -354,7 +308,7 @@ async def _download_file(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
if resume_byte_pos != length:
|
if resume_byte_pos != length:
|
||||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
|
||||||
headers = await get_download_headers()
|
headers = await get_download_headers()
|
||||||
if resume_byte_pos:
|
if resume_byte_pos:
|
||||||
headers["Range"] = f"bytes={resume_byte_pos}-"
|
headers["Range"] = f"bytes={resume_byte_pos}-"
|
||||||
@@ -394,7 +348,7 @@ async def _download_file(
|
|||||||
|
|
||||||
def calculate_repo_progress(
|
def calculate_repo_progress(
|
||||||
shard: ShardMetadata,
|
shard: ShardMetadata,
|
||||||
repo_id: str,
|
model_id: ModelId,
|
||||||
revision: str,
|
revision: str,
|
||||||
file_progress: dict[str, RepoFileDownloadProgress],
|
file_progress: dict[str, RepoFileDownloadProgress],
|
||||||
all_start_time: float,
|
all_start_time: float,
|
||||||
@@ -423,7 +377,7 @@ def calculate_repo_progress(
|
|||||||
else "not_started"
|
else "not_started"
|
||||||
)
|
)
|
||||||
return RepoDownloadProgress(
|
return RepoDownloadProgress(
|
||||||
repo_id=repo_id,
|
repo_id=model_id,
|
||||||
repo_revision=revision,
|
repo_revision=revision,
|
||||||
shard=shard,
|
shard=shard,
|
||||||
completed_files=len(
|
completed_files=len(
|
||||||
@@ -442,11 +396,11 @@ def calculate_repo_progress(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
|
||||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||||
await aios.makedirs(target_dir, exist_ok=True)
|
await aios.makedirs(target_dir, exist_ok=True)
|
||||||
index_file = await download_file_with_retry(
|
index_file = await download_file_with_retry(
|
||||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
model_id, revision, "model.safetensors.index.json", target_dir
|
||||||
)
|
)
|
||||||
async with aiofiles.open(index_file, "r") as f:
|
async with aiofiles.open(index_file, "r") as f:
|
||||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||||
@@ -504,7 +458,7 @@ async def download_shard(
|
|||||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||||
# Update: <- This does not seem to be the case. Yay?
|
# Update: <- This does not seem to be the case. Yay?
|
||||||
file_list = await fetch_file_list_with_cache(
|
file_list = await fetch_file_list_with_cache(
|
||||||
str(shard.model_card.model_id), revision, recursive=True
|
shard.model_card.model_id, revision, recursive=True
|
||||||
)
|
)
|
||||||
filtered_file_list = list(
|
filtered_file_list = list(
|
||||||
filter_repo_objects(
|
filter_repo_objects(
|
||||||
@@ -538,7 +492,7 @@ async def download_shard(
|
|||||||
else timedelta(seconds=0)
|
else timedelta(seconds=0)
|
||||||
)
|
)
|
||||||
file_progress[file.path] = RepoFileDownloadProgress(
|
file_progress[file.path] = RepoFileDownloadProgress(
|
||||||
repo_id=str(shard.model_card.model_id),
|
repo_id=shard.model_card.model_id,
|
||||||
repo_revision=revision,
|
repo_revision=revision,
|
||||||
file_path=file.path,
|
file_path=file.path,
|
||||||
downloaded=Memory.from_bytes(curr_bytes),
|
downloaded=Memory.from_bytes(curr_bytes),
|
||||||
@@ -555,7 +509,7 @@ async def download_shard(
|
|||||||
shard,
|
shard,
|
||||||
calculate_repo_progress(
|
calculate_repo_progress(
|
||||||
shard,
|
shard,
|
||||||
str(shard.model_card.model_id),
|
shard.model_card.model_id,
|
||||||
revision,
|
revision,
|
||||||
file_progress,
|
file_progress,
|
||||||
all_start_time,
|
all_start_time,
|
||||||
@@ -565,7 +519,7 @@ async def download_shard(
|
|||||||
for file in filtered_file_list:
|
for file in filtered_file_list:
|
||||||
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
||||||
file_progress[file.path] = RepoFileDownloadProgress(
|
file_progress[file.path] = RepoFileDownloadProgress(
|
||||||
repo_id=str(shard.model_card.model_id),
|
repo_id=shard.model_card.model_id,
|
||||||
repo_revision=revision,
|
repo_revision=revision,
|
||||||
file_path=file.path,
|
file_path=file.path,
|
||||||
downloaded=Memory.from_bytes(downloaded_bytes),
|
downloaded=Memory.from_bytes(downloaded_bytes),
|
||||||
@@ -589,7 +543,7 @@ async def download_shard(
|
|||||||
async def download_with_semaphore(file: FileListEntry) -> None:
|
async def download_with_semaphore(file: FileListEntry) -> None:
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
await download_file_with_retry(
|
await download_file_with_retry(
|
||||||
str(shard.model_card.model_id),
|
shard.model_card.model_id,
|
||||||
revision,
|
revision,
|
||||||
file.path,
|
file.path,
|
||||||
target_dir,
|
target_dir,
|
||||||
@@ -603,7 +557,7 @@ async def download_shard(
|
|||||||
*[download_with_semaphore(file) for file in filtered_file_list]
|
*[download_with_semaphore(file) for file in filtered_file_list]
|
||||||
)
|
)
|
||||||
final_repo_progress = calculate_repo_progress(
|
final_repo_progress = calculate_repo_progress(
|
||||||
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
|
shard, shard.model_card.model_id, revision, file_progress, all_start_time
|
||||||
)
|
)
|
||||||
await on_progress(shard, final_repo_progress)
|
await on_progress(shard, final_repo_progress)
|
||||||
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
|
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ from collections.abc import Awaitable
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncIterator, Callable
|
from typing import AsyncIterator, Callable
|
||||||
|
|
||||||
from exo.shared.models.model_cards import MODEL_CARDS
|
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||||
from exo.shared.models.model_meta import get_model_card
|
|
||||||
from exo.shared.types.worker.shards import (
|
from exo.shared.types.worker.shards import (
|
||||||
PipelineShardMetadata,
|
PipelineShardMetadata,
|
||||||
ShardMetadata,
|
ShardMetadata,
|
||||||
@@ -19,8 +18,8 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def build_base_shard(model_id: str) -> ShardMetadata:
|
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||||
model_card = await get_model_card(model_id)
|
model_card = await ModelCard.from_hf(model_id)
|
||||||
return PipelineShardMetadata(
|
return PipelineShardMetadata(
|
||||||
model_card=model_card,
|
model_card=model_card,
|
||||||
device_rank=0,
|
device_rank=0,
|
||||||
@@ -31,7 +30,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
|
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
|
||||||
base_shard = await build_base_shard(model_id)
|
base_shard = await build_base_shard(model_id)
|
||||||
return PipelineShardMetadata(
|
return PipelineShardMetadata(
|
||||||
model_card=base_shard.model_card,
|
model_card=base_shard.model_card,
|
||||||
@@ -148,7 +147,7 @@ class ResumableShardDownloader(ShardDownloader):
|
|||||||
self,
|
self,
|
||||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||||
async def _status_for_model(
|
async def _status_for_model(
|
||||||
model_id: str,
|
model_id: ModelId,
|
||||||
) -> tuple[Path, RepoDownloadProgress]:
|
) -> tuple[Path, RepoDownloadProgress]:
|
||||||
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
||||||
shard = await build_full_shard(model_id)
|
shard = await build_full_shard(model_id)
|
||||||
|
|||||||
@@ -83,11 +83,11 @@ class CustomMlxLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, original_layer: _LayerCallable):
|
def __init__(self, original_layer: _LayerCallable):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
object.__setattr__(self, "_original_layer", original_layer)
|
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def original_layer(self) -> _LayerCallable:
|
def original_layer(self) -> _LayerCallable:
|
||||||
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
return cast(_LayerCallable, self["_original_layer"])
|
||||||
|
|
||||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||||
if not TYPE_CHECKING:
|
if not TYPE_CHECKING:
|
||||||
@@ -96,7 +96,7 @@ class CustomMlxLayer(nn.Module):
|
|||||||
try:
|
try:
|
||||||
return super().__getattr__(name)
|
return super().__getattr__(name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
original_layer = object.__getattribute__(self, "_original_layer")
|
original_layer = cast(_LayerCallable, self["_original_layer"])
|
||||||
return getattr(original_layer, name)
|
return getattr(original_layer, name)
|
||||||
|
|
||||||
|
|
||||||
@@ -334,7 +334,7 @@ def tensor_auto_parallel(
|
|||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(model, "shard"):
|
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||||
try:
|
try:
|
||||||
model.shard(group) # type: ignore
|
model.shard(group) # type: ignore
|
||||||
return patch_tensor_model(model)
|
return patch_tensor_model(model)
|
||||||
@@ -383,7 +383,6 @@ def tensor_auto_parallel(
|
|||||||
all_to_sharded_linear_in_place,
|
all_to_sharded_linear_in_place,
|
||||||
sharded_to_all_linear_in_place,
|
sharded_to_all_linear_in_place,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
|||||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
from exo.shared.models.model_cards import ModelId
|
||||||
from exo.worker.engines.mlx.constants import (
|
from exo.worker.engines.mlx.constants import (
|
||||||
CACHE_GROUP_SIZE,
|
CACHE_GROUP_SIZE,
|
||||||
KV_CACHE_BITS,
|
KV_CACHE_BITS,
|
||||||
@@ -296,7 +297,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
|
|||||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||||
|
|
||||||
|
|
||||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||||
"""
|
"""
|
||||||
Get the EOS token IDs for a model based on its ID.
|
Get the EOS token IDs for a model based on its ID.
|
||||||
|
|
||||||
@@ -320,7 +321,9 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
|
def load_tokenizer_for_model_id(
|
||||||
|
model_id: ModelId, model_path: Path
|
||||||
|
) -> TokenizerWrapper:
|
||||||
"""
|
"""
|
||||||
Load tokenizer for a model given its ID and local path.
|
Load tokenizer for a model given its ID and local path.
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from exo.shared.constants import EXO_MODELS_DIR
|
from exo.shared.constants import EXO_MODELS_DIR
|
||||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
from exo.shared.models.model_cards import ModelCard
|
||||||
from exo.shared.types.api import ChatCompletionMessage
|
from exo.shared.types.api import ChatCompletionMessage
|
||||||
|
from exo.shared.types.common import ModelId
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||||
from exo.worker.download.download_utils import (
|
from exo.worker.download.download_utils import (
|
||||||
download_file_with_retry,
|
download_file_with_retry,
|
||||||
ensure_models_dir,
|
ensure_models_dir,
|
||||||
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def download_tokenizer_files(model_id: str) -> Path:
|
async def download_tokenizer_files(model_id: ModelId) -> Path:
|
||||||
"""Download only the tokenizer-related files for a model."""
|
"""Download only the tokenizer-related files for a model."""
|
||||||
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
|
target_dir = await ensure_models_dir() / model_id.normalize()
|
||||||
target_dir.mkdir(parents=True, exist_ok=True)
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
|
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
|
||||||
@@ -72,22 +72,22 @@ async def download_tokenizer_files(model_id: str) -> Path:
|
|||||||
|
|
||||||
|
|
||||||
# Get a sample of models to test (one per family to keep tests fast)
|
# Get a sample of models to test (one per family to keep tests fast)
|
||||||
def get_test_models() -> list[tuple[str, ModelCard]]:
|
def get_test_models() -> list[ModelCard]:
|
||||||
"""Get a representative sample of models to test."""
|
"""Get a representative sample of models to test."""
|
||||||
# Pick one model from each family to test
|
# Pick one model from each family to test
|
||||||
families: dict[str, tuple[str, ModelCard]] = {}
|
families: dict[str, ModelCard] = {}
|
||||||
for _, card in MODEL_CARDS.items():
|
for card in MODEL_CARDS.values():
|
||||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||||
parts = card.model_id.short().split("-")
|
parts = card.model_id.short().split("-")
|
||||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||||
|
|
||||||
if family not in families:
|
if family not in families:
|
||||||
families[family] = (card.model_id.short(), card)
|
families[family] = card
|
||||||
|
|
||||||
return list(families.values())
|
return list(families.values())
|
||||||
|
|
||||||
|
|
||||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
TEST_MODELS: list[ModelCard] = get_test_models()
|
||||||
|
|
||||||
pytestmark = pytest.mark.slow
|
pytestmark = pytest.mark.slow
|
||||||
|
|
||||||
@@ -101,14 +101,13 @@ def event_loop():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"short_id,model_card",
|
"model_card",
|
||||||
TEST_MODELS,
|
TEST_MODELS,
|
||||||
ids=[m[0] for m in TEST_MODELS],
|
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
|
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
|
||||||
"""Test that tokenizer can encode and decode text correctly."""
|
"""Test that tokenizer can encode and decode text correctly."""
|
||||||
model_id = str(model_card.model_id)
|
model_id = model_card.model_id
|
||||||
|
|
||||||
# Download tokenizer files
|
# Download tokenizer files
|
||||||
model_path = await download_tokenizer_files(model_id)
|
model_path = await download_tokenizer_files(model_id)
|
||||||
@@ -167,16 +166,15 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"short_id,model_card",
|
"model_card",
|
||||||
TEST_MODELS,
|
TEST_MODELS,
|
||||||
ids=[m[0] for m in TEST_MODELS],
|
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tokenizer_has_required_attributes(
|
async def test_tokenizer_has_required_attributes(
|
||||||
short_id: str, model_card: ModelCard
|
short_id: str, model_card: ModelCard
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that tokenizer has required attributes for inference."""
|
"""Test that tokenizer has required attributes for inference."""
|
||||||
model_id = str(model_card.model_id)
|
model_id = model_card.model_id
|
||||||
|
|
||||||
model_path = await download_tokenizer_files(model_id)
|
model_path = await download_tokenizer_files(model_id)
|
||||||
|
|
||||||
@@ -209,19 +207,18 @@ async def test_tokenizer_has_required_attributes(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"short_id,model_card",
|
"model_card",
|
||||||
TEST_MODELS,
|
TEST_MODELS,
|
||||||
ids=[m[0] for m in TEST_MODELS],
|
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
|
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
|
||||||
"""Test that tokenizer can encode text containing special tokens.
|
"""Test that tokenizer can encode text containing special tokens.
|
||||||
|
|
||||||
This is critical because the actual inference path uses prompts with
|
This is critical because the actual inference path uses prompts with
|
||||||
special tokens from chat templates. If special tokens aren't handled
|
special tokens from chat templates. If special tokens aren't handled
|
||||||
correctly, encoding will fail.
|
correctly, encoding will fail.
|
||||||
"""
|
"""
|
||||||
model_id = str(model_card.model_id)
|
model_id = model_card.model_id
|
||||||
|
|
||||||
model_path = await download_tokenizer_files(model_id)
|
model_path = await download_tokenizer_files(model_id)
|
||||||
|
|
||||||
@@ -301,16 +298,14 @@ async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) ->
|
|||||||
async def test_kimi_tokenizer_specifically():
|
async def test_kimi_tokenizer_specifically():
|
||||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||||
kimi_models = [
|
kimi_models = [
|
||||||
(short_id, card)
|
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
|
||||||
for short_id, card in MODEL_CARDS.items()
|
|
||||||
if "kimi" in short_id.lower()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if not kimi_models:
|
if not kimi_models:
|
||||||
pytest.skip("No Kimi models found in MODEL_CARDS")
|
pytest.skip("No Kimi models found in MODEL_CARDS")
|
||||||
|
|
||||||
_, model_card = kimi_models[0]
|
model_card = kimi_models[0]
|
||||||
model_id = str(model_card.model_id)
|
model_id = model_card.model_id
|
||||||
|
|
||||||
model_path = await download_tokenizer_files(model_id)
|
model_path = await download_tokenizer_files(model_id)
|
||||||
|
|
||||||
@@ -349,17 +344,15 @@ async def test_kimi_tokenizer_specifically():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_glm_tokenizer_specifically():
|
async def test_glm_tokenizer_specifically():
|
||||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||||
glm_models = [
|
glm_model_cards = [
|
||||||
(short_id, card)
|
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
|
||||||
for short_id, card in MODEL_CARDS.items()
|
|
||||||
if "glm" in short_id.lower()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if not glm_models:
|
if not glm_model_cards:
|
||||||
pytest.skip("No GLM models found in MODEL_CARDS")
|
pytest.skip("No GLM models found in MODEL_CARDS")
|
||||||
|
|
||||||
_, model_card = glm_models[0]
|
model_card = glm_model_cards[0]
|
||||||
model_id = str(model_card.model_id)
|
model_id = model_card.model_id
|
||||||
|
|
||||||
model_path = await download_tokenizer_files(model_id)
|
model_path = await download_tokenizer_files(model_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import exo.worker.plan as plan_mod
|
import exo.worker.plan as plan_mod
|
||||||
from exo.shared.models.model_cards import ModelId
|
from exo.shared.types.common import ModelId, NodeId
|
||||||
from exo.shared.types.common import NodeId
|
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.tasks import LoadModel
|
from exo.shared.types.tasks import LoadModel
|
||||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||||
|
|||||||
12
uv.lock
generated
12
uv.lock
generated
@@ -248,6 +248,7 @@ dependencies = [
|
|||||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -281,6 +282,7 @@ requires-dist = [
|
|||||||
{ name = "pydantic", specifier = ">=2.11.7" },
|
{ name = "pydantic", specifier = ">=2.11.7" },
|
||||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||||
{ name = "tiktoken", specifier = ">=0.12.0" },
|
{ name = "tiktoken", specifier = ">=0.12.0" },
|
||||||
|
{ name = "tomlkit", specifier = ">=0.14.0" },
|
||||||
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
|
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -315,6 +317,16 @@ dev = [
|
|||||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tomlkit"
|
||||||
|
version = "0.14.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.128.0"
|
version = "0.128.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user