diff --git a/.github/scripts/bench.py b/.github/scripts/bench.py index 4f607b69..44733da1 100644 --- a/.github/scripts/bench.py +++ b/.github/scripts/bench.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -# type: ignore """ Unified benchmark script for EXO. Runs single or multi-stage benchmarks with configurable load patterns. diff --git a/typings/.gitkeep b/.mlx_typings/.gitkeep similarity index 100% rename from typings/.gitkeep rename to .mlx_typings/.gitkeep diff --git a/typings/mlx/core/__init__.pyi b/.mlx_typings/mlx/core/__init__.pyi similarity index 99% rename from typings/mlx/core/__init__.pyi rename to .mlx_typings/mlx/core/__init__.pyi index 8edb9832..48680a80 100644 --- a/typings/mlx/core/__init__.pyi +++ b/.mlx_typings/mlx/core/__init__.pyi @@ -2614,7 +2614,7 @@ type MX_ARRAY_TREE = ( | Mapping[str, MX_ARRAY_TREE] ) -def eval(*args: MX_ARRAY_TREE) -> None: +def eval(*args: MX_ARRAY_TREE | None) -> None: """ Evaluate an :class:`array` or tree of :class:`array`. diff --git a/typings/mlx/core/cuda/__init__.pyi b/.mlx_typings/mlx/core/cuda/__init__.pyi similarity index 100% rename from typings/mlx/core/cuda/__init__.pyi rename to .mlx_typings/mlx/core/cuda/__init__.pyi diff --git a/typings/mlx/core/distributed/__init__.pyi b/.mlx_typings/mlx/core/distributed/__init__.pyi similarity index 100% rename from typings/mlx/core/distributed/__init__.pyi rename to .mlx_typings/mlx/core/distributed/__init__.pyi diff --git a/typings/mlx/core/metal/__init__.pyi b/.mlx_typings/mlx/core/metal/__init__.pyi similarity index 100% rename from typings/mlx/core/metal/__init__.pyi rename to .mlx_typings/mlx/core/metal/__init__.pyi diff --git a/typings/mlx/core/random/__init__.pyi b/.mlx_typings/mlx/core/random/__init__.pyi similarity index 100% rename from typings/mlx/core/random/__init__.pyi rename to .mlx_typings/mlx/core/random/__init__.pyi diff --git a/typings/mlx/nn/__init__.pyi b/.mlx_typings/mlx/nn/__init__.pyi similarity index 100% rename from typings/mlx/nn/__init__.pyi rename to .mlx_typings/mlx/nn/__init__.pyi diff --git a/typings/mlx/nn/init.pyi b/.mlx_typings/mlx/nn/init.pyi similarity index 100% rename from typings/mlx/nn/init.pyi rename to .mlx_typings/mlx/nn/init.pyi diff --git a/typings/mlx/nn/layers/__init__.pyi b/.mlx_typings/mlx/nn/layers/__init__.pyi similarity index 100% rename from typings/mlx/nn/layers/__init__.pyi rename to .mlx_typings/mlx/nn/layers/__init__.pyi diff --git a/typings/mlx/nn/layers/activations.pyi b/.mlx_typings/mlx/nn/layers/activations.pyi similarity index 100% rename from typings/mlx/nn/layers/activations.pyi rename to .mlx_typings/mlx/nn/layers/activations.pyi diff --git a/typings/mlx/nn/layers/base.pyi b/.mlx_typings/mlx/nn/layers/base.pyi similarity index 100% rename from typings/mlx/nn/layers/base.pyi rename to .mlx_typings/mlx/nn/layers/base.pyi diff --git a/typings/mlx/nn/layers/containers.pyi b/.mlx_typings/mlx/nn/layers/containers.pyi similarity index 100% rename from typings/mlx/nn/layers/containers.pyi rename to .mlx_typings/mlx/nn/layers/containers.pyi diff --git a/typings/mlx/nn/layers/convolution.pyi b/.mlx_typings/mlx/nn/layers/convolution.pyi similarity index 100% rename from typings/mlx/nn/layers/convolution.pyi rename to .mlx_typings/mlx/nn/layers/convolution.pyi diff --git a/typings/mlx/nn/layers/convolution_transpose.pyi b/.mlx_typings/mlx/nn/layers/convolution_transpose.pyi similarity index 100% rename from typings/mlx/nn/layers/convolution_transpose.pyi rename to .mlx_typings/mlx/nn/layers/convolution_transpose.pyi diff --git a/typings/mlx/nn/layers/distributed.pyi b/.mlx_typings/mlx/nn/layers/distributed.pyi similarity index 100% rename from typings/mlx/nn/layers/distributed.pyi rename to .mlx_typings/mlx/nn/layers/distributed.pyi diff --git a/typings/mlx/nn/layers/dropout.pyi b/.mlx_typings/mlx/nn/layers/dropout.pyi similarity index 100% rename from typings/mlx/nn/layers/dropout.pyi rename to .mlx_typings/mlx/nn/layers/dropout.pyi diff --git a/typings/mlx/nn/layers/embedding.pyi b/.mlx_typings/mlx/nn/layers/embedding.pyi similarity index 100% rename from typings/mlx/nn/layers/embedding.pyi rename to .mlx_typings/mlx/nn/layers/embedding.pyi diff --git a/typings/mlx/nn/layers/linear.pyi b/.mlx_typings/mlx/nn/layers/linear.pyi similarity index 100% rename from typings/mlx/nn/layers/linear.pyi rename to .mlx_typings/mlx/nn/layers/linear.pyi diff --git a/typings/mlx/nn/layers/normalization.pyi b/.mlx_typings/mlx/nn/layers/normalization.pyi similarity index 100% rename from typings/mlx/nn/layers/normalization.pyi rename to .mlx_typings/mlx/nn/layers/normalization.pyi diff --git a/typings/mlx/nn/layers/pooling.pyi b/.mlx_typings/mlx/nn/layers/pooling.pyi similarity index 100% rename from typings/mlx/nn/layers/pooling.pyi rename to .mlx_typings/mlx/nn/layers/pooling.pyi diff --git a/typings/mlx/nn/layers/positional_encoding.pyi b/.mlx_typings/mlx/nn/layers/positional_encoding.pyi similarity index 100% rename from typings/mlx/nn/layers/positional_encoding.pyi rename to .mlx_typings/mlx/nn/layers/positional_encoding.pyi diff --git a/typings/mlx/nn/layers/quantized.pyi b/.mlx_typings/mlx/nn/layers/quantized.pyi similarity index 100% rename from typings/mlx/nn/layers/quantized.pyi rename to .mlx_typings/mlx/nn/layers/quantized.pyi diff --git a/typings/mlx/nn/layers/recurrent.pyi b/.mlx_typings/mlx/nn/layers/recurrent.pyi similarity index 100% rename from typings/mlx/nn/layers/recurrent.pyi rename to .mlx_typings/mlx/nn/layers/recurrent.pyi diff --git a/typings/mlx/nn/layers/transformer.pyi b/.mlx_typings/mlx/nn/layers/transformer.pyi similarity index 100% rename from typings/mlx/nn/layers/transformer.pyi rename to .mlx_typings/mlx/nn/layers/transformer.pyi diff --git a/typings/mlx/nn/layers/upsample.pyi b/.mlx_typings/mlx/nn/layers/upsample.pyi similarity index 100% rename from typings/mlx/nn/layers/upsample.pyi rename to .mlx_typings/mlx/nn/layers/upsample.pyi diff --git a/typings/mlx/nn/losses.pyi b/.mlx_typings/mlx/nn/losses.pyi similarity index 100% rename from typings/mlx/nn/losses.pyi rename to .mlx_typings/mlx/nn/losses.pyi diff --git a/typings/mlx/nn/utils.pyi b/.mlx_typings/mlx/nn/utils.pyi similarity index 100% rename from typings/mlx/nn/utils.pyi rename to .mlx_typings/mlx/nn/utils.pyi diff --git a/typings/mlx/utils.pyi b/.mlx_typings/mlx/utils.pyi similarity index 100% rename from typings/mlx/utils.pyi rename to .mlx_typings/mlx/utils.pyi diff --git a/.mlx_typings/mlx_lm/__init__.pyi b/.mlx_typings/mlx_lm/__init__.pyi new file mode 100644 index 00000000..fee89807 --- /dev/null +++ b/.mlx_typings/mlx_lm/__init__.pyi @@ -0,0 +1,2 @@ +import models as models +import tokenizer_utils as tokenizer_utils diff --git a/.mlx_typings/mlx_lm/convert.pyi b/.mlx_typings/mlx_lm/convert.pyi new file mode 100644 index 00000000..aff4de7b --- /dev/null +++ b/.mlx_typings/mlx_lm/convert.pyi @@ -0,0 +1,45 @@ +""" +This type stub file was generated by pyright. +""" + +import argparse +from typing import Callable, Optional, Union + +import mlx.nn as nn + +def mixed_quant_predicate_builder( + recipe: str, model: nn.Module, group_size: int = ... +) -> Callable[[str, nn.Module, dict], Union[bool, dict]]: ... + +QUANT_RECIPES = ... +MODEL_CONVERSION_DTYPES = ... + +def convert( + hf_path: str, + mlx_path: str = ..., + quantize: bool = ..., + q_group_size: int = ..., + q_bits: int = ..., + q_mode: str = ..., + dtype: Optional[str] = ..., + upload_repo: str = ..., + revision: Optional[str] = ..., + dequantize: bool = ..., + quant_predicate: Optional[ + Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str] + ] = ..., + trust_remote_code: bool = ..., +): # -> None: + ... +def configure_parser() -> argparse.ArgumentParser: + """ + Configures and returns the argument parser for the script. + + Returns: + argparse.ArgumentParser: Configured argument parser. + """ + +def main(): # -> None: + ... + +if __name__ == "__main__": ... diff --git a/.mlx_typings/mlx_lm/generate.pyi b/.mlx_typings/mlx_lm/generate.pyi new file mode 100644 index 00000000..4711fce0 --- /dev/null +++ b/.mlx_typings/mlx_lm/generate.pyi @@ -0,0 +1,324 @@ +""" +This type stub file was generated by pyright. +""" + +import contextlib +from dataclasses import dataclass +from typing import Any, Callable, Generator, List, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +from transformers import PreTrainedTokenizer + +from .tokenizer_utils import TokenizerWrapper + +DEFAULT_PROMPT = ... +DEFAULT_MAX_TOKENS = ... +DEFAULT_TEMP = ... +DEFAULT_TOP_P = ... +DEFAULT_MIN_P = ... +DEFAULT_TOP_K = ... +DEFAULT_XTC_PROBABILITY = ... +DEFAULT_XTC_THRESHOLD = ... +DEFAULT_MIN_TOKENS_TO_KEEP = ... +DEFAULT_SEED = ... +DEFAULT_MODEL = ... +DEFAULT_QUANTIZED_KV_START = ... + +def str2bool(string): # -> bool: + ... +def setup_arg_parser(): # -> ArgumentParser: + """Set up and return the argument parser.""" + +generation_stream = ... + +@contextlib.contextmanager +def wired_limit( + model: nn.Module, streams: Optional[List[mx.Stream]] = ... +): # -> Generator[None, Any, None]: + """ + A context manager to temporarily change the wired limit. + + Note, the wired limit should not be changed during an async eval. If an + async eval could be running pass in the streams to synchronize with prior + to exiting the context manager. + """ +@dataclass +class GenerationResponse: + """ + The output of :func:`stream_generate`. + + Args: + text (str): The next segment of decoded text. This can be an empty string. + token (int): The next token. + from_draft (bool): Whether the token was generated by the draft model. + logprobs (mx.array): A vector of log probabilities. + prompt_tokens (int): The number of tokens in the prompt. + prompt_tps (float): The prompt processing tokens-per-second. + generation_tokens (int): The number of generated tokens. + generation_tps (float): The tokens-per-second for generation. + peak_memory (float): The peak memory used so far in GB. + finish_reason (str): The reason the response is being sent: "length", "stop" or `None` + """ + + text: str + token: int + logprobs: mx.array + from_draft: bool + prompt_tokens: int + prompt_tps: float + generation_tokens: int + generation_tps: float + peak_memory: float + finish_reason: Optional[str] = ... + +def maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits +): # -> None: + ... +def generate_step( + prompt: mx.array, + model: nn.Module, + *, + max_tokens: int = ..., + sampler: Optional[Callable[[mx.array], mx.array]] = ..., + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ..., + max_kv_size: Optional[int] = ..., + prompt_cache: Optional[Any] = ..., + prefill_step_size: int = ..., + kv_bits: Optional[int] = ..., + kv_group_size: int = ..., + quantized_kv_start: int = ..., + prompt_progress_callback: Optional[Callable[[int], int]] = ..., + input_embeddings: Optional[mx.array] = ..., +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. + max_kv_size (int, optional): Maximum size of the key-value cache. Old + entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + prompt_progress_callback (Callable[[int], int]): A call-back which takes the + prompt tokens processed so far and the total number of prompt tokens. + input_embeddings (mx.array, optional): Input embeddings to use instead of or in + conjunction with prompt tokens. Default: ``None``. + + Yields: + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + """ + +def speculative_generate_step( + prompt: mx.array, + model: nn.Module, + draft_model: nn.Module, + *, + num_draft_tokens: int = ..., + max_tokens: int = ..., + sampler: Optional[Callable[[mx.array], mx.array]] = ..., + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ..., + prompt_cache: Optional[Any] = ..., + prefill_step_size: int = ..., + kv_bits: Optional[int] = ..., + kv_group_size: int = ..., + quantized_kv_start: int = ..., +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + draft_model (nn.Module): The draft model for speculative decoding. + num_draft_tokens (int, optional): The number of draft tokens for + speculative decoding. Default: ``2``. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[[mx.array], mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. The cache must be trimmable. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + + Yields: + Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, + and a bool indicating if the token was generated by the draft model + """ + +def stream_generate( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompt: Union[str, mx.array, List[int]], + max_tokens: int = ..., + draft_model: Optional[nn.Module] = ..., + **kwargs, +) -> Generator[GenerationResponse, None, None]: + """ + A generator producing text based on the given prompt from the model. + + Args: + model (nn.Module): The model to use for generation. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (Union[str, mx.array, List[int]]): The input prompt string or + integer tokens. + max_tokens (int): The maximum number of tokens to generate. + Default: ``256``. + draft_model (Optional[nn.Module]): An optional draft model. If provided + then speculative decoding is used. The draft model must use the same + tokenizer as the main model. Default: ``None``. + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + + Yields: + GenerationResponse: An instance containing the generated text segment and + associated metadata. See :class:`GenerationResponse` for details. + """ + +def generate( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompt: Union[str, List[int]], + verbose: bool = ..., + **kwargs, +) -> str: + """ + Generate a complete response from the model. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. + verbose (bool): If ``True``, print tokens and timing information. + Default: ``False``. + kwargs: The remaining options get passed to :func:`stream_generate`. + See :func:`stream_generate` for more details. + """ +@dataclass +class BatchStats: + """ + An data object to hold generation stats. + + Args: + prompt_tokens (int): The number of prompt tokens processed. + prompt_tps (float): The prompt processing tokens-per-second. + prompt_time (float): The time in seconds spent in prompt processing. + generation_tokens (int): The number of generated tokens. + generation_tps (float): The tokens-per-second for generation. + generation_time (float): The time in seconds spent in generation . + peak_memory (float): The peak memory used so far in GB. + """ + + prompt_tokens: int = ... + prompt_tps: float = ... + prompt_time: float = ... + generation_tokens: int = ... + generation_tps: float = ... + generation_time: float = ... + peak_memory: float = ... + +@dataclass +class BatchResponse: + """ + An data object to hold a batch generation response. + + Args: + texts: (List[str]): The generated text for each prompt. + stats (BatchStats): Statistics about the generation. + """ + + texts: List[str] + stats: BatchStats + +@dataclass +class Batch: + uids: List[int] + y: mx.array + logprobs: mx.array + max_tokens: List[int] + num_tokens: List[int] + cache: List[Any] + def __len__(self): # -> int: + ... + def filter(self, keep_idx: List[int]): # -> None: + ... + def extend(self, other): # -> None: + ... + +class BatchGenerator: + @dataclass + class Response: + uid: int + token: int + logprobs: mx.array + finish_reason: Optional[str] + + def __init__( + self, + model, + max_tokens: int = ..., + stop_tokens: Optional[set] = ..., + sampler: Optional[Callable[[mx.array], mx.array]] = ..., + completion_batch_size: int = ..., + prefill_batch_size: int = ..., + prefill_step_size: int = ..., + ) -> None: ... + def insert( + self, prompts, max_tokens: Union[List[int], int, None] = ... + ): # -> list[Any]: + ... + def stats(self): # -> BatchStats: + ... + def next(self): # -> list[Any]: + ... + +def batch_generate( + model, + tokenizer, + prompts: List[int], + max_tokens: Union[int, List[int]] = ..., + verbose: bool = ..., + **kwargs, +) -> BatchResponse: + """ + Generate responses for the given batch of prompts. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (List[List[int]]): The input prompts. + verbose (bool): If ``True``, print tokens and timing information. + Default: ``False``. + max_tokens (Union[int, List[int]): Maximum number of output tokens. This + can be per prompt if a list is provided. + kwargs: The remaining options get passed to :obj:`BatchGenerator`. + See :obj:`BatchGenerator` for more details. + """ + +def main(): # -> None: + ... + +if __name__ == "__main__": ... diff --git a/.mlx_typings/mlx_lm/models/__init__.pyi b/.mlx_typings/mlx_lm/models/__init__.pyi new file mode 100644 index 00000000..e09bd4fc --- /dev/null +++ b/.mlx_typings/mlx_lm/models/__init__.pyi @@ -0,0 +1 @@ +import cache as cache diff --git a/.mlx_typings/mlx_lm/models/base.pyi b/.mlx_typings/mlx_lm/models/base.pyi new file mode 100644 index 00000000..e549e624 --- /dev/null +++ b/.mlx_typings/mlx_lm/models/base.pyi @@ -0,0 +1,47 @@ +""" +This type stub file was generated by pyright. +""" + +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx + +@dataclass +class BaseModelArgs: + @classmethod + def from_dict(cls, params): # -> Self: + ... + +def create_causal_mask( + N: int, + offset: int = ..., + window_size: Optional[int] = ..., + right_padding: Optional[mx.array] = ..., + left_padding: Optional[mx.array] = ..., +): # -> array: + ... +def create_attention_mask( + h, cache=..., window_size: Optional[int] = ..., return_array: bool = ... +): # -> array | Literal['causal'] | None: + ... +def create_ssm_mask(h, cache=...): # -> None: + ... +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = ..., + bits: int = ..., +) -> mx.array: ... +def scaled_dot_product_attention( + queries, + keys, + values, + cache, + scale: float, + mask: Optional[mx.array], + sinks: Optional[mx.array] = ..., +) -> mx.array: ... diff --git a/.mlx_typings/mlx_lm/models/bitlinear_layers.pyi b/.mlx_typings/mlx_lm/models/bitlinear_layers.pyi new file mode 100644 index 00000000..fa1caa82 --- /dev/null +++ b/.mlx_typings/mlx_lm/models/bitlinear_layers.pyi @@ -0,0 +1,26 @@ +""" +This type stub file was generated by pyright. +""" + +import mlx.nn as nn + +def bitnet_quantize(model, quantization_config: dict): ... +def make_bitlinear_kernel(): + """ + Custom Metal kernel that performs matrix multiplication directly on + packed weights and scales the output. This eliminates the need to + store unpacked weights in memory. + """ + +_bitlinear_kernel = ... + +class BitLinear(nn.Module): + """ + BitLinear module with memory-efficient weight handling. + """ + def __init__( + self, in_features, out_features, bias=..., invert_weight_scales=... + ) -> None: ... + def execute_matmul_kernel(self, x, packed_weights): ... + def __call__(self, x): # -> array: + ... diff --git a/.mlx_typings/mlx_lm/models/cache.pyi b/.mlx_typings/mlx_lm/models/cache.pyi new file mode 100644 index 00000000..30fe1b85 --- /dev/null +++ b/.mlx_typings/mlx_lm/models/cache.pyi @@ -0,0 +1,354 @@ +""" +This type stub file was generated by pyright. +""" + +from typing import Any, Dict, List, Optional + +import mlx.nn as nn +from mlx.core import array + +def make_prompt_cache( + model: nn.Module, max_kv_size: Optional[int] = ... +) -> List[KVCache | Any]: + """ + Construct the model's cache for use in generation. + + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum + size of ``max_kv_size`` + """ + +def save_prompt_cache( + file_name: str, cache: List[Any], metadata: Dict[str, str] = ... +) -> None: + """ + Save a pre-computed prompt cache to a file. + + Args: + file_name (str): The ``.safetensors`` file name. + cache (List[Any]): The model state. + metadata (Dict[str, str]): Optional metadata to save along with model + state. + """ + +def load_prompt_cache( + file_name, return_metadata=... +): # -> tuple[list[Any], Any] | list[Any]: + """ + Load a prompt cache from a file. + + Args: + file_name (str): The ``.safetensors`` file name. + return_metadata (bool): Whether or not to return metadata. + Default: ``False``. + + Returns: + List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and + the metadata if requested. + """ + +def can_trim_prompt_cache(cache: List[Any]) -> bool: + """ + Check if model's cache can be trimmed. + """ + +def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: + """ + Trim the model's cache by the given number of tokens. + + This function will trim the cache if possible (in-place) and return the + number of tokens that were trimmed. + + Args: + cache (List[Any]): The model's cache. + num_tokens (int): The number of tokens to trim. + + Returns: + (int): The number of tokens that were trimmed. + """ + +def create_attention_mask( + N: int, offset: int, return_array: bool, window_size: Optional[int] +): # -> array | Literal['causal'] | None: + ... + +class _BaseCache: + @property + def state(self): # -> list[Any]: + ... + @state.setter + def state(self, v): # -> None: + ... + @property + def meta_state(self): # -> Literal['']: + ... + @meta_state.setter + def meta_state(self, v): # -> None: + ... + def is_trimmable(self): # -> Literal[False]: + ... + @classmethod + def from_state(cls, state, meta_state): # -> Self: + ... + +class ConcatenateKVCache(_BaseCache): + """ConcatenateKVCache the simplest KV cache implementation. + + Can be used as a mock KV cache or when large blocks are being processed at + a time in which case KVCache isn't necessarily faster. Consider using the + KVCache with a larger step size before using this cache. + """ + def __init__(self) -> None: ... + def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]: + ... + @property + def state(self): # -> tuple[Any | array | None, Any | array | None]: + ... + @state.setter + def state(self, v): # -> None: + ... + def is_trimmable(self): # -> Literal[True]: + ... + def trim(self, n): # -> int: + ... + def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None: + ... + +class QuantizedKVCache(_BaseCache): + step = ... + def __init__(self, group_size: int = ..., bits: int = ...) -> None: ... + def update_and_fetch(self, keys, values): # -> Any: + ... + @property + def state( + self, + ): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any: + ... + @state.setter + def state(self, v): # -> None: + ... + @property + def meta_state(self): # -> tuple[str, ...]: + ... + @meta_state.setter + def meta_state(self, v): # -> None: + ... + def is_trimmable(self): # -> Literal[True]: + ... + def trim(self, n): # -> int: + ... + def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None: + ... + +class KVCache(_BaseCache): + step = ... + def __init__(self) -> None: ... + def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]: + ... + @property + def state( + self, + ) -> tuple[array, array]: ... + @state.setter + def state(self, v) -> None: ... + def is_trimmable(self): # -> Literal[True]: + ... + def trim(self, n): # -> int: + ... + def to_quantized( + self, group_size: int = ..., bits: int = ... + ) -> QuantizedKVCache: ... + def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None: + ... + +class RotatingKVCache(_BaseCache): + step = ... + def __init__(self, max_size, keep=...) -> None: ... + def update_and_fetch( + self, keys, values + ): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]: + ... + @property + def state( + self, + ): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]: + ... + @state.setter + def state(self, v): # -> None: + ... + @property + def meta_state(self): # -> tuple[str, ...]: + ... + @meta_state.setter + def meta_state(self, v): # -> None: + ... + def is_trimmable(self): # -> bool: + ... + def trim(self, n): # -> int: + ... + def to_quantized( + self, group_size: int = ..., bits: int = ... + ) -> QuantizedKVCache: ... + def make_mask( + self, N: int, window_size: Optional[int] = ..., return_array: bool = ... + ): # -> array | Literal['causal'] | None: + ... + +class ArraysCache(_BaseCache): + def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ... + def __setitem__(self, idx, value): # -> None: + ... + def __getitem__(self, idx): ... + @property + def state(self): # -> list[Any | array] | list[array]: + ... + @state.setter + def state(self, v): # -> None: + ... + def filter(self, batch_indices): # -> None: + """ + In-place filter to keep just the given indices in the cache. + """ + + def extend(self, other): # -> None: + """ + In-place extend this cache with the other cache. + """ + + def make_mask(self, N: int): # -> array | None: + ... + +class MambaCache(ArraysCache): + def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ... + +class ChunkedKVCache(KVCache): + def __init__(self, chunk_size) -> None: ... + def maybe_trim_front(self): # -> None: + ... + def update_and_fetch(self, keys, values): # -> tuple[array, array]: + ... + def trim(self, n): # -> int: + ... + @property + def meta_state(self): # -> tuple[str, ...]: + ... + @meta_state.setter + def meta_state(self, v): # -> None: + ... + +class CacheList(_BaseCache): + def __init__(self, *caches) -> None: ... + def __getitem__(self, idx): ... + def is_trimmable(self): # -> bool: + ... + def trim(self, n): ... + @property + def state(self): # -> list[Any]: + ... + @state.setter + def state(self, v): # -> None: + ... + def filter(self, batch_indices): # -> None: + """ + In-place filter to keep just the given indices in the cache. + """ + + def extend(self, other): # -> None: + """ + In-place extend this cache with the other cache. + """ + +class BatchKVCache(_BaseCache): + step = ... + def __init__(self, left_padding: List[int]) -> None: + """ + The BatchKV cache expects inputs to be left-padded. + + E.g. the following prompts: + + [1, 3, 5] + [7] + [2, 6, 8, 9] + + Should be padded like so: + + [0, 1, 3, 5] + [0, 0, 0, 7] + [2, 6, 8, 9] + + And ``left_padding`` specifies the amount of padding for each. + In this case, ``left_padding = [1, 3, 0]``. + """ + + def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]: + ... + @property + def state( + self, + ): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]: + ... + @state.setter + def state(self, v): # -> None: + ... + def is_trimmable(self): # -> Literal[True]: + ... + def trim(self, n): # -> int | float: + ... + def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array: + ... + def filter(self, batch_indices): # -> None: + """ + In-place filter to keep just the given indices in the cache. + """ + + def extend(self, other): # -> None: + """ + In-place extend this cache with the other cache. + """ + +class BatchRotatingKVCache(_BaseCache): + step = ... + def __init__(self, max_size, left_padding: List[int]) -> None: ... + def update_and_fetch( + self, keys, values + ): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]: + ... + @property + def state( + self, + ): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]: + ... + @state.setter + def state(self, v): # -> None: + ... + @property + def meta_state(self): # -> tuple[str, ...]: + ... + @meta_state.setter + def meta_state(self, v): # -> None: + ... + def is_trimmable(self): # -> bool: + ... + def trim(self, n): # -> int: + ... + def to_quantized( + self, group_size: int = ..., bits: int = ... + ) -> QuantizedKVCache: ... + def make_mask( + self, N: int, window_size: Optional[int] = ..., return_array: bool = ... + ): # -> array: + ... + def filter(self, batch_indices): # -> None: + """ + In-place filter to keep just the given indices in the cache. + """ + + def extend(self, other): # -> None: + """ + In-place extend this cache with the other cache. + """ diff --git a/.mlx_typings/mlx_lm/models/switch_layers.pyi b/.mlx_typings/mlx_lm/models/switch_layers.pyi new file mode 100644 index 00000000..c50c999a --- /dev/null +++ b/.mlx_typings/mlx_lm/models/switch_layers.pyi @@ -0,0 +1,79 @@ +""" +This type stub file was generated by pyright. +""" + +from functools import partial + +import mlx.core as mx +import mlx.nn as nn + +class QuantizedSwitchLinear(nn.Module): + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + bias: bool = ..., + group_size: int = ..., + bits: int = ..., + mode: str = ..., + ) -> None: ... + @property + def input_dims(self): # -> int: + ... + @property + def output_dims(self): # -> int: + ... + @property + def num_experts(self): # -> int: + ... + def __call__(self, x, indices, sorted_indices=...): # -> array: + ... + +class SwitchLinear(nn.Module): + def __init__( + self, input_dims: int, output_dims: int, num_experts: int, bias: bool = ... + ) -> None: ... + @property + def input_dims(self): # -> int: + ... + @property + def output_dims(self): # -> int: + ... + @property + def num_experts(self): # -> int: + ... + def __call__(self, x, indices, sorted_indices=...): ... + def to_quantized( + self, group_size: int = ..., bits: int = ..., mode: str = ... + ): # -> QuantizedSwitchLinear: + ... + +@partial(mx.compile, shapeless=True) +def swiglu(x, gate): ... + +class SwiGLU(nn.Module): + def __init__(self) -> None: ... + def __call__(self, x, gate): ... + +class SwitchGLU(nn.Module): + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + activation=..., + bias: bool = ..., + ) -> None: ... + def __call__(self, x, indices) -> mx.array: ... + +class SwitchMLP(nn.Module): + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + activation=..., + bias: bool = ..., + ) -> None: ... + def __call__(self, x, indices) -> mx.array: ... diff --git a/.mlx_typings/mlx_lm/sample_utils.pyi b/.mlx_typings/mlx_lm/sample_utils.pyi new file mode 100644 index 00000000..bc6955a7 --- /dev/null +++ b/.mlx_typings/mlx_lm/sample_utils.pyi @@ -0,0 +1,148 @@ +""" +This type stub file was generated by pyright. +""" + +from functools import partial +from typing import Callable, Dict, List, Optional + +import mlx.core as mx + +def make_sampler( + temp: float = ..., + top_p: float = ..., + min_p: float = ..., + min_tokens_to_keep: int = ..., + top_k: int = ..., + xtc_probability: float = ..., + xtc_threshold: float = ..., + xtc_special_tokens: List[int] = ..., +) -> Callable[[mx.array], mx.array]: + """ + Make a sampler function for use with ``generate_step``. + + Args: + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. + top_k (int, optional): The top k tokens ranked by probability to constrain + the sampling to. + xtc_probability (float, optional): The probability of applying XTC + sampling. + xtc_threshold (float, optional): The threshold the probs need to reach + for being sampled. + xtc_special_tokens (list(int), optional): List of special tokens IDs to + be excluded from XTC sampling. + + + Returns: + Callable[mx.array, mx.array]: + A sampler which takes log-probabilities and returns tokens. + """ + +def make_logits_processors( + logit_bias: Optional[Dict[int, float]] = ..., + repetition_penalty: Optional[float] = ..., + repetition_context_size: Optional[int] = ..., +): # -> list[Any]: + """ + Make logits processors for use with ``generate_step``. + + Args: + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + logit_bias (dictionary, optional): Additive logit bias. + + Returns: + List[Callable[[mx.array, mx.array], mx.array]]: + A list of logits processors. Each processor in the list is a + callable which takes an array of tokens and an array of logits + and returns the updated logits. + """ + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_top_k(logprobs: mx.array, top_k: int) -> mx.array: + """ + Sample from only the top K tokens ranked by probability. + + Args: + logprobs: A vector of log probabilities. + top_k (int): Top k tokens to sample from. + """ + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_min_p( + logprobs: mx.array, min_p: float, min_tokens_to_keep: int = ... +) -> mx.array: + """ + Apply min-p sampling to the logprobs. + + Min-p keeps all tokens that are above a minimum probability, scaled by the + probability of the most likely token. As a result, the filter is more + aggressive given a very high-probability token. + + Args: + logprobs: A vector of log probabilities. + min_p (float): Minimum token probability. Typical values are in the + 0.01-0.2 range, comparably selective as setting `top_p` in the + 0.99-0.8 range. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered. Default: ``1``. + + """ + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_top_p(logprobs: mx.array, top_p: float) -> mx.array: + """ + Apply top-p (nucleus) sampling to logits. + + Args: + logprobs: A vector of log probabilities. + top_p: The cumulative probability threshold for top-p filtering. + Returns: + token selected based on the top-p criterion. + """ + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def apply_xtc( + logits: mx.array, + xtc_probability: float, + xtc_threshold: float, + xtc_special_tokens: List[int], +) -> mx.array: + """ + Apply XTC sampling to the logits. + + Args: + logits: The logits from the model's output. + xtc_probability (float): Probability of XTC sampling to happen for each token + xtc_threshold (float): The threshold the probs need to reach for being sampled. + special_tokens_ids (list(int)): List of special tokens IDs to be excluded from XTC sampling. + """ + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def categorical_sampling(logits, temp): # -> array: + ... +def make_repetition_penalty( + penalty: float, context_size: int = ... +): # -> Callable[..., Any]: + """ + Make repetition penalty processor. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + penalty (float): The repetition penalty factor to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array]: + The repetition penalty processor. + """ diff --git a/.mlx_typings/mlx_lm/tokenizer_utils.pyi b/.mlx_typings/mlx_lm/tokenizer_utils.pyi new file mode 100644 index 00000000..a0a8355f --- /dev/null +++ b/.mlx_typings/mlx_lm/tokenizer_utils.pyi @@ -0,0 +1,168 @@ +""" +This type stub file was generated by pyright. +""" + +from functools import partial +from pathlib import Path + +from transformers import PreTrainedTokenizerFast + +class StreamingDetokenizer: + """The streaming detokenizer interface so that we can detokenize one token at a time. + + Example usage is as follows: + + detokenizer = ... + + # Reset the tokenizer state + detokenizer.reset() + + for token in generate(...): + detokenizer.add_token(token.item()) + + # Contains the whole text so far. Some tokens may not be included + # since it contains whole words usually. + detokenizer.text + + # Contains the printable segment (usually a word) since the last + # time it was accessed + detokenizer.last_segment + + # Contains all the tokens added so far + detokenizer.tokens + + # Make sure that we detokenize any remaining tokens + detokenizer.finalize() + + # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens) + """ + + __slots__ = ... + def reset(self): ... + def add_token(self, token): ... + def finalize(self): ... + @property + def last_segment(self): + """Return the last segment of readable text since last time this property was accessed.""" + +class NaiveStreamingDetokenizer(StreamingDetokenizer): + """NaiveStreamingDetokenizer relies on the underlying tokenizer + implementation and should work with every tokenizer. + + Its complexity is O(T^2) where T is the longest line since it will + repeatedly detokenize the same tokens until a new line is generated. + """ + def __init__(self, tokenizer) -> None: ... + def reset(self): # -> None: + ... + def add_token(self, token): # -> None: + ... + def finalize(self): # -> None: + ... + @property + def text(self): # -> str: + ... + +class SPMStreamingDetokenizer(StreamingDetokenizer): + """A streaming detokenizer for SPM models. + + It adds tokens to the text if the next token starts with the special SPM + underscore which results in linear complexity. + """ + def __init__(self, tokenizer, trim_space=...) -> None: ... + def reset(self): # -> None: + ... + def add_token(self, token): # -> None: + ... + def finalize(self): # -> None: + ... + +class BPEStreamingDetokenizer(StreamingDetokenizer): + """A streaming detokenizer for OpenAI style BPE models. + + It adds tokens to the text if the next token starts with a space similar to + the SPM detokenizer. + """ + + _byte_decoder = ... + _space_matches = ... + def __init__(self, tokenizer) -> None: ... + def reset(self): # -> None: + ... + def add_token(self, token): # -> None: + ... + def finalize(self): # -> None: + ... + @classmethod + def make_byte_decoder(cls): # -> None: + """See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale.""" + +class TokenizerWrapper: + """A wrapper that combines an HF tokenizer and a detokenizer. + + Accessing any attribute other than the ``detokenizer`` is forwarded to the + huggingface tokenizer. + """ + def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ... + def add_eos_token(self, token: str): # -> None: + ... + @property + def has_thinking(self): # -> bool: + ... + @property + def think_start(self): # -> str | None: + ... + @property + def think_end(self): # -> str | None: + ... + @property + def has_tool_calling(self): # -> bool: + ... + @property + def tool_call_start(self): # -> str | None: + ... + @property + def tool_call_end(self): # -> str | None: + ... + @property + def detokenizer(self): # -> NaiveStreamingDetokenizer: + """ + Get a stateful streaming detokenizer. + """ + + def __getattr__(self, attr): # -> set[Any] | Any: + ... + def __setattr__(self, attr, value): # -> None: + ... + +class NewlineTokenizer(PreTrainedTokenizerFast): + """A tokenizer that replaces newlines with and with new line.""" + def __init__(self, *args, **kwargs) -> None: ... + def encode(self, text, **kwargs): # -> list[int]: + ... + def encode_batch(self, texts, **kwargs): ... + def decode(self, *args, **kwargs): # -> str: + ... + def batch_decode(self, *args, **kwargs): # -> list[str]: + ... + +def load_tokenizer( + model_path: Path, + tokenizer_config_extra=..., + return_tokenizer=..., + eos_token_ids=..., +) -> ( + TokenizerWrapper + | type[SPMStreamingDetokenizer] + | partial[SPMStreamingDetokenizer] + | type[BPEStreamingDetokenizer] + | type[NaiveStreamingDetokenizer] +): + """Load a huggingface tokenizer and try to infer the type of streaming + detokenizer to use. + + Note, to use a fast streaming tokenizer, pass a local file path rather than + a Hugging Face repo ID. + """ + +def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ... diff --git a/.mlx_typings/mlx_lm/utils.pyi b/.mlx_typings/mlx_lm/utils.pyi new file mode 100644 index 00000000..99b207d1 --- /dev/null +++ b/.mlx_typings/mlx_lm/utils.pyi @@ -0,0 +1,195 @@ +""" +This type stub file was generated by pyright. +""" + +import os +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +import mlx.nn as nn +from transformers.utils.auto_docstring import ModelArgs + +from .tokenizer_utils import TokenizerWrapper + +if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": ... +else: ... +MODEL_REMAPPING = ... +MAX_FILE_SIZE_GB = ... + +def compute_bits_per_weight(model): ... +def hf_repo_to_path(hf_repo): # -> Path: + ... +def load_config(model_path: Path) -> dict: ... +def load_model( + model_path: Path, + lazy: bool = False, + strict: bool = True, + model_config: dict[str, Any] = {}, + get_model_classes: Callable[ + [dict[str, Any]], Tuple[Type[nn.Module], Type[ModelArgs]] + ] = ..., +) -> Tuple[nn.Module, dict[str, Any]]: + """ + Load and initialize the model from a given path. + + Args: + model_path (Path): The path to load the model from. + lazy (bool): If False eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` + strict (bool): Whether or not to raise an exception if weights don't + match. Default: ``True`` + model_config (dict, optional): Optional configuration parameters for the + model. Defaults to an empty dictionary. + get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): + A function that returns the model class and model args class given a config. + Defaults to the ``_get_classes`` function. + + Returns: + Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config. + + Raises: + FileNotFoundError: If the weight files (.safetensors) are not found. + ValueError: If the model class or args class are not found or cannot be instantiated. + """ + +def load( + path_or_hf_repo: str, + tokenizer_config=..., + model_config=..., + adapter_path: Optional[str] = ..., + lazy: bool = ..., + return_config: bool = ..., + revision: str = ..., +) -> Union[ + Tuple[nn.Module, TokenizerWrapper], + Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]], +]: + """ + Load the model and tokenizer from a given path or a huggingface repository. + + Args: + path_or_hf_repo (Path): The path or the huggingface repository to load the model from. + tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. + Defaults to an empty dictionary. + model_config(dict, optional): Configuration parameters specifically for the model. + Defaults to an empty dictionary. + adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers + to the model. Default: ``None``. + lazy (bool): If ``False`` eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` + return_config (bool: If ``True`` return the model config as the last item.. + revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. + Returns: + Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]: + A tuple containing the loaded model, tokenizer and, if requested, the model config. + + Raises: + FileNotFoundError: If config file or safetensors are not found. + ValueError: If model class or args class are not found. + """ + +def make_shards(weights: dict, max_file_size_gb: int = ...) -> list: + """ + Splits the weights into smaller shards. + + Args: + weights (dict): Model weights. + max_file_size_gb (int): Maximum size of each shard in gigabytes. + + Returns: + list: List of weight shards. + """ + +def create_model_card( + path: Union[str, Path], hf_path: Union[str, Path, None] +): # -> None: + """ + Uploads the model to Hugging Face hub. + + Args: + path (Union[str, Path]): Local path to the model. + hf_path (Union[str, Path, None]): Path to the original Hugging Face model. + """ + +def upload_to_hub(path: str, upload_repo: str): # -> None: + """ + Uploads the model to Hugging Face hub. + + Args: + path (str): Local path to the model. + upload_repo (str): Name of the HF repo to upload to. + """ + +def save_model( + save_path: Union[str, Path], model: nn.Module, *, donate_model: bool = ... +) -> None: + """Save model weights and metadata index into specified directory.""" + +def quantize_model( + model: nn.Module, + config: dict, + group_size: int, + bits: int, + mode: str = ..., + quant_predicate: Optional[Callable[[str, nn.Module], Union[bool, dict]]] = ..., +) -> Tuple[nn.Module, dict]: + """ + Applies quantization to the model weights. + + Args: + model (nn.Module): The model to be quantized. + config (dict): Model configuration. + group_size (int): Group size for quantization. + bits (int): Bits per weight for quantization. + mode (str): The quantization mode. + quant_predicate (Callable): A callable that decides how to quantize + each layer based on the path. Accepts the layer `path` and the + `module`. Returns either a bool to signify quantize/no quantize or + a dict of quantization parameters to pass to `to_quantized`. + + Returns: + Tuple: Tuple containing quantized model and config. + """ + +def save_config(config: dict, config_path: Union[str, Path]) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + +def save( + dst_path: Union[str, Path], + src_path_or_repo: Union[str, Path], + model: nn.Module, + tokenizer: TokenizerWrapper, + config: Dict[str, Any], + donate_model: bool = ..., +): # -> None: + ... +def common_prefix_len(list1, list2): # -> int: + """ + Calculates the length of the common prefix of two lists. + + Args: + list1: The first list of strings. + list2: The second list of strings. + + Returns: + The length of the common prefix. Returns 0 if lists are empty + or do not match at the first element. + """ + +def does_model_support_input_embeddings(model: nn.Module) -> bool: + """ + Check if the model supports input_embeddings in its call signature. + Args: + model (nn.Module): The model to check. + Returns: + bool: True if the model supports input_embeddings, False otherwise. + """ diff --git a/justfile b/justfile index a61d0bb8..2ef99049 100644 --- a/justfile +++ b/justfile @@ -1,5 +1,5 @@ fmt: - uv run ruff format src typings + uv run ruff format src .mlx_typings lint: uv run ruff check --fix src diff --git a/pyproject.toml b/pyproject.toml index 2113642a..12ff2bdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ build-backend = "uv_build" ### [tool.basedpyright] -include = [".venv/lib/mlx", "src"] +include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"] typeCheckingMode = "strict" failOnWarnings = true @@ -97,8 +97,8 @@ reportUnnecessaryTypeIgnoreComment = "error" pythonVersion = "3.13" pythonPlatform = "Darwin" -exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust"] -stubPath = "typings" +exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust", "**/.github"] +stubPath = ".mlx_typings" [[tool.basedpyright.executionEnvironments]] root = "src" diff --git a/src/exo/engines/mlx/utils_mlx.py b/src/exo/engines/mlx/utils_mlx.py index d1216e73..8d7bde05 100644 --- a/src/exo/engines/mlx/utils_mlx.py +++ b/src/exo/engines/mlx/utils_mlx.py @@ -9,10 +9,10 @@ from mlx_lm.models.cache import KVCache from mlx_lm.sample_utils import make_sampler try: - from mlx_lm.tokenizer_utils import load_tokenizer # type: ignore + from mlx_lm.tokenizer_utils import load_tokenizer except ImportError: from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore -from mlx_lm.utils import load_model # type: ignore +from mlx_lm.utils import load_model from pydantic import RootModel import mlx.core as mx @@ -167,12 +167,11 @@ def shard_and_load( f"loading model from {model_path} with strategy {model_shard_meta.strategy}" ) - model, config = load_model(model_path, lazy=True, strict=False) # type: ignore + model, config = load_model(model_path, lazy=True, strict=False) runner_print(f"{config=}") assert isinstance(model, nn.Module) - tokenizer = load_tokenizer(model_path) # type: ignore - tokenizer = cast(TokenizerWrapper, tokenizer) + tokenizer = cast(TokenizerWrapper, load_tokenizer(model_path)) runner_print(f"Group size: {group.size()}, group rank: {group.rank()}") diff --git a/src/exo/worker/download/impl_shard_downloader.py b/src/exo/worker/download/impl_shard_downloader.py index a00ac5a7..d6c59a80 100644 --- a/src/exo/worker/download/impl_shard_downloader.py +++ b/src/exo/worker/download/impl_shard_downloader.py @@ -31,7 +31,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata: ) -async def build_full_shard(model_id: str) -> PipelineShardMetadata | None: +async def build_full_shard(model_id: str) -> PipelineShardMetadata: base_shard = await build_base_shard(model_id) return PipelineShardMetadata( model_meta=base_shard.model_meta, @@ -150,11 +150,9 @@ class ResumableShardDownloader(ShardDownloader): # print("get_shard_download_status") async def _status_for_model( model_id: str, - ) -> tuple[Path, RepoDownloadProgress] | None: + ) -> tuple[Path, RepoDownloadProgress]: """Helper coroutine that builds the shard for a model and gets its download status.""" shard = await build_full_shard(model_id) - if shard is None: - return None return await download_shard( shard, self.on_progress_wrapper, skip_download=True ) @@ -168,8 +166,6 @@ class ResumableShardDownloader(ShardDownloader): for task in asyncio.as_completed(tasks): try: result = await task - if result is None: - continue path, progress = result yield (path, progress) except Exception as e: diff --git a/src/exo/worker/runner/generate.py b/src/exo/worker/runner/generate.py index d1497263..9fe58d40 100644 --- a/src/exo/worker/runner/generate.py +++ b/src/exo/worker/runner/generate.py @@ -35,16 +35,18 @@ generation_stream = mx.new_stream(mx.default_device()) def maybe_quantize_kv_cache( - prompt_cache: list[Any], + prompt_cache: list[KVCache | Any], quantized_kv_start: int, kv_group_size: int, kv_bits: int | None, ) -> None: if kv_bits is None: return - for e, c in enumerate(prompt_cache): # type: ignore[type-arg] - if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start: # type: ignore[type-arg] - prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) # type: ignore[type-arg] + for e, c in enumerate(prompt_cache): + if ( + hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore + ): + prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) def generate_step( @@ -189,7 +191,7 @@ def generate_step( quantize_cache_fn(prompt_cache) start_time = time.time() - mx.eval([c.state for c in prompt_cache]) # type: ignore + mx.eval([c.state for c in prompt_cache]) eval_time = time.time() - start_time prompt_processed_tokens += n_to_process @@ -221,9 +223,17 @@ def generate_step( n = 0 while True: - mx.eval(y, logprobs) + assert y is not None + assert logprobs is not None + if n != max_tokens: + next_y, next_logprobs = _step(y) + mx.async_eval(next_y, next_logprobs) + if n == 0: + mx.eval(y) + prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) + if n == max_tokens: + break yield int(y.item()), logprobs - n += 1 if n % 256 == 0: mx.clear_cache() if n == max_tokens: