Compare commits

..

4 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
541339aae6 Dont warn on single node jaccl placement 2026-01-20 18:28:51 +00:00
rltakashige
758464703d Fix GPT OSS tensor sharding with upstream MLX LM (#1223)
## Motivation
MLX LM has given GPT OSS a shard method, but MLX does not have an update
to match.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 18:24:54 +00:00
rltakashige
9e2179c848 Register original layer in CustomMlxLayer (#1229)
## Motivation
Kimi K2 Thinking Pipeline RDMA was broken before.

## Why It Works
No clue tbh

## Test Plan

### Manual Testing
Kimi K2 Thinking and GPT OSS work at the same time on Pipeline RDMA.
Needs exo bench to check more thoroughly

### Automated Testing
Layer composition tests still pass.
2026-01-20 18:20:01 +00:00
Evan Quiney
22b5d836ef swap all instances of model_id: str for model_id: ModelId (#1221)
This change uses the stronger typed ModelId, and introduces some
convenience methods. It also cleans up some code left over from #1204.

## Changes

`model_id: str -> model_id: ModelId`
`repo_id: str -> model_id: ModelId`

Introduces methods on ModelId, in particular ModelId.normalize() to
replace `/` with `--`.

This PR did introduce some circular imports, so has moved some code
around to try and limit them.

## Test Plan

Tests still pass, types still check. As this is about metadata, I
haven't tested inference.
2026-01-20 17:38:06 +00:00
17 changed files with 406 additions and 355 deletions

View File

@@ -24,6 +24,7 @@ dependencies = [
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
]
[project.scripts]

View File

@@ -19,8 +19,11 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
)
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -86,12 +89,12 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: str) -> ModelCard:
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
else:
return await get_model_card(model_id)
return await ModelCard.from_hf(model_id)
class API:
@@ -236,7 +239,7 @@ class API:
async def get_placement(
self,
model_id: str,
model_id: ModelId,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
@@ -551,7 +554,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(payload.model)
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -578,7 +581,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(payload.model)
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(

View File

@@ -120,7 +120,7 @@ def place_instance(
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
logger.debug(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
)

View File

@@ -1,16 +1,18 @@
from pydantic import PositiveInt
from typing import Annotated
from exo.shared.types.common import Id
import aiofiles
import aiofiles.os as aios
import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
_card_cache: dict[str, "ModelCard"] = {}
class ModelCard(CamelCaseModel):
@@ -20,6 +22,43 @@ class ModelCard(CamelCaseModel):
hidden_size: PositiveInt
supports_tensor: bool
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
)
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
@@ -308,3 +347,99 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
}
from exo.worker.download.download_utils import ( # noqa: E402
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
@property
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],
["Qwen3MoeForCausalLM"],
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)

View File

@@ -1,122 +0,0 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_card(model_id: str) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class PlaceInstanceParams(BaseModel):
model_id: str
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1

View File

@@ -25,6 +25,14 @@ class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -1,3 +1,8 @@
from datetime import timedelta
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
@@ -42,3 +47,50 @@ class DownloadOngoing(BaseDownloadProgress):
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)

View File

@@ -17,17 +17,20 @@ import aiohttp
import certifi
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
Field,
PositiveInt,
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.downloads import (
DownloadProgressData,
FileListEntry,
ModelSafetensorsIndex,
RepoDownloadProgress,
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -37,53 +40,6 @@ from exo.worker.download.huggingface_utils import (
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
@@ -125,12 +81,12 @@ def map_repo_download_progress_to_download_progress_data(
)
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.replace("/", "--")
def build_model_path(model_id: ModelId) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.normalize()
async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
return (await ensure_models_dir()) / model_id.normalize()
async def ensure_models_dir() -> Path:
@@ -138,8 +94,8 @@ async def ensure_models_dir() -> Path:
return EXO_MODELS_DIR
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
@@ -164,19 +120,17 @@ async def seed_models(seed_dir: str | Path):
async def fetch_file_list_with_cache(
repo_id: str, revision: str = "main", recursive: bool = False
model_id: ModelId, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
@@ -184,25 +138,25 @@ async def fetch_file_list_with_cache(
async def fetch_file_list_with_retry(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(repo_id, revision, path, recursive)
return await _fetch_file_list(model_id, revision, path, recursive)
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
async def _fetch_file_list(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_download_headers()
@@ -219,7 +173,7 @@ async def _fetch_file_list(
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
@@ -276,10 +230,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
async def file_meta(
repo_id: str, revision: str, path: str, redirected_location: str | None = None
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
) -> tuple[int, str]:
url = (
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
@@ -298,7 +252,7 @@ async def file_meta(
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
return await file_meta(model_id, revision, path, redirected_location)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -310,7 +264,7 @@ async def file_meta(
async def download_file_with_retry(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -320,23 +274,23 @@ async def download_file_with_retry(
for attempt in range(n_attempts):
try:
return await _download_file(
repo_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
async def _download_file(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -345,7 +299,7 @@ async def _download_file(
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(repo_id, revision, path)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
@@ -354,7 +308,7 @@ async def _download_file(
else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
headers = await get_download_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
@@ -394,7 +348,7 @@ async def _download_file(
def calculate_repo_progress(
shard: ShardMetadata,
repo_id: str,
model_id: ModelId,
revision: str,
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
@@ -423,7 +377,7 @@ def calculate_repo_progress(
else "not_started"
)
return RepoDownloadProgress(
repo_id=repo_id,
repo_id=model_id,
repo_revision=revision,
shard=shard,
completed_files=len(
@@ -442,11 +396,11 @@ def calculate_repo_progress(
)
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
model_id, revision, "model.safetensors.index.json", target_dir
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
@@ -504,7 +458,7 @@ async def download_shard(
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_card.model_id), revision, recursive=True
shard.model_card.model_id, revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -538,7 +492,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -555,7 +509,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_card.model_id),
shard.model_card.model_id,
revision,
file_progress,
all_start_time,
@@ -565,7 +519,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -589,7 +543,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
str(shard.model_card.model_id),
shard.model_card.model_id,
revision,
file.path,
target_dir,
@@ -603,7 +557,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
shard, shard.model_card.model_id, revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -3,8 +3,7 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -19,8 +18,8 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
)
async def build_base_shard(model_id: str) -> ShardMetadata:
model_card = await get_model_card(model_id)
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -31,7 +30,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata:
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
@@ -148,7 +147,7 @@ class ResumableShardDownloader(ShardDownloader):
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
async def _status_for_model(
model_id: str,
model_id: ModelId,
) -> tuple[Path, RepoDownloadProgress]:
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Protocol, cast
import mlx.core as mx
import mlx.nn as nn
@@ -67,16 +67,27 @@ def eval_with_timeout(
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
@property
def original_layer(self) -> nn.Module:
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, self["_original_layer"])
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -85,57 +96,56 @@ class CustomMlxLayer(nn.Module):
try:
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
original_layer = cast(_LayerCallable, self["_original_layer"])
return getattr(original_layer, name)
def patch_pipeline_first_layer(
pipeline_layer: nn.Module, group: mx.distributed.Group
) -> nn.Module:
cls = type(pipeline_layer)
orig_call = cast(Callable[..., mx.array], cls.__call__)
class PipelineFirstLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
rank = group.rank()
class PatchedFirstLayer(cls):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(self, x, *args, **kwargs)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
return self.original_layer(x, *args, **kwargs)
def patch_pipeline_last_layer(
pipeline_layer: nn.Module, group: mx.distributed.Group
) -> nn.Module:
cls = type(pipeline_layer)
orig_call = cast(Callable[..., mx.array], cls.__call__)
orig_call_sig = signature(orig_call)
class PipelineLastLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
s: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
rank = group.rank()
size = group.size()
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
class PatchedLastLayer(cls):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = orig_call_sig.bind_partial(x, *args, **kwargs).arguments.get(
"cache", None
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output: mx.array = orig_call(self, x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(output, (rank + 1) % size, group=group)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
return output
def _inner_model(model: nn.Module) -> nn.Module:
@@ -150,13 +160,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
# Handle both model.layers and model.h cases
layers: list[nn.Module]
layers: list[_LayerCallable]
if hasattr(inner_model_instance, "layers"):
layers = cast(list[nn.Module], inner_model_instance.layers)
layers = cast(list[_LayerCallable], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"):
layers = cast(list[nn.Module], inner_model_instance.h)
layers = cast(list[_LayerCallable], inner_model_instance.h)
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -181,12 +191,15 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = patch_pipeline_last_layer(
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],
group,
device_rank,
world_size,
group=group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
@@ -321,7 +334,7 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard"):
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
@@ -370,7 +383,6 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -433,7 +445,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
@@ -508,17 +520,17 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) # type: ignore
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
@@ -552,7 +564,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
@@ -586,7 +598,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -599,17 +611,17 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) # type: ignore
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class GptOssShardingStrategy(TensorParallelShardingStrategy):
@@ -661,7 +673,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x) # type: ignore
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -23,6 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
@@ -296,7 +297,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
@@ -320,7 +321,9 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.

View File

@@ -11,14 +11,15 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load, apply_chat_template
from exo.worker.engines.mlx.utils_mlx import shard_and_load
class MockLayer(nn.Module):
@@ -116,11 +117,12 @@ def run_gpt_oss_pipeline_device(
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
generated_text = ""
for response in mlx_generate(
model=model, tokenizer=tokenizer, task=task, prompt=prompt
model=model,
tokenizer=tokenizer,
task=task,
):
generated_text += response.text
if response.finish_reason is not None:
@@ -182,11 +184,11 @@ def run_gpt_oss_tensor_parallel_device(
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
generated_text = ""
for response in mlx_generate(
model=model, tokenizer=tokenizer, task=task, prompt=prompt
model=model,
tokenizer=tokenizer,
task=task,
):
generated_text += response.text
if response.finish_reason is not None:

View File

@@ -10,8 +10,8 @@ import pytest
from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer,
patch_pipeline_first_layer,
patch_pipeline_last_layer,
PipelineFirstLayer,
PipelineLastLayer,
patch_pipeline_model,
)
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
@@ -50,8 +50,8 @@ def run_pipeline_device(
group = mx.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = patch_pipeline_first_layer(mock, group)
composed = patch_pipeline_last_layer(first, group)
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
inner_model = MockModel([composed])
@@ -78,8 +78,8 @@ def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer()
group = mx.distributed.init()
first = patch_pipeline_first_layer(mock, group)
composed = patch_pipeline_last_layer(first, group)
first = PipelineFirstLayer(mock, r=0, group=group)
composed = PipelineLastLayer(first, r=0, s=1, group=group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined]

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
return False
async def download_tokenizer_files(model_id: str) -> Path:
async def download_tokenizer_files(model_id: ModelId) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir = await ensure_models_dir() / model_id.normalize()
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
@@ -72,22 +72,22 @@ async def download_tokenizer_files(model_id: str) -> Path:
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
def get_test_models() -> list[ModelCard]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for _, card in MODEL_CARDS.items():
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (card.model_id.short(), card)
families[family] = card
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
TEST_MODELS: list[ModelCard] = get_test_models()
pytestmark = pytest.mark.slow
@@ -101,14 +101,13 @@ def event_loop():
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
@@ -167,16 +166,15 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -209,19 +207,18 @@ async def test_tokenizer_has_required_attributes(
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -301,16 +298,14 @@ async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) ->
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_card = kimi_models[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -349,17 +344,15 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
]
if not glm_models:
if not glm_model_cards:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_card = glm_model_cards[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)

View File

@@ -1,6 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress

12
uv.lock generated
View File

@@ -248,6 +248,7 @@ dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
@@ -281,6 +282,7 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
]
@@ -315,6 +317,16 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"