mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
196 lines
6.6 KiB
Python
196 lines
6.6 KiB
Python
"""
|
|
This type stub file was generated by pyright.
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
|
|
|
import mlx.nn as nn
|
|
from transformers.utils.auto_docstring import ModelArgs
|
|
|
|
from .tokenizer_utils import TokenizerWrapper
|
|
|
|
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": ...
|
|
else: ...
|
|
MODEL_REMAPPING = ...
|
|
MAX_FILE_SIZE_GB = ...
|
|
|
|
def compute_bits_per_weight(model): ...
|
|
def hf_repo_to_path(hf_repo): # -> Path:
|
|
...
|
|
def load_config(model_path: Path) -> dict: ...
|
|
def load_model(
|
|
model_path: Path,
|
|
lazy: bool = False,
|
|
strict: bool = True,
|
|
model_config: dict[str, Any] = {},
|
|
get_model_classes: Callable[
|
|
[dict[str, Any]], Tuple[Type[nn.Module], Type[ModelArgs]]
|
|
] = ...,
|
|
) -> Tuple[nn.Module, dict[str, Any]]:
|
|
"""
|
|
Load and initialize the model from a given path.
|
|
|
|
Args:
|
|
model_path (Path): The path to load the model from.
|
|
lazy (bool): If False eval the model parameters to make sure they are
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
when needed. Default: ``False``
|
|
strict (bool): Whether or not to raise an exception if weights don't
|
|
match. Default: ``True``
|
|
model_config (dict, optional): Optional configuration parameters for the
|
|
model. Defaults to an empty dictionary.
|
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
|
A function that returns the model class and model args class given a config.
|
|
Defaults to the ``_get_classes`` function.
|
|
|
|
Returns:
|
|
Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
|
"""
|
|
|
|
def load(
|
|
path_or_hf_repo: str,
|
|
tokenizer_config=...,
|
|
model_config=...,
|
|
adapter_path: Optional[str] = ...,
|
|
lazy: bool = ...,
|
|
return_config: bool = ...,
|
|
revision: str = ...,
|
|
) -> Union[
|
|
Tuple[nn.Module, TokenizerWrapper],
|
|
Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]],
|
|
]:
|
|
"""
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
|
|
|
Args:
|
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
Defaults to an empty dictionary.
|
|
model_config(dict, optional): Configuration parameters specifically for the model.
|
|
Defaults to an empty dictionary.
|
|
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
to the model. Default: ``None``.
|
|
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
when needed. Default: ``False``
|
|
return_config (bool: If ``True`` return the model config as the last item..
|
|
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
|
Returns:
|
|
Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]:
|
|
A tuple containing the loaded model, tokenizer and, if requested, the model config.
|
|
|
|
Raises:
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
ValueError: If model class or args class are not found.
|
|
"""
|
|
|
|
def make_shards(weights: dict, max_file_size_gb: int = ...) -> list:
|
|
"""
|
|
Splits the weights into smaller shards.
|
|
|
|
Args:
|
|
weights (dict): Model weights.
|
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
|
|
Returns:
|
|
list: List of weight shards.
|
|
"""
|
|
|
|
def create_model_card(
|
|
path: Union[str, Path], hf_path: Union[str, Path, None]
|
|
): # -> None:
|
|
"""
|
|
Uploads the model to Hugging Face hub.
|
|
|
|
Args:
|
|
path (Union[str, Path]): Local path to the model.
|
|
hf_path (Union[str, Path, None]): Path to the original Hugging Face model.
|
|
"""
|
|
|
|
def upload_to_hub(path: str, upload_repo: str): # -> None:
|
|
"""
|
|
Uploads the model to Hugging Face hub.
|
|
|
|
Args:
|
|
path (str): Local path to the model.
|
|
upload_repo (str): Name of the HF repo to upload to.
|
|
"""
|
|
|
|
def save_model(
|
|
save_path: Union[str, Path], model: nn.Module, *, donate_model: bool = ...
|
|
) -> None:
|
|
"""Save model weights and metadata index into specified directory."""
|
|
|
|
def quantize_model(
|
|
model: nn.Module,
|
|
config: dict,
|
|
group_size: int,
|
|
bits: int,
|
|
mode: str = ...,
|
|
quant_predicate: Optional[Callable[[str, nn.Module], Union[bool, dict]]] = ...,
|
|
) -> Tuple[nn.Module, dict]:
|
|
"""
|
|
Applies quantization to the model weights.
|
|
|
|
Args:
|
|
model (nn.Module): The model to be quantized.
|
|
config (dict): Model configuration.
|
|
group_size (int): Group size for quantization.
|
|
bits (int): Bits per weight for quantization.
|
|
mode (str): The quantization mode.
|
|
quant_predicate (Callable): A callable that decides how to quantize
|
|
each layer based on the path. Accepts the layer `path` and the
|
|
`module`. Returns either a bool to signify quantize/no quantize or
|
|
a dict of quantization parameters to pass to `to_quantized`.
|
|
|
|
Returns:
|
|
Tuple: Tuple containing quantized model and config.
|
|
"""
|
|
|
|
def save_config(config: dict, config_path: Union[str, Path]) -> None:
|
|
"""Save the model configuration to the ``config_path``.
|
|
|
|
The final configuration will be sorted before saving for better readability.
|
|
|
|
Args:
|
|
config (dict): The model configuration.
|
|
config_path (Union[str, Path]): Model configuration file path.
|
|
"""
|
|
|
|
def save(
|
|
dst_path: Union[str, Path],
|
|
src_path_or_repo: Union[str, Path],
|
|
model: nn.Module,
|
|
tokenizer: TokenizerWrapper,
|
|
config: Dict[str, Any],
|
|
donate_model: bool = ...,
|
|
): # -> None:
|
|
...
|
|
def common_prefix_len(list1, list2): # -> int:
|
|
"""
|
|
Calculates the length of the common prefix of two lists.
|
|
|
|
Args:
|
|
list1: The first list of strings.
|
|
list2: The second list of strings.
|
|
|
|
Returns:
|
|
The length of the common prefix. Returns 0 if lists are empty
|
|
or do not match at the first element.
|
|
"""
|
|
|
|
def does_model_support_input_embeddings(model: nn.Module) -> bool:
|
|
"""
|
|
Check if the model supports input_embeddings in its call signature.
|
|
Args:
|
|
model (nn.Module): The model to check.
|
|
Returns:
|
|
bool: True if the model supports input_embeddings, False otherwise.
|
|
"""
|