Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
ec8c963f2b remove sharding that has been upstreamed into mlx-lm
LLama, Ministral, GPT-OSS, MiniMax and DeepSeek all have their tensor
sharding upstream. We can remove the sharding code from exo.
2026-01-20 16:04:40 +00:00
19 changed files with 350 additions and 676 deletions

View File

@@ -24,7 +24,6 @@ dependencies = [
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
]
[project.scripts]

View File

@@ -19,11 +19,8 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -44,7 +41,7 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -73,7 +70,7 @@ from exo.utils.event_buffer import OrderedBuffer
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -85,25 +82,16 @@ def chunk_to_response(
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
if isinstance(chunk, TokenChunk)
else StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=[tool.model_dump() for tool in chunk.tool_calls],
),
finish_reason="tool_calls",
)
],
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
async def resolve_model_card(model_id: str) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
else:
return await ModelCard.from_hf(model_id)
return await get_model_card(model_id)
class API:
@@ -147,9 +135,7 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ToolCallChunk]
] = {}
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -250,7 +236,7 @@ class API:
async def get_placement(
self,
model_id: ModelId,
model_id: str,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
@@ -415,21 +401,16 @@ class API:
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk | ToolCallChunk, None]:
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[
TokenChunk | ToolCallChunk
]()
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if (
isinstance(chunk, TokenChunk)
and chunk.finish_reason is not None
):
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
@@ -451,7 +432,7 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, TokenChunk) and chunk.finish_reason == "error":
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
@@ -470,7 +451,7 @@ class API:
yield f"data: {chunk_response.model_dump_json()}\n\n"
if isinstance(chunk, ToolCallChunk) or chunk.finish_reason is not None:
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
@@ -483,24 +464,6 @@ class API:
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ToolCallChunk):
finish_reason = "tool_calls"
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model or chunk.model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
tool_calls=[
tool.model_dump() for tool in chunk.tool_calls
],
),
)
],
)
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -544,11 +507,6 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ToolCallChunk):
raise HTTPException(
status_code=500,
detail="Tool call in bench",
)
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -593,7 +551,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
if not any(
@@ -620,7 +578,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
if not any(
@@ -699,7 +657,7 @@ class API:
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, (TokenChunk, ToolCallChunk))
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:

View File

@@ -1,18 +1,16 @@
from typing import Annotated
from pydantic import PositiveInt
import aiofiles
import aiofiles.os as aios
import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt
from exo.shared.types.common import ModelId
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
_card_cache: dict[str, "ModelCard"] = {}
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class ModelCard(CamelCaseModel):
@@ -22,43 +20,6 @@ class ModelCard(CamelCaseModel):
hidden_size: PositiveInt
supports_tensor: bool
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
)
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
@@ -347,99 +308,3 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
}
from exo.worker.download.download_utils import ( # noqa: E402
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
@property
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],
["Qwen3MoeForCausalLM"],
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)

View File

@@ -0,0 +1,122 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_card(model_id: str) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class PlaceInstanceParams(BaseModel):
model_id: ModelId
model_id: str
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1

View File

@@ -5,7 +5,6 @@ from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .worker.runner_response import ToolCallItem
class ChunkType(str, Enum):
@@ -26,12 +25,8 @@ class TokenChunk(BaseChunk):
error_message: str | None = None
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
class ImageChunk(BaseChunk):
data: bytes
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -25,14 +25,6 @@ class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -1,8 +1,3 @@
from datetime import timedelta
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
@@ -47,50 +42,3 @@ class DownloadOngoing(BaseDownloadProgress):
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)

View File

@@ -1,5 +1,5 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
from exo.utils.pydantic_ext import TaggedModel
class BaseRunnerResponse(TaggedModel):
@@ -10,11 +10,6 @@ class TokenizedResponse(BaseRunnerResponse):
prompt_tokens: int
class ToolCallItem(CamelCaseModel):
arguments: str
name: str
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
@@ -23,9 +18,5 @@ class GenerationResponse(BaseRunnerResponse):
stats: GenerationStats | None = None
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -17,20 +17,17 @@ import aiohttp
import certifi
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
Field,
PositiveInt,
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
DownloadProgressData,
FileListEntry,
ModelSafetensorsIndex,
RepoDownloadProgress,
RepoFileDownloadProgress,
)
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -40,6 +37,53 @@ from exo.worker.download.huggingface_utils import (
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
@@ -81,12 +125,12 @@ def map_repo_download_progress_to_download_progress_data(
)
def build_model_path(model_id: ModelId) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.normalize()
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.replace("/", "--")
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
return (await ensure_models_dir()) / model_id.normalize()
async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def ensure_models_dir() -> Path:
@@ -94,8 +138,8 @@ async def ensure_models_dir() -> Path:
return EXO_MODELS_DIR
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
@@ -120,17 +164,19 @@ async def seed_models(seed_dir: str | Path):
async def fetch_file_list_with_cache(
model_id: ModelId, revision: str = "main", recursive: bool = False
repo_id: str, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
@@ -138,25 +184,25 @@ async def fetch_file_list_with_cache(
async def fetch_file_list_with_retry(
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
return await _fetch_file_list(repo_id, revision, path, recursive)
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
)
async def _fetch_file_list(
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_download_headers()
@@ -173,7 +219,7 @@ async def _fetch_file_list(
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
model_id, revision, item.path, recursive
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
@@ -230,10 +276,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
async def file_meta(
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
repo_id: str, revision: str, path: str, redirected_location: str | None = None
) -> tuple[int, str]:
url = (
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
@@ -252,7 +298,7 @@ async def file_meta(
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(model_id, revision, path, redirected_location)
return await file_meta(repo_id, revision, path, redirected_location)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -264,7 +310,7 @@ async def file_meta(
async def download_file_with_retry(
model_id: ModelId,
repo_id: str,
revision: str,
path: str,
target_dir: Path,
@@ -274,23 +320,23 @@ async def download_file_with_retry(
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress
repo_id, revision, path, target_dir, on_progress
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
)
async def _download_file(
model_id: ModelId,
repo_id: str,
revision: str,
path: str,
target_dir: Path,
@@ -299,7 +345,7 @@ async def _download_file(
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
length, etag = await file_meta(repo_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
@@ -308,7 +354,7 @@ async def _download_file(
else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_download_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
@@ -348,7 +394,7 @@ async def _download_file(
def calculate_repo_progress(
shard: ShardMetadata,
model_id: ModelId,
repo_id: str,
revision: str,
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
@@ -377,7 +423,7 @@ def calculate_repo_progress(
else "not_started"
)
return RepoDownloadProgress(
repo_id=model_id,
repo_id=repo_id,
repo_revision=revision,
shard=shard,
completed_files=len(
@@ -396,11 +442,11 @@ def calculate_repo_progress(
)
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / model_id.normalize()
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
model_id, revision, "model.safetensors.index.json", target_dir
repo_id, revision, "model.safetensors.index.json", target_dir
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
@@ -431,6 +477,53 @@ async def get_downloaded_size(path: Path) -> int:
return 0
async def download_progress_for_local_path(
repo_id: str, shard: ShardMetadata, local_path: Path
) -> RepoDownloadProgress:
file_progress: dict[str, RepoFileDownloadProgress] = {}
total_files = 0
total_bytes = 0
if await aios.path.isdir(local_path):
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
file_path = Path(root) / f
size = (await aios.stat(file_path)).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
repo_revision="local",
file_path=rel_path,
downloaded=Memory.from_bytes(size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(size),
speed=0,
eta=timedelta(0),
status="complete",
start_time=time.time(),
)
total_files += 1
total_bytes += size
else:
raise ValueError(f"Local path {local_path} is not a directory")
return RepoDownloadProgress(
repo_id=repo_id,
repo_revision="local",
shard=shard,
completed_files=total_files,
total_files=total_files,
downloaded_bytes=Memory.from_bytes(total_bytes),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(total_bytes),
overall_speed=0,
overall_eta=timedelta(0),
status="complete",
file_progress=file_progress,
)
async def download_shard(
shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -441,6 +534,14 @@ async def download_shard(
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
"/", "--"
@@ -451,14 +552,13 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
shard.model_card.model_id, revision, recursive=True
str(shard.model_card.model_id), revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -492,7 +592,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_id=str(shard.model_card.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -509,7 +609,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
shard.model_card.model_id,
str(shard.model_card.model_id),
revision,
file_progress,
all_start_time,
@@ -519,7 +619,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_id=str(shard.model_card.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -543,7 +643,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
shard.model_card.model_id,
str(shard.model_card.model_id),
revision,
file.path,
target_dir,
@@ -557,7 +657,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, shard.model_card.model_id, revision, file_progress, all_start_time
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -3,7 +3,8 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -18,8 +19,8 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
)
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
async def build_base_shard(model_id: str) -> ShardMetadata:
model_card = await get_model_card(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -30,7 +31,7 @@ async def build_base_shard(model_id: ModelId) -> ShardMetadata:
)
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
@@ -147,7 +148,7 @@ class ResumableShardDownloader(ShardDownloader):
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
async def _status_for_model(
model_id: ModelId,
model_id: str,
) -> tuple[Path, RepoDownloadProgress]:
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)

View File

@@ -13,17 +13,11 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
@@ -83,11 +77,11 @@ class CustomMlxLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, self["_original_layer"])
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -96,7 +90,7 @@ class CustomMlxLayer(nn.Module):
try:
return super().__getattr__(name)
except AttributeError:
original_layer = cast(_LayerCallable, self["_original_layer"])
original_layer = object.__getattribute__(self, "_original_layer")
return getattr(original_layer, name)
@@ -334,40 +328,14 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, MiniMaxModel):
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
if isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -375,14 +343,6 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -417,34 +377,6 @@ class TensorParallelShardingStrategy(ABC):
) -> nn.Module: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.n_heads //= self.N
if layer.self_attn.n_kv_heads is not None:
layer.self_attn.n_kv_heads //= self.N
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
@@ -471,105 +403,6 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
else:
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
layer.self_attn.kv_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
return model
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.block_sparse_moe.switch_mlp.down_proj
)
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -622,58 +455,3 @@ class ShardedQwenMoE(CustomMlxLayer):
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn.num_key_value_groups = (
layer.self_attn.num_attention_heads
// layer.self_attn.num_key_value_heads
)
layer.self_attn.sinks = layer.self_attn.sinks[
layer.self_attn.num_attention_heads
* self.group.rank() : layer.self_attn.num_attention_heads
* (self.group.rank() + 1)
]
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
class ShardedGptOssMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -23,7 +23,6 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
@@ -297,7 +296,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
@@ -321,9 +320,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
return None
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.

View File

@@ -449,7 +449,7 @@ class Worker:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug("Fetching and emitting existing download progress...")
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
@@ -480,7 +480,7 @@ class Worker:
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug("Done emitting existing download progress.")
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -1,8 +1,6 @@
import json
import time
from collections.abc import Generator
from functools import cache
from typing import Any, Callable, cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -13,10 +11,9 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -36,8 +33,6 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallItem,
ToolCallResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -207,13 +202,7 @@ def main(
mlx_generator, tokenizer
)
# TODO: Add call parser here
if (
tokenizer.tool_parser # pyright: ignore[reportAny]
or tokenizer.tool_call_start
or tokenizer.tool_call_end
) is not None:
mlx_generator = parse_tool_calls(mlx_generator, tokenizer)
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
@@ -223,7 +212,7 @@ def main(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token, # hang on --- is this wrong?? this is totally wrong
idx=response.token,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
@@ -232,18 +221,6 @@ def main(
),
)
)
case ToolCallResponse():
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ToolCallChunk(
idx=4,
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
),
)
)
# can we make this more explicit?
except Exception as e:
@@ -352,50 +329,6 @@ def parse_thinking_models(
yield response
def parse_tool_calls(
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse | ToolCallResponse]:
assert tokenizer.tool_call_start is not None
assert tokenizer.tool_call_end is not None
tool_parser = cast(
Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None,
tokenizer.tool_parser,
) # first arg has followup args, but we don't care about them rn
assert tool_parser is not None
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
# assumption: the tool call start is one token
if response.text == tokenizer.tool_call_start:
in_tool_call = True
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
continue
# assumption: the tool call end is one token
if response.text == tokenizer.tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
if isinstance(parsed, list):
tools = [
ToolCallItem.model_validate_json(json.dumps(tool))
for tool in parsed
]
else:
tools = [ToolCallItem.model_validate_json(json.dumps(parsed))]
yield ToolCallResponse(tool_calls=tools)
except (json.JSONDecodeError, ValidationError):
in_tool_call = False
tool_call_text_parts = []
# fallthrough
yield response
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -11,9 +11,8 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
return False
async def download_tokenizer_files(model_id: ModelId) -> Path:
async def download_tokenizer_files(model_id: str) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.normalize()
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
@@ -72,22 +72,22 @@ async def download_tokenizer_files(model_id: ModelId) -> Path:
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[ModelCard]:
def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
families: dict[str, tuple[str, ModelCard]] = {}
for _, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = card
families[family] = (card.model_id.short(), card)
return list(families.values())
TEST_MODELS: list[ModelCard] = get_test_models()
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
pytestmark = pytest.mark.slow
@@ -101,13 +101,14 @@ def event_loop():
@pytest.mark.parametrize(
"model_card",
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = model_card.model_id
model_id = str(model_card.model_id)
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
@@ -166,15 +167,16 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
@pytest.mark.parametrize(
"model_card",
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = model_card.model_id
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
@@ -207,18 +209,19 @@ async def test_tokenizer_has_required_attributes(
@pytest.mark.parametrize(
"model_card",
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = model_card.model_id
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
@@ -298,14 +301,16 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
model_card = kimi_models[0]
model_id = model_card.model_id
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
@@ -344,15 +349,17 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
]
if not glm_model_cards:
if not glm_models:
pytest.skip("No GLM models found in MODEL_CARDS")
model_card = glm_model_cards[0]
model_id = model_card.model_id
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import ModelId, NodeId
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress

12
uv.lock generated
View File

@@ -248,7 +248,6 @@ dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
@@ -282,7 +281,6 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
]
@@ -317,16 +315,6 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"