mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-21 04:22:21 -05:00
Compare commits
2 Commits
leo/fix-sm
...
foo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f6f2f3065 | ||
|
|
e6af53c2ae |
@@ -24,7 +24,6 @@ dependencies = [
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -19,11 +19,8 @@ from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import (
|
||||
MODEL_CARDS,
|
||||
ModelCard,
|
||||
ModelId,
|
||||
)
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
@@ -89,12 +86,12 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
async def resolve_model_card(model_id: str) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
else:
|
||||
return await ModelCard.from_hf(model_id)
|
||||
return await get_model_card(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
@@ -239,7 +236,7 @@ class API:
|
||||
|
||||
async def get_placement(
|
||||
self,
|
||||
model_id: ModelId,
|
||||
model_id: str,
|
||||
sharding: Sharding = Sharding.Pipeline,
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
@@ -554,7 +551,7 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -581,7 +578,7 @@ class API:
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
from typing import Annotated
|
||||
from pydantic import PositiveInt
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import tomlkit
|
||||
from anyio import Path, open_file
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.common import Id
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
_card_cache: dict[str, "ModelCard"] = {}
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
def short(self) -> str:
|
||||
return self.split("/")[-1]
|
||||
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
@@ -22,43 +20,6 @@ class ModelCard(CamelCaseModel):
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
|
||||
async def save(self, path: Path) -> None:
|
||||
async with await open_file(path, "w") as f:
|
||||
py = self.model_dump()
|
||||
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
|
||||
await f.write(data)
|
||||
|
||||
@staticmethod
|
||||
async def load_from_path(path: Path) -> "ModelCard":
|
||||
async with await open_file(path, "r") as f:
|
||||
py = tomlkit.loads(await f.read())
|
||||
return ModelCard.model_validate(py)
|
||||
|
||||
@staticmethod
|
||||
async def load(model_id: ModelId) -> "ModelCard":
|
||||
if model_id in MODEL_CARDS:
|
||||
return MODEL_CARDS[model_id]
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
@staticmethod
|
||||
async def from_hf(model_id: ModelId) -> "ModelCard":
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
if (mc := _card_cache.get(model_id)) is not None:
|
||||
return mc
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
|
||||
mc = ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
supports_tensor=config_data.supports_tensor,
|
||||
)
|
||||
_card_cache[model_id] = mc
|
||||
return mc
|
||||
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
@@ -347,99 +308,3 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
}
|
||||
|
||||
from exo.worker.download.download_utils import ( # noqa: E402
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
architectures: list[str] | None = None
|
||||
|
||||
@property
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
["Qwen3MoeForCausalLM"],
|
||||
["MiniMaxM2ForCausalLM"],
|
||||
["LlamaForCausalLM"],
|
||||
["GptOssForCausalLM"],
|
||||
]
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
config_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(config_path, "r") as f:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
|
||||
async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(index_path, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
metadata = index_data.metadata
|
||||
if metadata is not None:
|
||||
return Memory.from_bytes(metadata.total_size)
|
||||
|
||||
info = model_info(model_id)
|
||||
if info.safetensors is None:
|
||||
raise ValueError(f"No safetensors info found for {model_id}")
|
||||
return Memory.from_bytes(info.safetensors.total)
|
||||
|
||||
122
src/exo/shared/models/model_meta.py
Normal file
122
src/exo/shared/models/model_meta.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Annotated
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
|
||||
|
||||
async def get_config_data(model_id: str) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
config_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(config_path, "r") as f:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
|
||||
async def get_safetensors_size(model_id: str) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(index_path, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
metadata = index_data.metadata
|
||||
if metadata is not None:
|
||||
return Memory.from_bytes(metadata.total_size)
|
||||
|
||||
info = model_info(model_id)
|
||||
if info.safetensors is None:
|
||||
raise ValueError(f"No safetensors info found for {model_id}")
|
||||
return Memory.from_bytes(info.safetensors.total)
|
||||
|
||||
|
||||
_model_card_cache: dict[str, ModelCard] = {}
|
||||
|
||||
|
||||
async def get_model_card(model_id: str) -> ModelCard:
|
||||
if model_id in _model_card_cache:
|
||||
return _model_card_cache[model_id]
|
||||
model_card = await _get_model_card(model_id)
|
||||
_model_card_cache[model_id] = model_card
|
||||
return model_card
|
||||
|
||||
|
||||
async def _get_model_card(model_id: str) -> ModelCard:
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
model_card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.supports_tensor if model_card is not None else False,
|
||||
)
|
||||
@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
|
||||
|
||||
class PlaceInstanceParams(BaseModel):
|
||||
model_id: ModelId
|
||||
model_id: str
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@@ -25,14 +25,6 @@ class NodeId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
def short(self) -> str:
|
||||
return self.split("/")[-1]
|
||||
|
||||
|
||||
class SessionId(CamelCaseModel):
|
||||
master_node_id: NodeId
|
||||
election_clock: int
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
from datetime import timedelta
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
@@ -47,50 +42,3 @@ class DownloadOngoing(BaseDownloadProgress):
|
||||
DownloadProgress = (
|
||||
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
|
||||
)
|
||||
|
||||
|
||||
class ModelSafetensorsIndexMetadata(BaseModel):
|
||||
total_size: PositiveInt
|
||||
|
||||
|
||||
class ModelSafetensorsIndex(BaseModel):
|
||||
metadata: ModelSafetensorsIndexMetadata | None
|
||||
weight_map: dict[str, str]
|
||||
|
||||
|
||||
class FileListEntry(BaseModel):
|
||||
type: Literal["file", "directory"]
|
||||
path: str
|
||||
size: int | None = None
|
||||
|
||||
|
||||
class RepoFileDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
speed: float
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class RepoDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@@ -17,20 +17,17 @@ import aiohttp
|
||||
import certifi
|
||||
from loguru import logger
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
DirectoryPath,
|
||||
Field,
|
||||
PositiveInt,
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadProgressData,
|
||||
FileListEntry,
|
||||
ModelSafetensorsIndex,
|
||||
RepoDownloadProgress,
|
||||
RepoFileDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import DownloadProgressData
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
@@ -40,6 +37,53 @@ from exo.worker.download.huggingface_utils import (
|
||||
)
|
||||
|
||||
|
||||
class ModelSafetensorsIndexMetadata(BaseModel):
|
||||
total_size: PositiveInt
|
||||
|
||||
|
||||
class ModelSafetensorsIndex(BaseModel):
|
||||
metadata: ModelSafetensorsIndexMetadata | None
|
||||
weight_map: dict[str, str]
|
||||
|
||||
|
||||
class FileListEntry(BaseModel):
|
||||
type: Literal["file", "directory"]
|
||||
path: str
|
||||
size: int | None = None
|
||||
|
||||
|
||||
class RepoFileDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
speed: float
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class RepoDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
def trim_etag(etag: str) -> str:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
return etag[1:-1]
|
||||
@@ -81,12 +125,12 @@ def map_repo_download_progress_to_download_progress_data(
|
||||
)
|
||||
|
||||
|
||||
def build_model_path(model_id: ModelId) -> DirectoryPath:
|
||||
return EXO_MODELS_DIR / model_id.normalize()
|
||||
def build_model_path(model_id: str) -> DirectoryPath:
|
||||
return EXO_MODELS_DIR / model_id.replace("/", "--")
|
||||
|
||||
|
||||
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
|
||||
return (await ensure_models_dir()) / model_id.normalize()
|
||||
async def resolve_model_path_for_repo(repo_id: str) -> Path:
|
||||
return (await ensure_models_dir()) / repo_id.replace("/", "--")
|
||||
|
||||
|
||||
async def ensure_models_dir() -> Path:
|
||||
@@ -94,8 +138,8 @@ async def ensure_models_dir() -> Path:
|
||||
return EXO_MODELS_DIR
|
||||
|
||||
|
||||
async def delete_model(model_id: ModelId) -> bool:
|
||||
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||
async def delete_model(repo_id: str) -> bool:
|
||||
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
|
||||
if not await aios.path.exists(model_dir):
|
||||
return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
@@ -120,17 +164,19 @@ async def seed_models(seed_dir: str | Path):
|
||||
|
||||
|
||||
async def fetch_file_list_with_cache(
|
||||
model_id: ModelId, revision: str = "main", recursive: bool = False
|
||||
repo_id: str, revision: str = "main", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
target_dir = (
|
||||
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
|
||||
)
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
cache_file = (
|
||||
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
)
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||
@@ -138,25 +184,25 @@ async def fetch_file_list_with_cache(
|
||||
|
||||
|
||||
async def fetch_file_list_with_retry(
|
||||
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||
return await _fetch_file_list(repo_id, revision, path, recursive)
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
|
||||
raise Exception(
|
||||
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
||||
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_file_list(
|
||||
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_download_headers()
|
||||
@@ -173,7 +219,7 @@ async def _fetch_file_list(
|
||||
files.append(FileListEntry.model_validate(item))
|
||||
elif item.type == "directory" and recursive:
|
||||
subfiles = await _fetch_file_list(
|
||||
model_id, revision, item.path, recursive
|
||||
repo_id, revision, item.path, recursive
|
||||
)
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
@@ -230,10 +276,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
|
||||
|
||||
|
||||
async def file_meta(
|
||||
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
|
||||
repo_id: str, revision: str, path: str, redirected_location: str | None = None
|
||||
) -> tuple[int, str]:
|
||||
url = (
|
||||
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
|
||||
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
if redirected_location is None
|
||||
else f"{get_hf_endpoint()}{redirected_location}"
|
||||
)
|
||||
@@ -252,7 +298,7 @@ async def file_meta(
|
||||
return content_length, etag
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = r.headers.get("location")
|
||||
return await file_meta(model_id, revision, path, redirected_location)
|
||||
return await file_meta(repo_id, revision, path, redirected_location)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
@@ -264,7 +310,7 @@ async def file_meta(
|
||||
|
||||
|
||||
async def download_file_with_retry(
|
||||
model_id: ModelId,
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
@@ -274,23 +320,23 @@ async def download_file_with_retry(
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
repo_id, revision, path, target_dir, on_progress
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||
raise e
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
|
||||
|
||||
async def _download_file(
|
||||
model_id: ModelId,
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
@@ -299,7 +345,7 @@ async def _download_file(
|
||||
if await aios.path.exists(target_dir / path):
|
||||
return target_dir / path
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
length, etag = await file_meta(repo_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
partial_path = target_dir / f"{path}.partial"
|
||||
resume_byte_pos = (
|
||||
@@ -308,7 +354,7 @@ async def _download_file(
|
||||
else None
|
||||
)
|
||||
if resume_byte_pos != length:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_download_headers()
|
||||
if resume_byte_pos:
|
||||
headers["Range"] = f"bytes={resume_byte_pos}-"
|
||||
@@ -348,7 +394,7 @@ async def _download_file(
|
||||
|
||||
def calculate_repo_progress(
|
||||
shard: ShardMetadata,
|
||||
model_id: ModelId,
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
file_progress: dict[str, RepoFileDownloadProgress],
|
||||
all_start_time: float,
|
||||
@@ -377,7 +423,7 @@ def calculate_repo_progress(
|
||||
else "not_started"
|
||||
)
|
||||
return RepoDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_id=repo_id,
|
||||
repo_revision=revision,
|
||||
shard=shard,
|
||||
completed_files=len(
|
||||
@@ -396,11 +442,11 @@ def calculate_repo_progress(
|
||||
)
|
||||
|
||||
|
||||
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_file = await download_file_with_retry(
|
||||
model_id, revision, "model.safetensors.index.json", target_dir
|
||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
||||
)
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
@@ -458,7 +504,7 @@ async def download_shard(
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
shard.model_card.model_id, revision, recursive=True
|
||||
str(shard.model_card.model_id), revision, recursive=True
|
||||
)
|
||||
filtered_file_list = list(
|
||||
filter_repo_objects(
|
||||
@@ -492,7 +538,7 @@ async def download_shard(
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=shard.model_card.model_id,
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(curr_bytes),
|
||||
@@ -509,7 +555,7 @@ async def download_shard(
|
||||
shard,
|
||||
calculate_repo_progress(
|
||||
shard,
|
||||
shard.model_card.model_id,
|
||||
str(shard.model_card.model_id),
|
||||
revision,
|
||||
file_progress,
|
||||
all_start_time,
|
||||
@@ -519,7 +565,7 @@ async def download_shard(
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=shard.model_card.model_id,
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(downloaded_bytes),
|
||||
@@ -543,7 +589,7 @@ async def download_shard(
|
||||
async def download_with_semaphore(file: FileListEntry) -> None:
|
||||
async with semaphore:
|
||||
await download_file_with_retry(
|
||||
shard.model_card.model_id,
|
||||
str(shard.model_card.model_id),
|
||||
revision,
|
||||
file.path,
|
||||
target_dir,
|
||||
@@ -557,7 +603,7 @@ async def download_shard(
|
||||
*[download_with_semaphore(file) for file in filtered_file_list]
|
||||
)
|
||||
final_repo_progress = calculate_repo_progress(
|
||||
shard, shard.model_card.model_id, revision, file_progress, all_start_time
|
||||
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
|
||||
)
|
||||
await on_progress(shard, final_repo_progress)
|
||||
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
|
||||
|
||||
@@ -3,7 +3,8 @@ from collections.abc import Awaitable
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -18,8 +19,8 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
)
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
async def build_base_shard(model_id: str) -> ShardMetadata:
|
||||
model_card = await get_model_card(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -30,7 +31,7 @@ async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
)
|
||||
|
||||
|
||||
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
|
||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
|
||||
base_shard = await build_base_shard(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=base_shard.model_card,
|
||||
@@ -147,7 +148,7 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
self,
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
async def _status_for_model(
|
||||
model_id: ModelId,
|
||||
model_id: str,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
||||
shard = await build_full_shard(model_id)
|
||||
|
||||
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -67,27 +67,16 @@ def eval_with_timeout(
|
||||
completed.set()
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
"""Structural type that any compatible layer must satisfy.
|
||||
|
||||
We require a single positional input of type ``mx.array`` and an
|
||||
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
|
||||
protocol matches the vast majority of `mlx.nn.Module` subclasses.
|
||||
"""
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
|
||||
|
||||
|
||||
class CustomMlxLayer(nn.Module):
|
||||
"""Base class for replacing an MLX layer with a custom implementation."""
|
||||
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
def __init__(self, original_layer: nn.Module):
|
||||
super().__init__()
|
||||
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
|
||||
@property
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, self["_original_layer"])
|
||||
def original_layer(self) -> nn.Module:
|
||||
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -96,81 +85,57 @@ class CustomMlxLayer(nn.Module):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
original_layer = cast(_LayerCallable, self["_original_layer"])
|
||||
original_layer = object.__getattribute__(self, "_original_layer")
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
r: int,
|
||||
group: mx.distributed.Group,
|
||||
):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.group = group
|
||||
def patch_pipeline_first_layer(
|
||||
pipeline_layer: nn.Module, group: mx.distributed.Group
|
||||
) -> nn.Module:
|
||||
cls = type(pipeline_layer)
|
||||
orig_call = cast(Callable[..., mx.array], cls.__call__)
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
original_shape = x.shape
|
||||
original_last_dim = original_shape[-1]
|
||||
# Pad to 8192 on last dim (16KB for float16)
|
||||
padded_last_dim = 8192
|
||||
padded_shape = (*original_shape[:-1], padded_last_dim)
|
||||
padded_template = mx.zeros(padded_shape, dtype=x.dtype)
|
||||
rank = group.rank()
|
||||
|
||||
logger.info(f"[recv] expecting padded shape={padded_shape}, original={original_shape}")
|
||||
received_padded = mx.distributed.recv_like(padded_template, self.r - 1, group=self.group)
|
||||
mx.eval(received_padded)
|
||||
class PatchedFirstLayer(cls):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if rank != 0:
|
||||
x = mx.distributed.recv_like(x, (rank - 1), group=group)
|
||||
return orig_call(self, x, *args, **kwargs)
|
||||
|
||||
# Slice off padding to get original data
|
||||
x = received_padded[..., :original_last_dim]
|
||||
mx.eval(x)
|
||||
logger.info(f"[recv] after slice: shape={x.shape}, sum={x.sum().item():.4f}, mean={x.mean().item():.4f}")
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
pipeline_layer.__class__ = PatchedFirstLayer
|
||||
|
||||
return pipeline_layer
|
||||
|
||||
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
r: int,
|
||||
s: int,
|
||||
group: mx.distributed.Group,
|
||||
):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.s: int = s
|
||||
self.group = group
|
||||
self.original_layer_signature = signature(self.original_layer.__call__)
|
||||
def patch_pipeline_last_layer(
|
||||
pipeline_layer: nn.Module, group: mx.distributed.Group
|
||||
) -> nn.Module:
|
||||
cls = type(pipeline_layer)
|
||||
orig_call = cast(Callable[..., mx.array], cls.__call__)
|
||||
orig_call_sig = signature(orig_call)
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
rank = group.rank()
|
||||
size = group.size()
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
class PatchedLastLayer(cls):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
|
||||
"cache", None
|
||||
)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
mx.eval(output)
|
||||
original_shape = output.shape
|
||||
original_last_dim = original_shape[-1]
|
||||
logger.info(f"[send] original shape={original_shape}, sum={output.sum().item():.4f}, mean={output.mean().item():.4f}")
|
||||
output: mx.array = orig_call(self, x, *args, **kwargs)
|
||||
|
||||
# Pad to 8192 on last dim (16KB for float16)
|
||||
padded_last_dim = 8192
|
||||
pad_size = padded_last_dim - original_last_dim
|
||||
padding = mx.zeros((*original_shape[:-1], pad_size), dtype=output.dtype)
|
||||
padded_output = mx.concatenate([output, padding], axis=-1)
|
||||
mx.eval(padded_output)
|
||||
logger.info(f"[send] padded shape={padded_output.shape}")
|
||||
if rank != size - 1:
|
||||
output = mx.distributed.send(output, (rank + 1) % size, group=group)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
sent = mx.distributed.send(padded_output, self.r + 1, group=self.group)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, sent) # type: ignore[reportUnknownMemberType]
|
||||
return output
|
||||
|
||||
return output
|
||||
pipeline_layer.__class__ = PatchedLastLayer
|
||||
|
||||
return pipeline_layer
|
||||
|
||||
|
||||
def _inner_model(model: nn.Module) -> nn.Module:
|
||||
@@ -185,13 +150,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
|
||||
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
|
||||
|
||||
|
||||
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[_LayerCallable]
|
||||
layers: list[nn.Module]
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
layers = cast(list[_LayerCallable], inner_model_instance.layers)
|
||||
layers = cast(list[nn.Module], inner_model_instance.layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
layers = cast(list[_LayerCallable], inner_model_instance.h)
|
||||
layers = cast(list[nn.Module], inner_model_instance.h)
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
@@ -216,15 +181,12 @@ def pipeline_auto_parallel(
|
||||
layers = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[0] = patch_pipeline_first_layer(layers[0], group)
|
||||
layers[-1] = patch_pipeline_last_layer(
|
||||
layers[-1],
|
||||
device_rank,
|
||||
world_size,
|
||||
group=group,
|
||||
group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
@@ -359,7 +321,7 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
@@ -408,6 +370,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -470,7 +433,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
@@ -545,17 +508,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
y = self.original_layer.__call__(x) # type: ignore
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -589,7 +552,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return model
|
||||
@@ -623,7 +586,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -636,17 +599,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class ShardedQwenMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
y = self.original_layer.__call__(x) # type: ignore
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
@@ -698,7 +661,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
y = self.original_layer(x) # type: ignore
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import functools
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models import cache as mlx_cache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
@@ -34,202 +30,6 @@ from exo.worker.runner.bootstrap import logger
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineGenerationResponse:
|
||||
"""Response from pipeline_stream_generate."""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
def pipeline_generate_step(
|
||||
prompt: mx.array,
|
||||
model: Model,
|
||||
*,
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
prefill_step_size: int = 2048,
|
||||
kv_bits: int | None = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
Synchronous generate_step for pipeline parallelism.
|
||||
No async_eval - everything is synchronous to ensure proper RDMA ordering.
|
||||
"""
|
||||
tokens = None
|
||||
|
||||
if prompt_cache is None:
|
||||
prompt_cache = mlx_cache.make_prompt_cache(model)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||
|
||||
def _step(y: mx.array) -> tuple[mx.array, mx.array]:
|
||||
nonlocal tokens
|
||||
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if logits_processors:
|
||||
tokens = mx.concatenate([tokens, y]) if tokens is not None else y
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
sampled = sampler(logprobs)
|
||||
|
||||
# Synchronous eval - critical for pipeline RDMA
|
||||
mx.eval(sampled, logprobs)
|
||||
|
||||
return sampled, logprobs.squeeze(0)
|
||||
|
||||
# === PREFILL PHASE ===
|
||||
# Process prompt in chunks, synchronously
|
||||
total_prompt_tokens = len(prompt)
|
||||
prompt_processed = 0
|
||||
|
||||
logger.info(f"[pipeline] Starting prefill: {total_prompt_tokens} tokens")
|
||||
|
||||
while total_prompt_tokens - prompt_processed > 1:
|
||||
remaining = (total_prompt_tokens - prompt_processed) - 1
|
||||
n_to_process = min(prefill_step_size, remaining)
|
||||
|
||||
chunk = prompt[:n_to_process]
|
||||
model(chunk[None], cache=prompt_cache)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
# Synchronous eval of cache state
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
|
||||
prompt_processed += n_to_process
|
||||
prompt = prompt[n_to_process:]
|
||||
|
||||
logger.info(f"[pipeline] Prefill progress: {prompt_processed}/{total_prompt_tokens}")
|
||||
mx.clear_cache()
|
||||
|
||||
# Process final token of prefill
|
||||
y, logprobs = _step(prompt)
|
||||
logger.info("[pipeline] Prefill complete, starting decode")
|
||||
|
||||
# === DECODE PHASE ===
|
||||
# Generate tokens one at a time, fully synchronous
|
||||
n = 0
|
||||
while True:
|
||||
if n == max_tokens:
|
||||
break
|
||||
|
||||
yield y.item(), logprobs
|
||||
|
||||
# Generate next token - synchronous
|
||||
y, logprobs = _step(y)
|
||||
|
||||
n += 1
|
||||
if n % 10 == 0:
|
||||
logger.info(f"[pipeline] Decode progress: {n}/{max_tokens}")
|
||||
if n % 256 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
|
||||
def pipeline_stream_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str | mx.array | list[int],
|
||||
max_tokens: int = 256,
|
||||
**kwargs: Any,
|
||||
) -> Generator[PipelineGenerationResponse, None, None]:
|
||||
"""
|
||||
Synchronous stream_generate for pipeline parallelism.
|
||||
Uses mx.eval instead of mx.async_eval to ensure proper RDMA synchronization.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if not isinstance(prompt, mx.array):
|
||||
if isinstance(prompt, str):
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
token_generator = pipeline_generate_step(prompt, model, **kwargs)
|
||||
|
||||
tic = time.perf_counter()
|
||||
prompt_time = 0.0
|
||||
prompt_tps = 0.0
|
||||
|
||||
for n, (token, logprobs) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
tic = time.perf_counter()
|
||||
|
||||
if token in tokenizer.eos_token_ids:
|
||||
detokenizer.finalize()
|
||||
yield PipelineGenerationResponse(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic) if n > 0 else 0,
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason="stop",
|
||||
)
|
||||
return
|
||||
|
||||
detokenizer.add_token(token)
|
||||
|
||||
if (n + 1) == max_tokens:
|
||||
detokenizer.finalize()
|
||||
yield PipelineGenerationResponse(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason="length",
|
||||
)
|
||||
return
|
||||
|
||||
yield PipelineGenerationResponse(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic) if n > 0 else 0,
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[KVCache | Any],
|
||||
quantized_kv_start: int,
|
||||
@@ -320,7 +120,6 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prompt: str,
|
||||
use_pipeline_generate: bool = False,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -346,37 +145,19 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
|
||||
# Use synchronous pipeline generate for pipeline parallelism (RDMA)
|
||||
if use_pipeline_generate:
|
||||
logger.info("[mlx_generate] Using synchronous pipeline_stream_generate for RDMA")
|
||||
generator = pipeline_stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
)
|
||||
else:
|
||||
generator = stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
)
|
||||
|
||||
for out in generator:
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -23,7 +23,6 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
@@ -297,7 +296,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
"""
|
||||
Get the EOS token IDs for a model based on its ID.
|
||||
|
||||
@@ -321,9 +320,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
return None
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(
|
||||
model_id: ModelId, model_path: Path
|
||||
) -> TokenizerWrapper:
|
||||
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
@@ -184,14 +184,11 @@ def main(
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
# Use synchronous generate only for JACCL/RDMA backend
|
||||
is_jaccl = isinstance(instance, MlxJacclInstance)
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
use_pipeline_generate=is_jaccl,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
|
||||
@@ -11,15 +11,14 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load, apply_chat_template
|
||||
|
||||
|
||||
class MockLayer(nn.Module):
|
||||
@@ -117,12 +116,11 @@ def run_gpt_oss_pipeline_device(
|
||||
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
@@ -184,11 +182,11 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
|
||||
@@ -10,8 +10,8 @@ import pytest
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
CustomMlxLayer,
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
patch_pipeline_first_layer,
|
||||
patch_pipeline_last_layer,
|
||||
patch_pipeline_model,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
|
||||
@@ -50,8 +50,8 @@ def run_pipeline_device(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
mock = MockLayerInner()
|
||||
first = PipelineFirstLayer(mock, r=rank, group=group)
|
||||
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
|
||||
first = patch_pipeline_first_layer(mock, group)
|
||||
composed = patch_pipeline_last_layer(first, group)
|
||||
|
||||
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
|
||||
inner_model = MockModel([composed])
|
||||
@@ -78,8 +78,8 @@ def test_composed_wrappers_delegate_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
group = mx.distributed.init()
|
||||
|
||||
first = PipelineFirstLayer(mock, r=0, group=group)
|
||||
composed = PipelineLastLayer(first, r=0, s=1, group=group)
|
||||
first = patch_pipeline_first_layer(mock, group)
|
||||
composed = patch_pipeline_last_layer(first, group)
|
||||
|
||||
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert composed.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def download_tokenizer_files(model_id: ModelId) -> Path:
|
||||
async def download_tokenizer_files(model_id: str) -> Path:
|
||||
"""Download only the tokenizer-related files for a model."""
|
||||
target_dir = await ensure_models_dir() / model_id.normalize()
|
||||
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
|
||||
@@ -72,22 +72,22 @@ async def download_tokenizer_files(model_id: ModelId) -> Path:
|
||||
|
||||
|
||||
# Get a sample of models to test (one per family to keep tests fast)
|
||||
def get_test_models() -> list[ModelCard]:
|
||||
def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, ModelCard] = {}
|
||||
for card in MODEL_CARDS.values():
|
||||
families: dict[str, tuple[str, ModelCard]] = {}
|
||||
for _, card in MODEL_CARDS.items():
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = card.model_id.short().split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
|
||||
if family not in families:
|
||||
families[family] = card
|
||||
families[family] = (card.model_id.short(), card)
|
||||
|
||||
return list(families.values())
|
||||
|
||||
|
||||
TEST_MODELS: list[ModelCard] = get_test_models()
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
@@ -101,13 +101,14 @@ def event_loop():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_card",
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode and decode text correctly."""
|
||||
model_id = model_card.model_id
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
# Download tokenizer files
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
@@ -166,15 +167,16 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_card",
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_has_required_attributes(
|
||||
short_id: str, model_card: ModelCard
|
||||
) -> None:
|
||||
"""Test that tokenizer has required attributes for inference."""
|
||||
model_id = model_card.model_id
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -207,18 +209,19 @@ async def test_tokenizer_has_required_attributes(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_card",
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
|
||||
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode text containing special tokens.
|
||||
|
||||
This is critical because the actual inference path uses prompts with
|
||||
special tokens from chat templates. If special tokens aren't handled
|
||||
correctly, encoding will fail.
|
||||
"""
|
||||
model_id = model_card.model_id
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -298,14 +301,16 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
|
||||
async def test_kimi_tokenizer_specifically():
|
||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||
kimi_models = [
|
||||
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "kimi" in short_id.lower()
|
||||
]
|
||||
|
||||
if not kimi_models:
|
||||
pytest.skip("No Kimi models found in MODEL_CARDS")
|
||||
|
||||
model_card = kimi_models[0]
|
||||
model_id = model_card.model_id
|
||||
_, model_card = kimi_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -344,15 +349,17 @@ async def test_kimi_tokenizer_specifically():
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
glm_model_cards = [
|
||||
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
|
||||
glm_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "glm" in short_id.lower()
|
||||
]
|
||||
|
||||
if not glm_model_cards:
|
||||
if not glm_models:
|
||||
pytest.skip("No GLM models found in MODEL_CARDS")
|
||||
|
||||
model_card = glm_model_cards[0]
|
||||
model_id = model_card.model_id
|
||||
_, model_card = glm_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import ModelId, NodeId
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
|
||||
12
uv.lock
generated
12
uv.lock
generated
@@ -248,7 +248,6 @@ dependencies = [
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
@@ -282,7 +281,6 @@ requires-dist = [
|
||||
{ name = "pydantic", specifier = ">=2.11.7" },
|
||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||
{ name = "tiktoken", specifier = ">=0.12.0" },
|
||||
{ name = "tomlkit", specifier = ">=0.14.0" },
|
||||
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
|
||||
]
|
||||
|
||||
@@ -317,16 +315,6 @@ dev = [
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
|
||||
]
|
||||
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
|
||||
Reference in New Issue
Block a user