mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
MLX LM type stubs
This commit is contained in:
1
.github/scripts/bench.py
vendored
1
.github/scripts/bench.py
vendored
@@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# type: ignore
|
||||
"""
|
||||
Unified benchmark script for EXO.
|
||||
Runs single or multi-stage benchmarks with configurable load patterns.
|
||||
|
||||
@@ -2614,7 +2614,7 @@ type MX_ARRAY_TREE = (
|
||||
| Mapping[str, MX_ARRAY_TREE]
|
||||
)
|
||||
|
||||
def eval(*args: MX_ARRAY_TREE) -> None:
|
||||
def eval(*args: MX_ARRAY_TREE | None) -> None:
|
||||
"""
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
2
.mlx_typings/mlx_lm/__init__.pyi
Normal file
2
.mlx_typings/mlx_lm/__init__.pyi
Normal file
@@ -0,0 +1,2 @@
|
||||
import models as models
|
||||
import tokenizer_utils as tokenizer_utils
|
||||
45
.mlx_typings/mlx_lm/convert.pyi
Normal file
45
.mlx_typings/mlx_lm/convert.pyi
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
def mixed_quant_predicate_builder(
|
||||
recipe: str, model: nn.Module, group_size: int = ...
|
||||
) -> Callable[[str, nn.Module, dict], Union[bool, dict]]: ...
|
||||
|
||||
QUANT_RECIPES = ...
|
||||
MODEL_CONVERSION_DTYPES = ...
|
||||
|
||||
def convert(
|
||||
hf_path: str,
|
||||
mlx_path: str = ...,
|
||||
quantize: bool = ...,
|
||||
q_group_size: int = ...,
|
||||
q_bits: int = ...,
|
||||
q_mode: str = ...,
|
||||
dtype: Optional[str] = ...,
|
||||
upload_repo: str = ...,
|
||||
revision: Optional[str] = ...,
|
||||
dequantize: bool = ...,
|
||||
quant_predicate: Optional[
|
||||
Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str]
|
||||
] = ...,
|
||||
trust_remote_code: bool = ...,
|
||||
): # -> None:
|
||||
...
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
|
||||
def main(): # -> None:
|
||||
...
|
||||
|
||||
if __name__ == "__main__": ...
|
||||
324
.mlx_typings/mlx_lm/generate.pyi
Normal file
324
.mlx_typings/mlx_lm/generate.pyi
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
|
||||
DEFAULT_PROMPT = ...
|
||||
DEFAULT_MAX_TOKENS = ...
|
||||
DEFAULT_TEMP = ...
|
||||
DEFAULT_TOP_P = ...
|
||||
DEFAULT_MIN_P = ...
|
||||
DEFAULT_TOP_K = ...
|
||||
DEFAULT_XTC_PROBABILITY = ...
|
||||
DEFAULT_XTC_THRESHOLD = ...
|
||||
DEFAULT_MIN_TOKENS_TO_KEEP = ...
|
||||
DEFAULT_SEED = ...
|
||||
DEFAULT_MODEL = ...
|
||||
DEFAULT_QUANTIZED_KV_START = ...
|
||||
|
||||
def str2bool(string): # -> bool:
|
||||
...
|
||||
def setup_arg_parser(): # -> ArgumentParser:
|
||||
"""Set up and return the argument parser."""
|
||||
|
||||
generation_stream = ...
|
||||
|
||||
@contextlib.contextmanager
|
||||
def wired_limit(
|
||||
model: nn.Module, streams: Optional[List[mx.Stream]] = ...
|
||||
): # -> Generator[None, Any, None]:
|
||||
"""
|
||||
A context manager to temporarily change the wired limit.
|
||||
|
||||
Note, the wired limit should not be changed during an async eval. If an
|
||||
async eval could be running pass in the streams to synchronize with prior
|
||||
to exiting the context manager.
|
||||
"""
|
||||
@dataclass
|
||||
class GenerationResponse:
|
||||
"""
|
||||
The output of :func:`stream_generate`.
|
||||
|
||||
Args:
|
||||
text (str): The next segment of decoded text. This can be an empty string.
|
||||
token (int): The next token.
|
||||
from_draft (bool): Whether the token was generated by the draft model.
|
||||
logprobs (mx.array): A vector of log probabilities.
|
||||
prompt_tokens (int): The number of tokens in the prompt.
|
||||
prompt_tps (float): The prompt processing tokens-per-second.
|
||||
generation_tokens (int): The number of generated tokens.
|
||||
generation_tps (float): The tokens-per-second for generation.
|
||||
peak_memory (float): The peak memory used so far in GB.
|
||||
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
|
||||
"""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
from_draft: bool
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: Optional[str] = ...
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
): # -> None:
|
||||
...
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
*,
|
||||
max_tokens: int = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
|
||||
max_kv_size: Optional[int] = ...,
|
||||
prompt_cache: Optional[Any] = ...,
|
||||
prefill_step_size: int = ...,
|
||||
kv_bits: Optional[int] = ...,
|
||||
kv_group_size: int = ...,
|
||||
quantized_kv_start: int = ...,
|
||||
prompt_progress_callback: Optional[Callable[[int], int]] = ...,
|
||||
input_embeddings: Optional[mx.array] = ...,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
prompt_progress_callback (Callable[[int], int]): A call-back which takes the
|
||||
prompt tokens processed so far and the total number of prompt tokens.
|
||||
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
|
||||
conjunction with prompt tokens. Default: ``None``.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
def speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
draft_model: nn.Module,
|
||||
*,
|
||||
num_draft_tokens: int = ...,
|
||||
max_tokens: int = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
|
||||
prompt_cache: Optional[Any] = ...,
|
||||
prefill_step_size: int = ...,
|
||||
kv_bits: Optional[int] = ...,
|
||||
kv_group_size: int = ...,
|
||||
quantized_kv_start: int = ...,
|
||||
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
draft_model (nn.Module): The draft model for speculative decoding.
|
||||
num_draft_tokens (int, optional): The number of draft tokens for
|
||||
speculative decoding. Default: ``2``.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
sampler (Callable[[mx.array], mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place. The cache must be trimmable.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
|
||||
and a bool indicating if the token was generated by the draft model
|
||||
"""
|
||||
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: Union[str, mx.array, List[int]],
|
||||
max_tokens: int = ...,
|
||||
draft_model: Optional[nn.Module] = ...,
|
||||
**kwargs,
|
||||
) -> Generator[GenerationResponse, None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to use for generation.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or
|
||||
integer tokens.
|
||||
max_tokens (int): The maximum number of tokens to generate.
|
||||
Default: ``256``.
|
||||
draft_model (Optional[nn.Module]): An optional draft model. If provided
|
||||
then speculative decoding is used. The draft model must use the same
|
||||
tokenizer as the main model. Default: ``None``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
Yields:
|
||||
GenerationResponse: An instance containing the generated text segment and
|
||||
associated metadata. See :class:`GenerationResponse` for details.
|
||||
"""
|
||||
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: Union[str, List[int]],
|
||||
verbose: bool = ...,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a complete response from the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||
See :func:`stream_generate` for more details.
|
||||
"""
|
||||
@dataclass
|
||||
class BatchStats:
|
||||
"""
|
||||
An data object to hold generation stats.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of prompt tokens processed.
|
||||
prompt_tps (float): The prompt processing tokens-per-second.
|
||||
prompt_time (float): The time in seconds spent in prompt processing.
|
||||
generation_tokens (int): The number of generated tokens.
|
||||
generation_tps (float): The tokens-per-second for generation.
|
||||
generation_time (float): The time in seconds spent in generation .
|
||||
peak_memory (float): The peak memory used so far in GB.
|
||||
"""
|
||||
|
||||
prompt_tokens: int = ...
|
||||
prompt_tps: float = ...
|
||||
prompt_time: float = ...
|
||||
generation_tokens: int = ...
|
||||
generation_tps: float = ...
|
||||
generation_time: float = ...
|
||||
peak_memory: float = ...
|
||||
|
||||
@dataclass
|
||||
class BatchResponse:
|
||||
"""
|
||||
An data object to hold a batch generation response.
|
||||
|
||||
Args:
|
||||
texts: (List[str]): The generated text for each prompt.
|
||||
stats (BatchStats): Statistics about the generation.
|
||||
"""
|
||||
|
||||
texts: List[str]
|
||||
stats: BatchStats
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
uids: List[int]
|
||||
y: mx.array
|
||||
logprobs: mx.array
|
||||
max_tokens: List[int]
|
||||
num_tokens: List[int]
|
||||
cache: List[Any]
|
||||
def __len__(self): # -> int:
|
||||
...
|
||||
def filter(self, keep_idx: List[int]): # -> None:
|
||||
...
|
||||
def extend(self, other): # -> None:
|
||||
...
|
||||
|
||||
class BatchGenerator:
|
||||
@dataclass
|
||||
class Response:
|
||||
uid: int
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompts: List[int],
|
||||
max_tokens: Union[int, List[int]] = ...,
|
||||
verbose: bool = ...,
|
||||
**kwargs,
|
||||
) -> BatchResponse:
|
||||
"""
|
||||
Generate responses for the given batch of prompts.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (List[List[int]]): The input prompts.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
max_tokens (Union[int, List[int]): Maximum number of output tokens. This
|
||||
can be per prompt if a list is provided.
|
||||
kwargs: The remaining options get passed to :obj:`BatchGenerator`.
|
||||
See :obj:`BatchGenerator` for more details.
|
||||
"""
|
||||
|
||||
def main(): # -> None:
|
||||
...
|
||||
|
||||
if __name__ == "__main__": ...
|
||||
1
.mlx_typings/mlx_lm/models/__init__.pyi
Normal file
1
.mlx_typings/mlx_lm/models/__init__.pyi
Normal file
@@ -0,0 +1 @@
|
||||
import cache as cache
|
||||
47
.mlx_typings/mlx_lm/models/base.pyi
Normal file
47
.mlx_typings/mlx_lm/models/base.pyi
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@dataclass
|
||||
class BaseModelArgs:
|
||||
@classmethod
|
||||
def from_dict(cls, params): # -> Self:
|
||||
...
|
||||
|
||||
def create_causal_mask(
|
||||
N: int,
|
||||
offset: int = ...,
|
||||
window_size: Optional[int] = ...,
|
||||
right_padding: Optional[mx.array] = ...,
|
||||
left_padding: Optional[mx.array] = ...,
|
||||
): # -> array:
|
||||
...
|
||||
def create_attention_mask(
|
||||
h, cache=..., window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
def create_ssm_mask(h, cache=...): # -> None:
|
||||
...
|
||||
def quantized_scaled_dot_product_attention(
|
||||
queries: mx.array,
|
||||
q_keys: tuple[mx.array, mx.array, mx.array],
|
||||
q_values: tuple[mx.array, mx.array, mx.array],
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
group_size: int = ...,
|
||||
bits: int = ...,
|
||||
) -> mx.array: ...
|
||||
def scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache,
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
sinks: Optional[mx.array] = ...,
|
||||
) -> mx.array: ...
|
||||
26
.mlx_typings/mlx_lm/models/bitlinear_layers.pyi
Normal file
26
.mlx_typings/mlx_lm/models/bitlinear_layers.pyi
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
def bitnet_quantize(model, quantization_config: dict): ...
|
||||
def make_bitlinear_kernel():
|
||||
"""
|
||||
Custom Metal kernel that performs matrix multiplication directly on
|
||||
packed weights and scales the output. This eliminates the need to
|
||||
store unpacked weights in memory.
|
||||
"""
|
||||
|
||||
_bitlinear_kernel = ...
|
||||
|
||||
class BitLinear(nn.Module):
|
||||
"""
|
||||
BitLinear module with memory-efficient weight handling.
|
||||
"""
|
||||
def __init__(
|
||||
self, in_features, out_features, bias=..., invert_weight_scales=...
|
||||
) -> None: ...
|
||||
def execute_matmul_kernel(self, x, packed_weights): ...
|
||||
def __call__(self, x): # -> array:
|
||||
...
|
||||
354
.mlx_typings/mlx_lm/models/cache.pyi
Normal file
354
.mlx_typings/mlx_lm/models/cache.pyi
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx.core import array
|
||||
|
||||
def make_prompt_cache(
|
||||
model: nn.Module, max_kv_size: Optional[int] = ...
|
||||
) -> List[KVCache | Any]:
|
||||
"""
|
||||
Construct the model's cache for use in generation.
|
||||
|
||||
This function will defer the cache construction to the model if it has a
|
||||
``make_cache`` method, otherwise it will make a default KV cache.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
"""
|
||||
|
||||
def save_prompt_cache(
|
||||
file_name: str, cache: List[Any], metadata: Dict[str, str] = ...
|
||||
) -> None:
|
||||
"""
|
||||
Save a pre-computed prompt cache to a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
cache (List[Any]): The model state.
|
||||
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||
state.
|
||||
"""
|
||||
|
||||
def load_prompt_cache(
|
||||
file_name, return_metadata=...
|
||||
): # -> tuple[list[Any], Any] | list[Any]:
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
return_metadata (bool): Whether or not to return metadata.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
the metadata if requested.
|
||||
"""
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
|
||||
This function will trim the cache if possible (in-place) and return the
|
||||
number of tokens that were trimmed.
|
||||
|
||||
Args:
|
||||
cache (List[Any]): The model's cache.
|
||||
num_tokens (int): The number of tokens to trim.
|
||||
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
|
||||
def create_attention_mask(
|
||||
N: int, offset: int, return_array: bool, window_size: Optional[int]
|
||||
): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class _BaseCache:
|
||||
@property
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): # -> Literal['']:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[False]:
|
||||
...
|
||||
@classmethod
|
||||
def from_state(cls, state, meta_state): # -> Self:
|
||||
...
|
||||
|
||||
class ConcatenateKVCache(_BaseCache):
|
||||
"""ConcatenateKVCache the simplest KV cache implementation.
|
||||
|
||||
Can be used as a mock KV cache or when large blocks are being processed at
|
||||
a time in which case KVCache isn't necessarily faster. Consider using the
|
||||
KVCache with a larger step size before using this cache.
|
||||
"""
|
||||
def __init__(self) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
|
||||
...
|
||||
@property
|
||||
def state(self): # -> tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
step = ...
|
||||
def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> Any:
|
||||
...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
step = ...
|
||||
def __init__(self) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||
...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
) -> tuple[array, array]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
step = ...
|
||||
def __init__(self, max_size, keep=...) -> None: ...
|
||||
def update_and_fetch(
|
||||
self, keys, values
|
||||
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||
...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(
|
||||
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
class ArraysCache(_BaseCache):
|
||||
def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||
def __setitem__(self, idx, value): # -> None:
|
||||
...
|
||||
def __getitem__(self, idx): ...
|
||||
@property
|
||||
def state(self): # -> list[Any | array] | list[array]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
|
||||
def make_mask(self, N: int): # -> array | None:
|
||||
...
|
||||
|
||||
class MambaCache(ArraysCache):
|
||||
def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||
|
||||
class ChunkedKVCache(KVCache):
|
||||
def __init__(self, chunk_size) -> None: ...
|
||||
def maybe_trim_front(self): # -> None:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
|
||||
class CacheList(_BaseCache):
|
||||
def __init__(self, *caches) -> None: ...
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): ...
|
||||
@property
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
|
||||
class BatchKVCache(_BaseCache):
|
||||
step = ...
|
||||
def __init__(self, left_padding: List[int]) -> None:
|
||||
"""
|
||||
The BatchKV cache expects inputs to be left-padded.
|
||||
|
||||
E.g. the following prompts:
|
||||
|
||||
[1, 3, 5]
|
||||
[7]
|
||||
[2, 6, 8, 9]
|
||||
|
||||
Should be padded like so:
|
||||
|
||||
[0, 1, 3, 5]
|
||||
[0, 0, 0, 7]
|
||||
[2, 6, 8, 9]
|
||||
|
||||
And ``left_padding`` specifies the amount of padding for each.
|
||||
In this case, ``left_padding = [1, 3, 0]``.
|
||||
"""
|
||||
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||
...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int | float:
|
||||
...
|
||||
def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
|
||||
class BatchRotatingKVCache(_BaseCache):
|
||||
step = ...
|
||||
def __init__(self, max_size, left_padding: List[int]) -> None: ...
|
||||
def update_and_fetch(
|
||||
self, keys, values
|
||||
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||
...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||
...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(
|
||||
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): # -> array:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
79
.mlx_typings/mlx_lm/models/switch_layers.pyi
Normal file
79
.mlx_typings/mlx_lm/models/switch_layers.pyi
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class QuantizedSwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
bias: bool = ...,
|
||||
group_size: int = ...,
|
||||
bits: int = ...,
|
||||
mode: str = ...,
|
||||
) -> None: ...
|
||||
@property
|
||||
def input_dims(self): # -> int:
|
||||
...
|
||||
@property
|
||||
def output_dims(self): # -> int:
|
||||
...
|
||||
@property
|
||||
def num_experts(self): # -> int:
|
||||
...
|
||||
def __call__(self, x, indices, sorted_indices=...): # -> array:
|
||||
...
|
||||
|
||||
class SwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = ...
|
||||
) -> None: ...
|
||||
@property
|
||||
def input_dims(self): # -> int:
|
||||
...
|
||||
@property
|
||||
def output_dims(self): # -> int:
|
||||
...
|
||||
@property
|
||||
def num_experts(self): # -> int:
|
||||
...
|
||||
def __call__(self, x, indices, sorted_indices=...): ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ..., mode: str = ...
|
||||
): # -> QuantizedSwitchLinear:
|
||||
...
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def swiglu(x, gate): ...
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def __init__(self) -> None: ...
|
||||
def __call__(self, x, gate): ...
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=...,
|
||||
bias: bool = ...,
|
||||
) -> None: ...
|
||||
def __call__(self, x, indices) -> mx.array: ...
|
||||
|
||||
class SwitchMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=...,
|
||||
bias: bool = ...,
|
||||
) -> None: ...
|
||||
def __call__(self, x, indices) -> mx.array: ...
|
||||
148
.mlx_typings/mlx_lm/sample_utils.pyi
Normal file
148
.mlx_typings/mlx_lm/sample_utils.pyi
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def make_sampler(
|
||||
temp: float = ...,
|
||||
top_p: float = ...,
|
||||
min_p: float = ...,
|
||||
min_tokens_to_keep: int = ...,
|
||||
top_k: int = ...,
|
||||
xtc_probability: float = ...,
|
||||
xtc_threshold: float = ...,
|
||||
xtc_special_tokens: List[int] = ...,
|
||||
) -> Callable[[mx.array], mx.array]:
|
||||
"""
|
||||
Make a sampler function for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
top_k (int, optional): The top k tokens ranked by probability to constrain
|
||||
the sampling to.
|
||||
xtc_probability (float, optional): The probability of applying XTC
|
||||
sampling.
|
||||
xtc_threshold (float, optional): The threshold the probs need to reach
|
||||
for being sampled.
|
||||
xtc_special_tokens (list(int), optional): List of special tokens IDs to
|
||||
be excluded from XTC sampling.
|
||||
|
||||
|
||||
Returns:
|
||||
Callable[mx.array, mx.array]:
|
||||
A sampler which takes log-probabilities and returns tokens.
|
||||
"""
|
||||
|
||||
def make_logits_processors(
|
||||
logit_bias: Optional[Dict[int, float]] = ...,
|
||||
repetition_penalty: Optional[float] = ...,
|
||||
repetition_context_size: Optional[int] = ...,
|
||||
): # -> list[Any]:
|
||||
"""
|
||||
Make logits processors for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
|
||||
Returns:
|
||||
List[Callable[[mx.array, mx.array], mx.array]]:
|
||||
A list of logits processors. Each processor in the list is a
|
||||
callable which takes an array of tokens and an array of logits
|
||||
and returns the updated logits.
|
||||
"""
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_top_k(logprobs: mx.array, top_k: int) -> mx.array:
|
||||
"""
|
||||
Sample from only the top K tokens ranked by probability.
|
||||
|
||||
Args:
|
||||
logprobs: A vector of log probabilities.
|
||||
top_k (int): Top k tokens to sample from.
|
||||
"""
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_min_p(
|
||||
logprobs: mx.array, min_p: float, min_tokens_to_keep: int = ...
|
||||
) -> mx.array:
|
||||
"""
|
||||
Apply min-p sampling to the logprobs.
|
||||
|
||||
Min-p keeps all tokens that are above a minimum probability, scaled by the
|
||||
probability of the most likely token. As a result, the filter is more
|
||||
aggressive given a very high-probability token.
|
||||
|
||||
Args:
|
||||
logprobs: A vector of log probabilities.
|
||||
min_p (float): Minimum token probability. Typical values are in the
|
||||
0.01-0.2 range, comparably selective as setting `top_p` in the
|
||||
0.99-0.8 range.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered. Default: ``1``.
|
||||
|
||||
"""
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_top_p(logprobs: mx.array, top_p: float) -> mx.array:
|
||||
"""
|
||||
Apply top-p (nucleus) sampling to logits.
|
||||
|
||||
Args:
|
||||
logprobs: A vector of log probabilities.
|
||||
top_p: The cumulative probability threshold for top-p filtering.
|
||||
Returns:
|
||||
token selected based on the top-p criterion.
|
||||
"""
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_xtc(
|
||||
logits: mx.array,
|
||||
xtc_probability: float,
|
||||
xtc_threshold: float,
|
||||
xtc_special_tokens: List[int],
|
||||
) -> mx.array:
|
||||
"""
|
||||
Apply XTC sampling to the logits.
|
||||
|
||||
Args:
|
||||
logits: The logits from the model's output.
|
||||
xtc_probability (float): Probability of XTC sampling to happen for each token
|
||||
xtc_threshold (float): The threshold the probs need to reach for being sampled.
|
||||
special_tokens_ids (list(int)): List of special tokens IDs to be excluded from XTC sampling.
|
||||
"""
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def categorical_sampling(logits, temp): # -> array:
|
||||
...
|
||||
def make_repetition_penalty(
|
||||
penalty: float, context_size: int = ...
|
||||
): # -> Callable[..., Any]:
|
||||
"""
|
||||
Make repetition penalty processor.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
context_size (int): The number of previous tokens to use.
|
||||
Default: ``20``.
|
||||
|
||||
Returns:
|
||||
Callable[[mx.array, List[int]], mx.array]:
|
||||
The repetition penalty processor.
|
||||
"""
|
||||
168
.mlx_typings/mlx_lm/tokenizer_utils.pyi
Normal file
168
.mlx_typings/mlx_lm/tokenizer_utils.pyi
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
class StreamingDetokenizer:
|
||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||
|
||||
Example usage is as follows:
|
||||
|
||||
detokenizer = ...
|
||||
|
||||
# Reset the tokenizer state
|
||||
detokenizer.reset()
|
||||
|
||||
for token in generate(...):
|
||||
detokenizer.add_token(token.item())
|
||||
|
||||
# Contains the whole text so far. Some tokens may not be included
|
||||
# since it contains whole words usually.
|
||||
detokenizer.text
|
||||
|
||||
# Contains the printable segment (usually a word) since the last
|
||||
# time it was accessed
|
||||
detokenizer.last_segment
|
||||
|
||||
# Contains all the tokens added so far
|
||||
detokenizer.tokens
|
||||
|
||||
# Make sure that we detokenize any remaining tokens
|
||||
detokenizer.finalize()
|
||||
|
||||
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
|
||||
"""
|
||||
|
||||
__slots__ = ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
@property
|
||||
def last_segment(self):
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
implementation and should work with every tokenizer.
|
||||
|
||||
Its complexity is O(T^2) where T is the longest line since it will
|
||||
repeatedly detokenize the same tokens until a new line is generated.
|
||||
"""
|
||||
def __init__(self, tokenizer) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
@property
|
||||
def text(self): # -> str:
|
||||
...
|
||||
|
||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for SPM models.
|
||||
|
||||
It adds tokens to the text if the next token starts with the special SPM
|
||||
underscore which results in linear complexity.
|
||||
"""
|
||||
def __init__(self, tokenizer, trim_space=...) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
|
||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for OpenAI style BPE models.
|
||||
|
||||
It adds tokens to the text if the next token starts with a space similar to
|
||||
the SPM detokenizer.
|
||||
"""
|
||||
|
||||
_byte_decoder = ...
|
||||
_space_matches = ...
|
||||
def __init__(self, tokenizer) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
@classmethod
|
||||
def make_byte_decoder(cls): # -> None:
|
||||
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||
|
||||
class TokenizerWrapper:
|
||||
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
||||
|
||||
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||
huggingface tokenizer.
|
||||
"""
|
||||
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
|
||||
def add_eos_token(self, token: str): # -> None:
|
||||
...
|
||||
@property
|
||||
def has_thinking(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def think_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def think_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def has_tool_calling(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def tool_call_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def tool_call_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def detokenizer(self): # -> NaiveStreamingDetokenizer:
|
||||
"""
|
||||
Get a stateful streaming detokenizer.
|
||||
"""
|
||||
|
||||
def __getattr__(self, attr): # -> set[Any] | Any:
|
||||
...
|
||||
def __setattr__(self, attr, value): # -> None:
|
||||
...
|
||||
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def encode(self, text, **kwargs): # -> list[int]:
|
||||
...
|
||||
def encode_batch(self, texts, **kwargs): ...
|
||||
def decode(self, *args, **kwargs): # -> str:
|
||||
...
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
|
||||
def load_tokenizer(
|
||||
model_path: Path,
|
||||
tokenizer_config_extra=...,
|
||||
return_tokenizer=...,
|
||||
eos_token_ids=...,
|
||||
) -> (
|
||||
TokenizerWrapper
|
||||
| type[SPMStreamingDetokenizer]
|
||||
| partial[SPMStreamingDetokenizer]
|
||||
| type[BPEStreamingDetokenizer]
|
||||
| type[NaiveStreamingDetokenizer]
|
||||
):
|
||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||
detokenizer to use.
|
||||
|
||||
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
|
||||
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
|
||||
195
.mlx_typings/mlx_lm/utils.pyi
Normal file
195
.mlx_typings/mlx_lm/utils.pyi
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import mlx.nn as nn
|
||||
from transformers.utils.auto_docstring import ModelArgs
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
|
||||
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": ...
|
||||
else: ...
|
||||
MODEL_REMAPPING = ...
|
||||
MAX_FILE_SIZE_GB = ...
|
||||
|
||||
def compute_bits_per_weight(model): ...
|
||||
def hf_repo_to_path(hf_repo): # -> Path:
|
||||
...
|
||||
def load_config(model_path: Path) -> dict: ...
|
||||
def load_model(
|
||||
model_path: Path,
|
||||
lazy: bool = False,
|
||||
strict: bool = True,
|
||||
model_config: dict[str, Any] = {},
|
||||
get_model_classes: Callable[
|
||||
[dict[str, Any]], Tuple[Type[nn.Module], Type[ModelArgs]]
|
||||
] = ...,
|
||||
) -> Tuple[nn.Module, dict[str, Any]]:
|
||||
"""
|
||||
Load and initialize the model from a given path.
|
||||
|
||||
Args:
|
||||
model_path (Path): The path to load the model from.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
strict (bool): Whether or not to raise an exception if weights don't
|
||||
match. Default: ``True``
|
||||
model_config (dict, optional): Optional configuration parameters for the
|
||||
model. Defaults to an empty dictionary.
|
||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||
A function that returns the model class and model args class given a config.
|
||||
Defaults to the ``_get_classes`` function.
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||
"""
|
||||
|
||||
def load(
|
||||
path_or_hf_repo: str,
|
||||
tokenizer_config=...,
|
||||
model_config=...,
|
||||
adapter_path: Optional[str] = ...,
|
||||
lazy: bool = ...,
|
||||
return_config: bool = ...,
|
||||
revision: str = ...,
|
||||
) -> Union[
|
||||
Tuple[nn.Module, TokenizerWrapper],
|
||||
Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||
Defaults to an empty dictionary.
|
||||
model_config(dict, optional): Configuration parameters specifically for the model.
|
||||
Defaults to an empty dictionary.
|
||||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||
to the model. Default: ``None``.
|
||||
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
return_config (bool: If ``True`` return the model config as the last item..
|
||||
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
||||
Returns:
|
||||
Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]:
|
||||
A tuple containing the loaded model, tokenizer and, if requested, the model config.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
ValueError: If model class or args class are not found.
|
||||
"""
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gb: int = ...) -> list:
|
||||
"""
|
||||
Splits the weights into smaller shards.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||
|
||||
Returns:
|
||||
list: List of weight shards.
|
||||
"""
|
||||
|
||||
def create_model_card(
|
||||
path: Union[str, Path], hf_path: Union[str, Path, None]
|
||||
): # -> None:
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]): Local path to the model.
|
||||
hf_path (Union[str, Path, None]): Path to the original Hugging Face model.
|
||||
"""
|
||||
|
||||
def upload_to_hub(path: str, upload_repo: str): # -> None:
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (str): Local path to the model.
|
||||
upload_repo (str): Name of the HF repo to upload to.
|
||||
"""
|
||||
|
||||
def save_model(
|
||||
save_path: Union[str, Path], model: nn.Module, *, donate_model: bool = ...
|
||||
) -> None:
|
||||
"""Save model weights and metadata index into specified directory."""
|
||||
|
||||
def quantize_model(
|
||||
model: nn.Module,
|
||||
config: dict,
|
||||
group_size: int,
|
||||
bits: int,
|
||||
mode: str = ...,
|
||||
quant_predicate: Optional[Callable[[str, nn.Module], Union[bool, dict]]] = ...,
|
||||
) -> Tuple[nn.Module, dict]:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be quantized.
|
||||
config (dict): Model configuration.
|
||||
group_size (int): Group size for quantization.
|
||||
bits (int): Bits per weight for quantization.
|
||||
mode (str): The quantization mode.
|
||||
quant_predicate (Callable): A callable that decides how to quantize
|
||||
each layer based on the path. Accepts the layer `path` and the
|
||||
`module`. Returns either a bool to signify quantize/no quantize or
|
||||
a dict of quantization parameters to pass to `to_quantized`.
|
||||
|
||||
Returns:
|
||||
Tuple: Tuple containing quantized model and config.
|
||||
"""
|
||||
|
||||
def save_config(config: dict, config_path: Union[str, Path]) -> None:
|
||||
"""Save the model configuration to the ``config_path``.
|
||||
|
||||
The final configuration will be sorted before saving for better readability.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
config_path (Union[str, Path]): Model configuration file path.
|
||||
"""
|
||||
|
||||
def save(
|
||||
dst_path: Union[str, Path],
|
||||
src_path_or_repo: Union[str, Path],
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
config: Dict[str, Any],
|
||||
donate_model: bool = ...,
|
||||
): # -> None:
|
||||
...
|
||||
def common_prefix_len(list1, list2): # -> int:
|
||||
"""
|
||||
Calculates the length of the common prefix of two lists.
|
||||
|
||||
Args:
|
||||
list1: The first list of strings.
|
||||
list2: The second list of strings.
|
||||
|
||||
Returns:
|
||||
The length of the common prefix. Returns 0 if lists are empty
|
||||
or do not match at the first element.
|
||||
"""
|
||||
|
||||
def does_model_support_input_embeddings(model: nn.Module) -> bool:
|
||||
"""
|
||||
Check if the model supports input_embeddings in its call signature.
|
||||
Args:
|
||||
model (nn.Module): The model to check.
|
||||
Returns:
|
||||
bool: True if the model supports input_embeddings, False otherwise.
|
||||
"""
|
||||
2
justfile
2
justfile
@@ -1,5 +1,5 @@
|
||||
fmt:
|
||||
uv run ruff format src typings
|
||||
uv run ruff format src .mlx_typings
|
||||
|
||||
lint:
|
||||
uv run ruff check --fix src
|
||||
|
||||
@@ -81,7 +81,7 @@ build-backend = "uv_build"
|
||||
###
|
||||
|
||||
[tool.basedpyright]
|
||||
include = [".venv/lib/mlx", "src"]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
|
||||
typeCheckingMode = "strict"
|
||||
failOnWarnings = true
|
||||
|
||||
@@ -97,8 +97,8 @@ reportUnnecessaryTypeIgnoreComment = "error"
|
||||
pythonVersion = "3.13"
|
||||
pythonPlatform = "Darwin"
|
||||
|
||||
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust"]
|
||||
stubPath = "typings"
|
||||
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust", "**/.github"]
|
||||
stubPath = ".mlx_typings"
|
||||
|
||||
[[tool.basedpyright.executionEnvironments]]
|
||||
root = "src"
|
||||
|
||||
@@ -9,10 +9,10 @@ from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer # type: ignore
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
from mlx_lm.utils import load_model # type: ignore
|
||||
from mlx_lm.utils import load_model
|
||||
from pydantic import RootModel
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -167,12 +167,11 @@ def shard_and_load(
|
||||
f"loading model from {model_path} with strategy {model_shard_meta.strategy}"
|
||||
)
|
||||
|
||||
model, config = load_model(model_path, lazy=True, strict=False) # type: ignore
|
||||
model, config = load_model(model_path, lazy=True, strict=False)
|
||||
runner_print(f"{config=}")
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
tokenizer = load_tokenizer(model_path) # type: ignore
|
||||
tokenizer = cast(TokenizerWrapper, tokenizer)
|
||||
tokenizer = cast(TokenizerWrapper, load_tokenizer(model_path))
|
||||
|
||||
runner_print(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata:
|
||||
)
|
||||
|
||||
|
||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata | None:
|
||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
|
||||
base_shard = await build_base_shard(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_meta=base_shard.model_meta,
|
||||
@@ -150,11 +150,9 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
# print("get_shard_download_status")
|
||||
async def _status_for_model(
|
||||
model_id: str,
|
||||
) -> tuple[Path, RepoDownloadProgress] | None:
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
||||
shard = await build_full_shard(model_id)
|
||||
if shard is None:
|
||||
return None
|
||||
return await download_shard(
|
||||
shard, self.on_progress_wrapper, skip_download=True
|
||||
)
|
||||
@@ -168,8 +166,6 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await task
|
||||
if result is None:
|
||||
continue
|
||||
path, progress = result
|
||||
yield (path, progress)
|
||||
except Exception as e:
|
||||
|
||||
@@ -35,16 +35,18 @@ generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[Any],
|
||||
prompt_cache: list[KVCache | Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache): # type: ignore[type-arg]
|
||||
if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start: # type: ignore[type-arg]
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) # type: ignore[type-arg]
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
|
||||
|
||||
def generate_step(
|
||||
@@ -189,7 +191,7 @@ def generate_step(
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
start_time = time.time()
|
||||
mx.eval([c.state for c in prompt_cache]) # type: ignore
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
eval_time = time.time() - start_time
|
||||
prompt_processed_tokens += n_to_process
|
||||
|
||||
@@ -221,9 +223,17 @@ def generate_step(
|
||||
n = 0
|
||||
|
||||
while True:
|
||||
mx.eval(y, logprobs)
|
||||
assert y is not None
|
||||
assert logprobs is not None
|
||||
if n != max_tokens:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
if n == 0:
|
||||
mx.eval(y)
|
||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||
if n == max_tokens:
|
||||
break
|
||||
yield int(y.item()), logprobs
|
||||
n += 1
|
||||
if n % 256 == 0:
|
||||
mx.clear_cache()
|
||||
if n == max_tokens:
|
||||
|
||||
Reference in New Issue
Block a user