MLX LM type stubs

This commit is contained in:
rltakashige
2025-11-06 13:59:29 -08:00
committed by GitHub
parent 19e90572e6
commit ff00b165c5
45 changed files with 1417 additions and 24 deletions

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python3
# type: ignore
"""
Unified benchmark script for EXO.
Runs single or multi-stage benchmarks with configurable load patterns.

View File

View File

@@ -2614,7 +2614,7 @@ type MX_ARRAY_TREE = (
| Mapping[str, MX_ARRAY_TREE]
)
def eval(*args: MX_ARRAY_TREE) -> None:
def eval(*args: MX_ARRAY_TREE | None) -> None:
"""
Evaluate an :class:`array` or tree of :class:`array`.

View File

@@ -0,0 +1,2 @@
import models as models
import tokenizer_utils as tokenizer_utils

View File

@@ -0,0 +1,45 @@
"""
This type stub file was generated by pyright.
"""
import argparse
from typing import Callable, Optional, Union
import mlx.nn as nn
def mixed_quant_predicate_builder(
recipe: str, model: nn.Module, group_size: int = ...
) -> Callable[[str, nn.Module, dict], Union[bool, dict]]: ...
QUANT_RECIPES = ...
MODEL_CONVERSION_DTYPES = ...
def convert(
hf_path: str,
mlx_path: str = ...,
quantize: bool = ...,
q_group_size: int = ...,
q_bits: int = ...,
q_mode: str = ...,
dtype: Optional[str] = ...,
upload_repo: str = ...,
revision: Optional[str] = ...,
dequantize: bool = ...,
quant_predicate: Optional[
Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str]
] = ...,
trust_remote_code: bool = ...,
): # -> None:
...
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
def main(): # -> None:
...
if __name__ == "__main__": ...

View File

@@ -0,0 +1,324 @@
"""
This type stub file was generated by pyright.
"""
import contextlib
from dataclasses import dataclass
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from transformers import PreTrainedTokenizer
from .tokenizer_utils import TokenizerWrapper
DEFAULT_PROMPT = ...
DEFAULT_MAX_TOKENS = ...
DEFAULT_TEMP = ...
DEFAULT_TOP_P = ...
DEFAULT_MIN_P = ...
DEFAULT_TOP_K = ...
DEFAULT_XTC_PROBABILITY = ...
DEFAULT_XTC_THRESHOLD = ...
DEFAULT_MIN_TOKENS_TO_KEEP = ...
DEFAULT_SEED = ...
DEFAULT_MODEL = ...
DEFAULT_QUANTIZED_KV_START = ...
def str2bool(string): # -> bool:
...
def setup_arg_parser(): # -> ArgumentParser:
"""Set up and return the argument parser."""
generation_stream = ...
@contextlib.contextmanager
def wired_limit(
model: nn.Module, streams: Optional[List[mx.Stream]] = ...
): # -> Generator[None, Any, None]:
"""
A context manager to temporarily change the wired limit.
Note, the wired limit should not be changed during an async eval. If an
async eval could be running pass in the streams to synchronize with prior
to exiting the context manager.
"""
@dataclass
class GenerationResponse:
"""
The output of :func:`stream_generate`.
Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
from_draft (bool): Whether the token was generated by the draft model.
logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB.
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
"""
text: str
token: int
logprobs: mx.array
from_draft: bool
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
finish_reason: Optional[str] = ...
def maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
): # -> None:
...
def generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
max_kv_size: Optional[int] = ...,
prompt_cache: Optional[Any] = ...,
prefill_step_size: int = ...,
kv_bits: Optional[int] = ...,
kv_group_size: int = ...,
quantized_kv_start: int = ...,
prompt_progress_callback: Optional[Callable[[int], int]] = ...,
input_embeddings: Optional[mx.array] = ...,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
prompt_progress_callback (Callable[[int], int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
conjunction with prompt tokens. Default: ``None``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
draft_model: nn.Module,
*,
num_draft_tokens: int = ...,
max_tokens: int = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
prompt_cache: Optional[Any] = ...,
prefill_step_size: int = ...,
kv_bits: Optional[int] = ...,
kv_group_size: int = ...,
quantized_kv_start: int = ...,
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
draft_model (nn.Module): The draft model for speculative decoding.
num_draft_tokens (int, optional): The number of draft tokens for
speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[[mx.array], mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
and a bool indicating if the token was generated by the draft model
"""
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = ...,
draft_model: Optional[nn.Module] = ...,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or
integer tokens.
max_tokens (int): The maximum number of tokens to generate.
Default: ``256``.
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
GenerationResponse: An instance containing the generated text segment and
associated metadata. See :class:`GenerationResponse` for details.
"""
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, List[int]],
verbose: bool = ...,
**kwargs,
) -> str:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
See :func:`stream_generate` for more details.
"""
@dataclass
class BatchStats:
"""
An data object to hold generation stats.
Args:
prompt_tokens (int): The number of prompt tokens processed.
prompt_tps (float): The prompt processing tokens-per-second.
prompt_time (float): The time in seconds spent in prompt processing.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
generation_time (float): The time in seconds spent in generation .
peak_memory (float): The peak memory used so far in GB.
"""
prompt_tokens: int = ...
prompt_tps: float = ...
prompt_time: float = ...
generation_tokens: int = ...
generation_tps: float = ...
generation_time: float = ...
peak_memory: float = ...
@dataclass
class BatchResponse:
"""
An data object to hold a batch generation response.
Args:
texts: (List[str]): The generated text for each prompt.
stats (BatchStats): Statistics about the generation.
"""
texts: List[str]
stats: BatchStats
@dataclass
class Batch:
uids: List[int]
y: mx.array
logprobs: mx.array
max_tokens: List[int]
num_tokens: List[int]
cache: List[Any]
def __len__(self): # -> int:
...
def filter(self, keep_idx: List[int]): # -> None:
...
def extend(self, other): # -> None:
...
class BatchGenerator:
@dataclass
class Response:
uid: int
token: int
logprobs: mx.array
finish_reason: Optional[str]
def __init__(
self,
model,
max_tokens: int = ...,
stop_tokens: Optional[set] = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
completion_batch_size: int = ...,
prefill_batch_size: int = ...,
prefill_step_size: int = ...,
) -> None: ...
def insert(
self, prompts, max_tokens: Union[List[int], int, None] = ...
): # -> list[Any]:
...
def stats(self): # -> BatchStats:
...
def next(self): # -> list[Any]:
...
def batch_generate(
model,
tokenizer,
prompts: List[int],
max_tokens: Union[int, List[int]] = ...,
verbose: bool = ...,
**kwargs,
) -> BatchResponse:
"""
Generate responses for the given batch of prompts.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (List[List[int]]): The input prompts.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
max_tokens (Union[int, List[int]): Maximum number of output tokens. This
can be per prompt if a list is provided.
kwargs: The remaining options get passed to :obj:`BatchGenerator`.
See :obj:`BatchGenerator` for more details.
"""
def main(): # -> None:
...
if __name__ == "__main__": ...

View File

@@ -0,0 +1 @@
import cache as cache

View File

@@ -0,0 +1,47 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params): # -> Self:
...
def create_causal_mask(
N: int,
offset: int = ...,
window_size: Optional[int] = ...,
right_padding: Optional[mx.array] = ...,
left_padding: Optional[mx.array] = ...,
): # -> array:
...
def create_attention_mask(
h, cache=..., window_size: Optional[int] = ..., return_array: bool = ...
): # -> array | Literal['causal'] | None:
...
def create_ssm_mask(h, cache=...): # -> None:
...
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = ...,
bits: int = ...,
) -> mx.array: ...
def scaled_dot_product_attention(
queries,
keys,
values,
cache,
scale: float,
mask: Optional[mx.array],
sinks: Optional[mx.array] = ...,
) -> mx.array: ...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
def bitnet_quantize(model, quantization_config: dict): ...
def make_bitlinear_kernel():
"""
Custom Metal kernel that performs matrix multiplication directly on
packed weights and scales the output. This eliminates the need to
store unpacked weights in memory.
"""
_bitlinear_kernel = ...
class BitLinear(nn.Module):
"""
BitLinear module with memory-efficient weight handling.
"""
def __init__(
self, in_features, out_features, bias=..., invert_weight_scales=...
) -> None: ...
def execute_matmul_kernel(self, x, packed_weights): ...
def __call__(self, x): # -> array:
...

View File

@@ -0,0 +1,354 @@
"""
This type stub file was generated by pyright.
"""
from typing import Any, Dict, List, Optional
import mlx.nn as nn
from mlx.core import array
def make_prompt_cache(
model: nn.Module, max_kv_size: Optional[int] = ...
) -> List[KVCache | Any]:
"""
Construct the model's cache for use in generation.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
def save_prompt_cache(
file_name: str, cache: List[Any], metadata: Dict[str, str] = ...
) -> None:
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
def load_prompt_cache(
file_name, return_metadata=...
): # -> tuple[list[Any], Any] | list[Any]:
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
def create_attention_mask(
N: int, offset: int, return_array: bool, window_size: Optional[int]
): # -> array | Literal['causal'] | None:
...
class _BaseCache:
@property
def state(self): # -> list[Any]:
...
@state.setter
def state(self, v): # -> None:
...
@property
def meta_state(self): # -> Literal['']:
...
@meta_state.setter
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[False]:
...
@classmethod
def from_state(cls, state, meta_state): # -> Self:
...
class ConcatenateKVCache(_BaseCache):
"""ConcatenateKVCache the simplest KV cache implementation.
Can be used as a mock KV cache or when large blocks are being processed at
a time in which case KVCache isn't necessarily faster. Consider using the
KVCache with a larger step size before using this cache.
"""
def __init__(self) -> None: ...
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
...
@property
def state(self): # -> tuple[Any | array | None, Any | array | None]:
...
@state.setter
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class QuantizedKVCache(_BaseCache):
step = ...
def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...
def update_and_fetch(self, keys, values): # -> Any:
...
@property
def state(
self,
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
...
@state.setter
def state(self, v): # -> None:
...
@property
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class KVCache(_BaseCache):
step = ...
def __init__(self) -> None: ...
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
...
@property
def state(
self,
) -> tuple[array, array]: ...
@state.setter
def state(self, v) -> None: ...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class RotatingKVCache(_BaseCache):
step = ...
def __init__(self, max_size, keep=...) -> None: ...
def update_and_fetch(
self, keys, values
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
...
@property
def state(
self,
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
...
@state.setter
def state(self, v): # -> None:
...
@property
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
): # -> array | Literal['causal'] | None:
...
class ArraysCache(_BaseCache):
def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...
def __setitem__(self, idx, value): # -> None:
...
def __getitem__(self, idx): ...
@property
def state(self): # -> list[Any | array] | list[array]:
...
@state.setter
def state(self, v): # -> None:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
def make_mask(self, N: int): # -> array | None:
...
class MambaCache(ArraysCache):
def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...
class ChunkedKVCache(KVCache):
def __init__(self, chunk_size) -> None: ...
def maybe_trim_front(self): # -> None:
...
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
...
def trim(self, n): # -> int:
...
@property
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): # -> None:
...
class CacheList(_BaseCache):
def __init__(self, *caches) -> None: ...
def __getitem__(self, idx): ...
def is_trimmable(self): # -> bool:
...
def trim(self, n): ...
@property
def state(self): # -> list[Any]:
...
@state.setter
def state(self, v): # -> None:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
class BatchKVCache(_BaseCache):
step = ...
def __init__(self, left_padding: List[int]) -> None:
"""
The BatchKV cache expects inputs to be left-padded.
E.g. the following prompts:
[1, 3, 5]
[7]
[2, 6, 8, 9]
Should be padded like so:
[0, 1, 3, 5]
[0, 0, 0, 7]
[2, 6, 8, 9]
And ``left_padding`` specifies the amount of padding for each.
In this case, ``left_padding = [1, 3, 0]``.
"""
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
...
@property
def state(
self,
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
...
@state.setter
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int | float:
...
def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
class BatchRotatingKVCache(_BaseCache):
step = ...
def __init__(self, max_size, left_padding: List[int]) -> None: ...
def update_and_fetch(
self, keys, values
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
...
@property
def state(
self,
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
...
@state.setter
def state(self, v): # -> None:
...
@property
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
): # -> array:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""

View File

@@ -0,0 +1,79 @@
"""
This type stub file was generated by pyright.
"""
from functools import partial
import mlx.core as mx
import mlx.nn as nn
class QuantizedSwitchLinear(nn.Module):
def __init__(
self,
input_dims: int,
output_dims: int,
num_experts: int,
bias: bool = ...,
group_size: int = ...,
bits: int = ...,
mode: str = ...,
) -> None: ...
@property
def input_dims(self): # -> int:
...
@property
def output_dims(self): # -> int:
...
@property
def num_experts(self): # -> int:
...
def __call__(self, x, indices, sorted_indices=...): # -> array:
...
class SwitchLinear(nn.Module):
def __init__(
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = ...
) -> None: ...
@property
def input_dims(self): # -> int:
...
@property
def output_dims(self): # -> int:
...
@property
def num_experts(self): # -> int:
...
def __call__(self, x, indices, sorted_indices=...): ...
def to_quantized(
self, group_size: int = ..., bits: int = ..., mode: str = ...
): # -> QuantizedSwitchLinear:
...
@partial(mx.compile, shapeless=True)
def swiglu(x, gate): ...
class SwiGLU(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x, gate): ...
class SwitchGLU(nn.Module):
def __init__(
self,
input_dims: int,
hidden_dims: int,
num_experts: int,
activation=...,
bias: bool = ...,
) -> None: ...
def __call__(self, x, indices) -> mx.array: ...
class SwitchMLP(nn.Module):
def __init__(
self,
input_dims: int,
hidden_dims: int,
num_experts: int,
activation=...,
bias: bool = ...,
) -> None: ...
def __call__(self, x, indices) -> mx.array: ...

View File

@@ -0,0 +1,148 @@
"""
This type stub file was generated by pyright.
"""
from functools import partial
from typing import Callable, Dict, List, Optional
import mlx.core as mx
def make_sampler(
temp: float = ...,
top_p: float = ...,
min_p: float = ...,
min_tokens_to_keep: int = ...,
top_k: int = ...,
xtc_probability: float = ...,
xtc_threshold: float = ...,
xtc_special_tokens: List[int] = ...,
) -> Callable[[mx.array], mx.array]:
"""
Make a sampler function for use with ``generate_step``.
Args:
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
top_k (int, optional): The top k tokens ranked by probability to constrain
the sampling to.
xtc_probability (float, optional): The probability of applying XTC
sampling.
xtc_threshold (float, optional): The threshold the probs need to reach
for being sampled.
xtc_special_tokens (list(int), optional): List of special tokens IDs to
be excluded from XTC sampling.
Returns:
Callable[mx.array, mx.array]:
A sampler which takes log-probabilities and returns tokens.
"""
def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = ...,
repetition_penalty: Optional[float] = ...,
repetition_context_size: Optional[int] = ...,
): # -> list[Any]:
"""
Make logits processors for use with ``generate_step``.
Args:
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias.
Returns:
List[Callable[[mx.array, mx.array], mx.array]]:
A list of logits processors. Each processor in the list is a
callable which takes an array of tokens and an array of logits
and returns the updated logits.
"""
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def apply_top_k(logprobs: mx.array, top_k: int) -> mx.array:
"""
Sample from only the top K tokens ranked by probability.
Args:
logprobs: A vector of log probabilities.
top_k (int): Top k tokens to sample from.
"""
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def apply_min_p(
logprobs: mx.array, min_p: float, min_tokens_to_keep: int = ...
) -> mx.array:
"""
Apply min-p sampling to the logprobs.
Min-p keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter is more
aggressive given a very high-probability token.
Args:
logprobs: A vector of log probabilities.
min_p (float): Minimum token probability. Typical values are in the
0.01-0.2 range, comparably selective as setting `top_p` in the
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.
"""
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def apply_top_p(logprobs: mx.array, top_p: float) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
Args:
logprobs: A vector of log probabilities.
top_p: The cumulative probability threshold for top-p filtering.
Returns:
token selected based on the top-p criterion.
"""
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def apply_xtc(
logits: mx.array,
xtc_probability: float,
xtc_threshold: float,
xtc_special_tokens: List[int],
) -> mx.array:
"""
Apply XTC sampling to the logits.
Args:
logits: The logits from the model's output.
xtc_probability (float): Probability of XTC sampling to happen for each token
xtc_threshold (float): The threshold the probs need to reach for being sampled.
special_tokens_ids (list(int)): List of special tokens IDs to be excluded from XTC sampling.
"""
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def categorical_sampling(logits, temp): # -> array:
...
def make_repetition_penalty(
penalty: float, context_size: int = ...
): # -> Callable[..., Any]:
"""
Make repetition penalty processor.
Paper: https://arxiv.org/abs/1909.05858
Args:
penalty (float): The repetition penalty factor to be applied.
context_size (int): The number of previous tokens to use.
Default: ``20``.
Returns:
Callable[[mx.array, List[int]], mx.array]:
The repetition penalty processor.
"""

View File

@@ -0,0 +1,168 @@
"""
This type stub file was generated by pyright.
"""
from functools import partial
from pathlib import Path
from transformers import PreTrainedTokenizerFast
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
Example usage is as follows:
detokenizer = ...
# Reset the tokenizer state
detokenizer.reset()
for token in generate(...):
detokenizer.add_token(token.item())
# Contains the whole text so far. Some tokens may not be included
# since it contains whole words usually.
detokenizer.text
# Contains the printable segment (usually a word) since the last
# time it was accessed
detokenizer.last_segment
# Contains all the tokens added so far
detokenizer.tokens
# Make sure that we detokenize any remaining tokens
detokenizer.finalize()
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
"""
__slots__ = ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
@property
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
implementation and should work with every tokenizer.
Its complexity is O(T^2) where T is the longest line since it will
repeatedly detokenize the same tokens until a new line is generated.
"""
def __init__(self, tokenizer) -> None: ...
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
@property
def text(self): # -> str:
...
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
It adds tokens to the text if the next token starts with the special SPM
underscore which results in linear complexity.
"""
def __init__(self, tokenizer, trim_space=...) -> None: ...
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
class BPEStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for OpenAI style BPE models.
It adds tokens to the text if the next token starts with a space similar to
the SPM detokenizer.
"""
_byte_decoder = ...
_space_matches = ...
def __init__(self, tokenizer) -> None: ...
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
@classmethod
def make_byte_decoder(cls): # -> None:
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
class TokenizerWrapper:
"""A wrapper that combines an HF tokenizer and a detokenizer.
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
def add_eos_token(self, token: str): # -> None:
...
@property
def has_thinking(self): # -> bool:
...
@property
def think_start(self): # -> str | None:
...
@property
def think_end(self): # -> str | None:
...
@property
def has_tool_calling(self): # -> bool:
...
@property
def tool_call_start(self): # -> str | None:
...
@property
def tool_call_end(self): # -> str | None:
...
@property
def detokenizer(self): # -> NaiveStreamingDetokenizer:
"""
Get a stateful streaming detokenizer.
"""
def __getattr__(self, attr): # -> set[Any] | Any:
...
def __setattr__(self, attr, value): # -> None:
...
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
def __init__(self, *args, **kwargs) -> None: ...
def encode(self, text, **kwargs): # -> list[int]:
...
def encode_batch(self, texts, **kwargs): ...
def decode(self, *args, **kwargs): # -> str:
...
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load_tokenizer(
model_path: Path,
tokenizer_config_extra=...,
return_tokenizer=...,
eos_token_ids=...,
) -> (
TokenizerWrapper
| type[SPMStreamingDetokenizer]
| partial[SPMStreamingDetokenizer]
| type[BPEStreamingDetokenizer]
| type[NaiveStreamingDetokenizer]
):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
Note, to use a fast streaming tokenizer, pass a local file path rather than
a Hugging Face repo ID.
"""
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...

View File

@@ -0,0 +1,195 @@
"""
This type stub file was generated by pyright.
"""
import os
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import mlx.nn as nn
from transformers.utils.auto_docstring import ModelArgs
from .tokenizer_utils import TokenizerWrapper
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": ...
else: ...
MODEL_REMAPPING = ...
MAX_FILE_SIZE_GB = ...
def compute_bits_per_weight(model): ...
def hf_repo_to_path(hf_repo): # -> Path:
...
def load_config(model_path: Path) -> dict: ...
def load_model(
model_path: Path,
lazy: bool = False,
strict: bool = True,
model_config: dict[str, Any] = {},
get_model_classes: Callable[
[dict[str, Any]], Tuple[Type[nn.Module], Type[ModelArgs]]
] = ...,
) -> Tuple[nn.Module, dict[str, Any]]:
"""
Load and initialize the model from a given path.
Args:
model_path (Path): The path to load the model from.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
strict (bool): Whether or not to raise an exception if weights don't
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config.
Defaults to the ``_get_classes`` function.
Returns:
Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config.
Raises:
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
def load(
path_or_hf_repo: str,
tokenizer_config=...,
model_config=...,
adapter_path: Optional[str] = ...,
lazy: bool = ...,
return_config: bool = ...,
revision: str = ...,
) -> Union[
Tuple[nn.Module, TokenizerWrapper],
Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]],
]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
return_config (bool: If ``True`` return the model config as the last item..
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
Returns:
Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]:
A tuple containing the loaded model, tokenizer and, if requested, the model config.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
def make_shards(weights: dict, max_file_size_gb: int = ...) -> list:
"""
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
def create_model_card(
path: Union[str, Path], hf_path: Union[str, Path, None]
): # -> None:
"""
Uploads the model to Hugging Face hub.
Args:
path (Union[str, Path]): Local path to the model.
hf_path (Union[str, Path, None]): Path to the original Hugging Face model.
"""
def upload_to_hub(path: str, upload_repo: str): # -> None:
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
"""
def save_model(
save_path: Union[str, Path], model: nn.Module, *, donate_model: bool = ...
) -> None:
"""Save model weights and metadata index into specified directory."""
def quantize_model(
model: nn.Module,
config: dict,
group_size: int,
bits: int,
mode: str = ...,
quant_predicate: Optional[Callable[[str, nn.Module], Union[bool, dict]]] = ...,
) -> Tuple[nn.Module, dict]:
"""
Applies quantization to the model weights.
Args:
model (nn.Module): The model to be quantized.
config (dict): Model configuration.
group_size (int): Group size for quantization.
bits (int): Bits per weight for quantization.
mode (str): The quantization mode.
quant_predicate (Callable): A callable that decides how to quantize
each layer based on the path. Accepts the layer `path` and the
`module`. Returns either a bool to signify quantize/no quantize or
a dict of quantization parameters to pass to `to_quantized`.
Returns:
Tuple: Tuple containing quantized model and config.
"""
def save_config(config: dict, config_path: Union[str, Path]) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
def save(
dst_path: Union[str, Path],
src_path_or_repo: Union[str, Path],
model: nn.Module,
tokenizer: TokenizerWrapper,
config: Dict[str, Any],
donate_model: bool = ...,
): # -> None:
...
def common_prefix_len(list1, list2): # -> int:
"""
Calculates the length of the common prefix of two lists.
Args:
list1: The first list of strings.
list2: The second list of strings.
Returns:
The length of the common prefix. Returns 0 if lists are empty
or do not match at the first element.
"""
def does_model_support_input_embeddings(model: nn.Module) -> bool:
"""
Check if the model supports input_embeddings in its call signature.
Args:
model (nn.Module): The model to check.
Returns:
bool: True if the model supports input_embeddings, False otherwise.
"""

View File

@@ -1,5 +1,5 @@
fmt:
uv run ruff format src typings
uv run ruff format src .mlx_typings
lint:
uv run ruff check --fix src

View File

@@ -81,7 +81,7 @@ build-backend = "uv_build"
###
[tool.basedpyright]
include = [".venv/lib/mlx", "src"]
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
typeCheckingMode = "strict"
failOnWarnings = true
@@ -97,8 +97,8 @@ reportUnnecessaryTypeIgnoreComment = "error"
pythonVersion = "3.13"
pythonPlatform = "Darwin"
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust"]
stubPath = "typings"
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust", "**/.github"]
stubPath = ".mlx_typings"
[[tool.basedpyright.executionEnvironments]]
root = "src"

View File

@@ -9,10 +9,10 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
try:
from mlx_lm.tokenizer_utils import load_tokenizer # type: ignore
from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
from mlx_lm.utils import load_model # type: ignore
from mlx_lm.utils import load_model
from pydantic import RootModel
import mlx.core as mx
@@ -167,12 +167,11 @@ def shard_and_load(
f"loading model from {model_path} with strategy {model_shard_meta.strategy}"
)
model, config = load_model(model_path, lazy=True, strict=False) # type: ignore
model, config = load_model(model_path, lazy=True, strict=False)
runner_print(f"{config=}")
assert isinstance(model, nn.Module)
tokenizer = load_tokenizer(model_path) # type: ignore
tokenizer = cast(TokenizerWrapper, tokenizer)
tokenizer = cast(TokenizerWrapper, load_tokenizer(model_path))
runner_print(f"Group size: {group.size()}, group rank: {group.rank()}")

View File

@@ -31,7 +31,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata:
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata | None:
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_meta=base_shard.model_meta,
@@ -150,11 +150,9 @@ class ResumableShardDownloader(ShardDownloader):
# print("get_shard_download_status")
async def _status_for_model(
model_id: str,
) -> tuple[Path, RepoDownloadProgress] | None:
) -> tuple[Path, RepoDownloadProgress]:
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)
if shard is None:
return None
return await download_shard(
shard, self.on_progress_wrapper, skip_download=True
)
@@ -168,8 +166,6 @@ class ResumableShardDownloader(ShardDownloader):
for task in asyncio.as_completed(tasks):
try:
result = await task
if result is None:
continue
path, progress = result
yield (path, progress)
except Exception as e:

View File

@@ -35,16 +35,18 @@ generation_stream = mx.new_stream(mx.default_device())
def maybe_quantize_kv_cache(
prompt_cache: list[Any],
prompt_cache: list[KVCache | Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
if kv_bits is None:
return
for e, c in enumerate(prompt_cache): # type: ignore[type-arg]
if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start: # type: ignore[type-arg]
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) # type: ignore[type-arg]
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
def generate_step(
@@ -189,7 +191,7 @@ def generate_step(
quantize_cache_fn(prompt_cache)
start_time = time.time()
mx.eval([c.state for c in prompt_cache]) # type: ignore
mx.eval([c.state for c in prompt_cache])
eval_time = time.time() - start_time
prompt_processed_tokens += n_to_process
@@ -221,9 +223,17 @@ def generate_step(
n = 0
while True:
mx.eval(y, logprobs)
assert y is not None
assert logprobs is not None
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
yield int(y.item()), logprobs
n += 1
if n % 256 == 0:
mx.clear_cache()
if n == max_tokens: