mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
MLX LM type stubs
This commit is contained in:
1
.github/scripts/bench.py
vendored
1
.github/scripts/bench.py
vendored
@@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
# type: ignore
|
|
||||||
"""
|
"""
|
||||||
Unified benchmark script for EXO.
|
Unified benchmark script for EXO.
|
||||||
Runs single or multi-stage benchmarks with configurable load patterns.
|
Runs single or multi-stage benchmarks with configurable load patterns.
|
||||||
|
|||||||
@@ -2614,7 +2614,7 @@ type MX_ARRAY_TREE = (
|
|||||||
| Mapping[str, MX_ARRAY_TREE]
|
| Mapping[str, MX_ARRAY_TREE]
|
||||||
)
|
)
|
||||||
|
|
||||||
def eval(*args: MX_ARRAY_TREE) -> None:
|
def eval(*args: MX_ARRAY_TREE | None) -> None:
|
||||||
"""
|
"""
|
||||||
Evaluate an :class:`array` or tree of :class:`array`.
|
Evaluate an :class:`array` or tree of :class:`array`.
|
||||||
|
|
||||||
2
.mlx_typings/mlx_lm/__init__.pyi
Normal file
2
.mlx_typings/mlx_lm/__init__.pyi
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
import models as models
|
||||||
|
import tokenizer_utils as tokenizer_utils
|
||||||
45
.mlx_typings/mlx_lm/convert.pyi
Normal file
45
.mlx_typings/mlx_lm/convert.pyi
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
def mixed_quant_predicate_builder(
|
||||||
|
recipe: str, model: nn.Module, group_size: int = ...
|
||||||
|
) -> Callable[[str, nn.Module, dict], Union[bool, dict]]: ...
|
||||||
|
|
||||||
|
QUANT_RECIPES = ...
|
||||||
|
MODEL_CONVERSION_DTYPES = ...
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
hf_path: str,
|
||||||
|
mlx_path: str = ...,
|
||||||
|
quantize: bool = ...,
|
||||||
|
q_group_size: int = ...,
|
||||||
|
q_bits: int = ...,
|
||||||
|
q_mode: str = ...,
|
||||||
|
dtype: Optional[str] = ...,
|
||||||
|
upload_repo: str = ...,
|
||||||
|
revision: Optional[str] = ...,
|
||||||
|
dequantize: bool = ...,
|
||||||
|
quant_predicate: Optional[
|
||||||
|
Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str]
|
||||||
|
] = ...,
|
||||||
|
trust_remote_code: bool = ...,
|
||||||
|
): # -> None:
|
||||||
|
...
|
||||||
|
def configure_parser() -> argparse.ArgumentParser:
|
||||||
|
"""
|
||||||
|
Configures and returns the argument parser for the script.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.ArgumentParser: Configured argument parser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def main(): # -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
if __name__ == "__main__": ...
|
||||||
324
.mlx_typings/mlx_lm/generate.pyi
Normal file
324
.mlx_typings/mlx_lm/generate.pyi
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
DEFAULT_PROMPT = ...
|
||||||
|
DEFAULT_MAX_TOKENS = ...
|
||||||
|
DEFAULT_TEMP = ...
|
||||||
|
DEFAULT_TOP_P = ...
|
||||||
|
DEFAULT_MIN_P = ...
|
||||||
|
DEFAULT_TOP_K = ...
|
||||||
|
DEFAULT_XTC_PROBABILITY = ...
|
||||||
|
DEFAULT_XTC_THRESHOLD = ...
|
||||||
|
DEFAULT_MIN_TOKENS_TO_KEEP = ...
|
||||||
|
DEFAULT_SEED = ...
|
||||||
|
DEFAULT_MODEL = ...
|
||||||
|
DEFAULT_QUANTIZED_KV_START = ...
|
||||||
|
|
||||||
|
def str2bool(string): # -> bool:
|
||||||
|
...
|
||||||
|
def setup_arg_parser(): # -> ArgumentParser:
|
||||||
|
"""Set up and return the argument parser."""
|
||||||
|
|
||||||
|
generation_stream = ...
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def wired_limit(
|
||||||
|
model: nn.Module, streams: Optional[List[mx.Stream]] = ...
|
||||||
|
): # -> Generator[None, Any, None]:
|
||||||
|
"""
|
||||||
|
A context manager to temporarily change the wired limit.
|
||||||
|
|
||||||
|
Note, the wired limit should not be changed during an async eval. If an
|
||||||
|
async eval could be running pass in the streams to synchronize with prior
|
||||||
|
to exiting the context manager.
|
||||||
|
"""
|
||||||
|
@dataclass
|
||||||
|
class GenerationResponse:
|
||||||
|
"""
|
||||||
|
The output of :func:`stream_generate`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The next segment of decoded text. This can be an empty string.
|
||||||
|
token (int): The next token.
|
||||||
|
from_draft (bool): Whether the token was generated by the draft model.
|
||||||
|
logprobs (mx.array): A vector of log probabilities.
|
||||||
|
prompt_tokens (int): The number of tokens in the prompt.
|
||||||
|
prompt_tps (float): The prompt processing tokens-per-second.
|
||||||
|
generation_tokens (int): The number of generated tokens.
|
||||||
|
generation_tps (float): The tokens-per-second for generation.
|
||||||
|
peak_memory (float): The peak memory used so far in GB.
|
||||||
|
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
token: int
|
||||||
|
logprobs: mx.array
|
||||||
|
from_draft: bool
|
||||||
|
prompt_tokens: int
|
||||||
|
prompt_tps: float
|
||||||
|
generation_tokens: int
|
||||||
|
generation_tps: float
|
||||||
|
peak_memory: float
|
||||||
|
finish_reason: Optional[str] = ...
|
||||||
|
|
||||||
|
def maybe_quantize_kv_cache(
|
||||||
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||||
|
): # -> None:
|
||||||
|
...
|
||||||
|
def generate_step(
|
||||||
|
prompt: mx.array,
|
||||||
|
model: nn.Module,
|
||||||
|
*,
|
||||||
|
max_tokens: int = ...,
|
||||||
|
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||||
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
|
||||||
|
max_kv_size: Optional[int] = ...,
|
||||||
|
prompt_cache: Optional[Any] = ...,
|
||||||
|
prefill_step_size: int = ...,
|
||||||
|
kv_bits: Optional[int] = ...,
|
||||||
|
kv_group_size: int = ...,
|
||||||
|
quantized_kv_start: int = ...,
|
||||||
|
prompt_progress_callback: Optional[Callable[[int], int]] = ...,
|
||||||
|
input_embeddings: Optional[mx.array] = ...,
|
||||||
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
|
"""
|
||||||
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (mx.array): The input prompt.
|
||||||
|
model (nn.Module): The model to use for generation.
|
||||||
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||||
|
generator. Default: ``256``.
|
||||||
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||||
|
token from a vector of log probabilities. Default: ``None``.
|
||||||
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
|
A list of functions that take tokens and logits and return the processed
|
||||||
|
logits. Default: ``None``.
|
||||||
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
|
provided, the cache will be updated in place.
|
||||||
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||||
|
None implies no cache quantization. Default: ``None``.
|
||||||
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||||
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||||
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
prompt_progress_callback (Callable[[int], int]): A call-back which takes the
|
||||||
|
prompt tokens processed so far and the total number of prompt tokens.
|
||||||
|
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
|
||||||
|
conjunction with prompt tokens. Default: ``None``.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def speculative_generate_step(
|
||||||
|
prompt: mx.array,
|
||||||
|
model: nn.Module,
|
||||||
|
draft_model: nn.Module,
|
||||||
|
*,
|
||||||
|
num_draft_tokens: int = ...,
|
||||||
|
max_tokens: int = ...,
|
||||||
|
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||||
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
|
||||||
|
prompt_cache: Optional[Any] = ...,
|
||||||
|
prefill_step_size: int = ...,
|
||||||
|
kv_bits: Optional[int] = ...,
|
||||||
|
kv_group_size: int = ...,
|
||||||
|
quantized_kv_start: int = ...,
|
||||||
|
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
|
||||||
|
"""
|
||||||
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (mx.array): The input prompt.
|
||||||
|
model (nn.Module): The model to use for generation.
|
||||||
|
draft_model (nn.Module): The draft model for speculative decoding.
|
||||||
|
num_draft_tokens (int, optional): The number of draft tokens for
|
||||||
|
speculative decoding. Default: ``2``.
|
||||||
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||||
|
generator. Default: ``256``.
|
||||||
|
sampler (Callable[[mx.array], mx.array], optional): A sampler for sampling a
|
||||||
|
token from a vector of log probabilities. Default: ``None``.
|
||||||
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
|
A list of functions that take tokens and logits and return the processed
|
||||||
|
logits. Default: ``None``.
|
||||||
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
|
provided, the cache will be updated in place. The cache must be trimmable.
|
||||||
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||||
|
None implies no cache quantization. Default: ``None``.
|
||||||
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||||
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||||
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
|
||||||
|
and a bool indicating if the token was generated by the draft model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def stream_generate(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
|
prompt: Union[str, mx.array, List[int]],
|
||||||
|
max_tokens: int = ...,
|
||||||
|
draft_model: Optional[nn.Module] = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> Generator[GenerationResponse, None, None]:
|
||||||
|
"""
|
||||||
|
A generator producing text based on the given prompt from the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to use for generation.
|
||||||
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
|
prompt (Union[str, mx.array, List[int]]): The input prompt string or
|
||||||
|
integer tokens.
|
||||||
|
max_tokens (int): The maximum number of tokens to generate.
|
||||||
|
Default: ``256``.
|
||||||
|
draft_model (Optional[nn.Module]): An optional draft model. If provided
|
||||||
|
then speculative decoding is used. The draft model must use the same
|
||||||
|
tokenizer as the main model. Default: ``None``.
|
||||||
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
|
See :func:`generate_step` for more details.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
GenerationResponse: An instance containing the generated text segment and
|
||||||
|
associated metadata. See :class:`GenerationResponse` for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
|
prompt: Union[str, List[int]],
|
||||||
|
verbose: bool = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a complete response from the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The language model.
|
||||||
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
|
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||||
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
|
Default: ``False``.
|
||||||
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||||
|
See :func:`stream_generate` for more details.
|
||||||
|
"""
|
||||||
|
@dataclass
|
||||||
|
class BatchStats:
|
||||||
|
"""
|
||||||
|
An data object to hold generation stats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_tokens (int): The number of prompt tokens processed.
|
||||||
|
prompt_tps (float): The prompt processing tokens-per-second.
|
||||||
|
prompt_time (float): The time in seconds spent in prompt processing.
|
||||||
|
generation_tokens (int): The number of generated tokens.
|
||||||
|
generation_tps (float): The tokens-per-second for generation.
|
||||||
|
generation_time (float): The time in seconds spent in generation .
|
||||||
|
peak_memory (float): The peak memory used so far in GB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens: int = ...
|
||||||
|
prompt_tps: float = ...
|
||||||
|
prompt_time: float = ...
|
||||||
|
generation_tokens: int = ...
|
||||||
|
generation_tps: float = ...
|
||||||
|
generation_time: float = ...
|
||||||
|
peak_memory: float = ...
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchResponse:
|
||||||
|
"""
|
||||||
|
An data object to hold a batch generation response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: (List[str]): The generated text for each prompt.
|
||||||
|
stats (BatchStats): Statistics about the generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
texts: List[str]
|
||||||
|
stats: BatchStats
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Batch:
|
||||||
|
uids: List[int]
|
||||||
|
y: mx.array
|
||||||
|
logprobs: mx.array
|
||||||
|
max_tokens: List[int]
|
||||||
|
num_tokens: List[int]
|
||||||
|
cache: List[Any]
|
||||||
|
def __len__(self): # -> int:
|
||||||
|
...
|
||||||
|
def filter(self, keep_idx: List[int]): # -> None:
|
||||||
|
...
|
||||||
|
def extend(self, other): # -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class BatchGenerator:
|
||||||
|
@dataclass
|
||||||
|
class Response:
|
||||||
|
uid: int
|
||||||
|
token: int
|
||||||
|
logprobs: mx.array
|
||||||
|
finish_reason: Optional[str]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
max_tokens: int = ...,
|
||||||
|
stop_tokens: Optional[set] = ...,
|
||||||
|
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||||
|
completion_batch_size: int = ...,
|
||||||
|
prefill_batch_size: int = ...,
|
||||||
|
prefill_step_size: int = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def insert(
|
||||||
|
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||||
|
): # -> list[Any]:
|
||||||
|
...
|
||||||
|
def stats(self): # -> BatchStats:
|
||||||
|
...
|
||||||
|
def next(self): # -> list[Any]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def batch_generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompts: List[int],
|
||||||
|
max_tokens: Union[int, List[int]] = ...,
|
||||||
|
verbose: bool = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchResponse:
|
||||||
|
"""
|
||||||
|
Generate responses for the given batch of prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The language model.
|
||||||
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
|
prompt (List[List[int]]): The input prompts.
|
||||||
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
|
Default: ``False``.
|
||||||
|
max_tokens (Union[int, List[int]): Maximum number of output tokens. This
|
||||||
|
can be per prompt if a list is provided.
|
||||||
|
kwargs: The remaining options get passed to :obj:`BatchGenerator`.
|
||||||
|
See :obj:`BatchGenerator` for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def main(): # -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
if __name__ == "__main__": ...
|
||||||
1
.mlx_typings/mlx_lm/models/__init__.pyi
Normal file
1
.mlx_typings/mlx_lm/models/__init__.pyi
Normal file
@@ -0,0 +1 @@
|
|||||||
|
import cache as cache
|
||||||
47
.mlx_typings/mlx_lm/models/base.pyi
Normal file
47
.mlx_typings/mlx_lm/models/base.pyi
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelArgs:
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params): # -> Self:
|
||||||
|
...
|
||||||
|
|
||||||
|
def create_causal_mask(
|
||||||
|
N: int,
|
||||||
|
offset: int = ...,
|
||||||
|
window_size: Optional[int] = ...,
|
||||||
|
right_padding: Optional[mx.array] = ...,
|
||||||
|
left_padding: Optional[mx.array] = ...,
|
||||||
|
): # -> array:
|
||||||
|
...
|
||||||
|
def create_attention_mask(
|
||||||
|
h, cache=..., window_size: Optional[int] = ..., return_array: bool = ...
|
||||||
|
): # -> array | Literal['causal'] | None:
|
||||||
|
...
|
||||||
|
def create_ssm_mask(h, cache=...): # -> None:
|
||||||
|
...
|
||||||
|
def quantized_scaled_dot_product_attention(
|
||||||
|
queries: mx.array,
|
||||||
|
q_keys: tuple[mx.array, mx.array, mx.array],
|
||||||
|
q_values: tuple[mx.array, mx.array, mx.array],
|
||||||
|
scale: float,
|
||||||
|
mask: Optional[mx.array],
|
||||||
|
group_size: int = ...,
|
||||||
|
bits: int = ...,
|
||||||
|
) -> mx.array: ...
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
cache,
|
||||||
|
scale: float,
|
||||||
|
mask: Optional[mx.array],
|
||||||
|
sinks: Optional[mx.array] = ...,
|
||||||
|
) -> mx.array: ...
|
||||||
26
.mlx_typings/mlx_lm/models/bitlinear_layers.pyi
Normal file
26
.mlx_typings/mlx_lm/models/bitlinear_layers.pyi
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
def bitnet_quantize(model, quantization_config: dict): ...
|
||||||
|
def make_bitlinear_kernel():
|
||||||
|
"""
|
||||||
|
Custom Metal kernel that performs matrix multiplication directly on
|
||||||
|
packed weights and scales the output. This eliminates the need to
|
||||||
|
store unpacked weights in memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_bitlinear_kernel = ...
|
||||||
|
|
||||||
|
class BitLinear(nn.Module):
|
||||||
|
"""
|
||||||
|
BitLinear module with memory-efficient weight handling.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self, in_features, out_features, bias=..., invert_weight_scales=...
|
||||||
|
) -> None: ...
|
||||||
|
def execute_matmul_kernel(self, x, packed_weights): ...
|
||||||
|
def __call__(self, x): # -> array:
|
||||||
|
...
|
||||||
354
.mlx_typings/mlx_lm/models/cache.pyi
Normal file
354
.mlx_typings/mlx_lm/models/cache.pyi
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.core import array
|
||||||
|
|
||||||
|
def make_prompt_cache(
|
||||||
|
model: nn.Module, max_kv_size: Optional[int] = ...
|
||||||
|
) -> List[KVCache | Any]:
|
||||||
|
"""
|
||||||
|
Construct the model's cache for use in generation.
|
||||||
|
|
||||||
|
This function will defer the cache construction to the model if it has a
|
||||||
|
``make_cache`` method, otherwise it will make a default KV cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The language model.
|
||||||
|
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||||
|
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||||
|
size of ``max_kv_size``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_prompt_cache(
|
||||||
|
file_name: str, cache: List[Any], metadata: Dict[str, str] = ...
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save a pre-computed prompt cache to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name (str): The ``.safetensors`` file name.
|
||||||
|
cache (List[Any]): The model state.
|
||||||
|
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||||
|
state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_prompt_cache(
|
||||||
|
file_name, return_metadata=...
|
||||||
|
): # -> tuple[list[Any], Any] | list[Any]:
|
||||||
|
"""
|
||||||
|
Load a prompt cache from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name (str): The ``.safetensors`` file name.
|
||||||
|
return_metadata (bool): Whether or not to return metadata.
|
||||||
|
Default: ``False``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||||
|
the metadata if requested.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model's cache can be trimmed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Trim the model's cache by the given number of tokens.
|
||||||
|
|
||||||
|
This function will trim the cache if possible (in-place) and return the
|
||||||
|
number of tokens that were trimmed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache (List[Any]): The model's cache.
|
||||||
|
num_tokens (int): The number of tokens to trim.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(int): The number of tokens that were trimmed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_attention_mask(
|
||||||
|
N: int, offset: int, return_array: bool, window_size: Optional[int]
|
||||||
|
): # -> array | Literal['causal'] | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class _BaseCache:
|
||||||
|
@property
|
||||||
|
def state(self): # -> list[Any]:
|
||||||
|
...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def meta_state(self): # -> Literal['']:
|
||||||
|
...
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
def is_trimmable(self): # -> Literal[False]:
|
||||||
|
...
|
||||||
|
@classmethod
|
||||||
|
def from_state(cls, state, meta_state): # -> Self:
|
||||||
|
...
|
||||||
|
|
||||||
|
class ConcatenateKVCache(_BaseCache):
|
||||||
|
"""ConcatenateKVCache the simplest KV cache implementation.
|
||||||
|
|
||||||
|
Can be used as a mock KV cache or when large blocks are being processed at
|
||||||
|
a time in which case KVCache isn't necessarily faster. Consider using the
|
||||||
|
KVCache with a larger step size before using this cache.
|
||||||
|
"""
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def state(self): # -> tuple[Any | array | None, Any | array | None]:
|
||||||
|
...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
def is_trimmable(self): # -> Literal[True]:
|
||||||
|
...
|
||||||
|
def trim(self, n): # -> int:
|
||||||
|
...
|
||||||
|
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class QuantizedKVCache(_BaseCache):
|
||||||
|
step = ...
|
||||||
|
def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...
|
||||||
|
def update_and_fetch(self, keys, values): # -> Any:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def state(
|
||||||
|
self,
|
||||||
|
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
|
||||||
|
...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def meta_state(self): # -> tuple[str, ...]:
|
||||||
|
...
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
def is_trimmable(self): # -> Literal[True]:
|
||||||
|
...
|
||||||
|
def trim(self, n): # -> int:
|
||||||
|
...
|
||||||
|
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class KVCache(_BaseCache):
|
||||||
|
step = ...
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def state(
|
||||||
|
self,
|
||||||
|
) -> tuple[array, array]: ...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v) -> None: ...
|
||||||
|
def is_trimmable(self): # -> Literal[True]:
|
||||||
|
...
|
||||||
|
def trim(self, n): # -> int:
|
||||||
|
...
|
||||||
|
def to_quantized(
|
||||||
|
self, group_size: int = ..., bits: int = ...
|
||||||
|
) -> QuantizedKVCache: ...
|
||||||
|
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class RotatingKVCache(_BaseCache):
|
||||||
|
step = ...
|
||||||
|
def __init__(self, max_size, keep=...) -> None: ...
|
||||||
|
def update_and_fetch(
|
||||||
|
self, keys, values
|
||||||
|
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def state(
|
||||||
|
self,
|
||||||
|
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
|
||||||
|
...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def meta_state(self): # -> tuple[str, ...]:
|
||||||
|
...
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
def is_trimmable(self): # -> bool:
|
||||||
|
...
|
||||||
|
def trim(self, n): # -> int:
|
||||||
|
...
|
||||||
|
def to_quantized(
|
||||||
|
self, group_size: int = ..., bits: int = ...
|
||||||
|
) -> QuantizedKVCache: ...
|
||||||
|
def make_mask(
|
||||||
|
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||||
|
): # -> array | Literal['causal'] | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class ArraysCache(_BaseCache):
|
||||||
|
def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||||
|
def __setitem__(self, idx, value): # -> None:
|
||||||
|
...
|
||||||
|
def __getitem__(self, idx): ...
|
||||||
|
@property
|
||||||
|
def state(self): # -> list[Any | array] | list[array]:
|
||||||
|
...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
def filter(self, batch_indices): # -> None:
|
||||||
|
"""
|
||||||
|
In-place filter to keep just the given indices in the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extend(self, other): # -> None:
|
||||||
|
"""
|
||||||
|
In-place extend this cache with the other cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def make_mask(self, N: int): # -> array | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class MambaCache(ArraysCache):
|
||||||
|
def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class ChunkedKVCache(KVCache):
|
||||||
|
def __init__(self, chunk_size) -> None: ...
|
||||||
|
def maybe_trim_front(self): # -> None:
|
||||||
|
...
|
||||||
|
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
|
||||||
|
...
|
||||||
|
def trim(self, n): # -> int:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def meta_state(self): # -> tuple[str, ...]:
|
||||||
|
...
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class CacheList(_BaseCache):
|
||||||
|
def __init__(self, *caches) -> None: ...
|
||||||
|
def __getitem__(self, idx): ...
|
||||||
|
def is_trimmable(self): # -> bool:
|
||||||
|
...
|
||||||
|
def trim(self, n): ...
|
||||||
|
@property
|
||||||
|
def state(self): # -> list[Any]:
|
||||||
|
...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
def filter(self, batch_indices): # -> None:
|
||||||
|
"""
|
||||||
|
In-place filter to keep just the given indices in the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extend(self, other): # -> None:
|
||||||
|
"""
|
||||||
|
In-place extend this cache with the other cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class BatchKVCache(_BaseCache):
|
||||||
|
step = ...
|
||||||
|
def __init__(self, left_padding: List[int]) -> None:
|
||||||
|
"""
|
||||||
|
The BatchKV cache expects inputs to be left-padded.
|
||||||
|
|
||||||
|
E.g. the following prompts:
|
||||||
|
|
||||||
|
[1, 3, 5]
|
||||||
|
[7]
|
||||||
|
[2, 6, 8, 9]
|
||||||
|
|
||||||
|
Should be padded like so:
|
||||||
|
|
||||||
|
[0, 1, 3, 5]
|
||||||
|
[0, 0, 0, 7]
|
||||||
|
[2, 6, 8, 9]
|
||||||
|
|
||||||
|
And ``left_padding`` specifies the amount of padding for each.
|
||||||
|
In this case, ``left_padding = [1, 3, 0]``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def state(
|
||||||
|
self,
|
||||||
|
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||||
|
...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
def is_trimmable(self): # -> Literal[True]:
|
||||||
|
...
|
||||||
|
def trim(self, n): # -> int | float:
|
||||||
|
...
|
||||||
|
def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array:
|
||||||
|
...
|
||||||
|
def filter(self, batch_indices): # -> None:
|
||||||
|
"""
|
||||||
|
In-place filter to keep just the given indices in the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extend(self, other): # -> None:
|
||||||
|
"""
|
||||||
|
In-place extend this cache with the other cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class BatchRotatingKVCache(_BaseCache):
|
||||||
|
step = ...
|
||||||
|
def __init__(self, max_size, left_padding: List[int]) -> None: ...
|
||||||
|
def update_and_fetch(
|
||||||
|
self, keys, values
|
||||||
|
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def state(
|
||||||
|
self,
|
||||||
|
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||||
|
...
|
||||||
|
@state.setter
|
||||||
|
def state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def meta_state(self): # -> tuple[str, ...]:
|
||||||
|
...
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v): # -> None:
|
||||||
|
...
|
||||||
|
def is_trimmable(self): # -> bool:
|
||||||
|
...
|
||||||
|
def trim(self, n): # -> int:
|
||||||
|
...
|
||||||
|
def to_quantized(
|
||||||
|
self, group_size: int = ..., bits: int = ...
|
||||||
|
) -> QuantizedKVCache: ...
|
||||||
|
def make_mask(
|
||||||
|
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||||
|
): # -> array:
|
||||||
|
...
|
||||||
|
def filter(self, batch_indices): # -> None:
|
||||||
|
"""
|
||||||
|
In-place filter to keep just the given indices in the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extend(self, other): # -> None:
|
||||||
|
"""
|
||||||
|
In-place extend this cache with the other cache.
|
||||||
|
"""
|
||||||
79
.mlx_typings/mlx_lm/models/switch_layers.pyi
Normal file
79
.mlx_typings/mlx_lm/models/switch_layers.pyi
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class QuantizedSwitchLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
num_experts: int,
|
||||||
|
bias: bool = ...,
|
||||||
|
group_size: int = ...,
|
||||||
|
bits: int = ...,
|
||||||
|
mode: str = ...,
|
||||||
|
) -> None: ...
|
||||||
|
@property
|
||||||
|
def input_dims(self): # -> int:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def output_dims(self): # -> int:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def num_experts(self): # -> int:
|
||||||
|
...
|
||||||
|
def __call__(self, x, indices, sorted_indices=...): # -> array:
|
||||||
|
...
|
||||||
|
|
||||||
|
class SwitchLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = ...
|
||||||
|
) -> None: ...
|
||||||
|
@property
|
||||||
|
def input_dims(self): # -> int:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def output_dims(self): # -> int:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def num_experts(self): # -> int:
|
||||||
|
...
|
||||||
|
def __call__(self, x, indices, sorted_indices=...): ...
|
||||||
|
def to_quantized(
|
||||||
|
self, group_size: int = ..., bits: int = ..., mode: str = ...
|
||||||
|
): # -> QuantizedSwitchLinear:
|
||||||
|
...
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def swiglu(x, gate): ...
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, x, gate): ...
|
||||||
|
|
||||||
|
class SwitchGLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
hidden_dims: int,
|
||||||
|
num_experts: int,
|
||||||
|
activation=...,
|
||||||
|
bias: bool = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def __call__(self, x, indices) -> mx.array: ...
|
||||||
|
|
||||||
|
class SwitchMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
hidden_dims: int,
|
||||||
|
num_experts: int,
|
||||||
|
activation=...,
|
||||||
|
bias: bool = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def __call__(self, x, indices) -> mx.array: ...
|
||||||
148
.mlx_typings/mlx_lm/sample_utils.pyi
Normal file
148
.mlx_typings/mlx_lm/sample_utils.pyi
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
def make_sampler(
|
||||||
|
temp: float = ...,
|
||||||
|
top_p: float = ...,
|
||||||
|
min_p: float = ...,
|
||||||
|
min_tokens_to_keep: int = ...,
|
||||||
|
top_k: int = ...,
|
||||||
|
xtc_probability: float = ...,
|
||||||
|
xtc_threshold: float = ...,
|
||||||
|
xtc_special_tokens: List[int] = ...,
|
||||||
|
) -> Callable[[mx.array], mx.array]:
|
||||||
|
"""
|
||||||
|
Make a sampler function for use with ``generate_step``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
|
Default: ``0``.
|
||||||
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||||
|
more less likely words.
|
||||||
|
min_p (float, optional): The minimum value (scaled by the top token's
|
||||||
|
probability) that a token probability must have to be considered.
|
||||||
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
|
be filtered by min_p sampling.
|
||||||
|
top_k (int, optional): The top k tokens ranked by probability to constrain
|
||||||
|
the sampling to.
|
||||||
|
xtc_probability (float, optional): The probability of applying XTC
|
||||||
|
sampling.
|
||||||
|
xtc_threshold (float, optional): The threshold the probs need to reach
|
||||||
|
for being sampled.
|
||||||
|
xtc_special_tokens (list(int), optional): List of special tokens IDs to
|
||||||
|
be excluded from XTC sampling.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[mx.array, mx.array]:
|
||||||
|
A sampler which takes log-probabilities and returns tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def make_logits_processors(
|
||||||
|
logit_bias: Optional[Dict[int, float]] = ...,
|
||||||
|
repetition_penalty: Optional[float] = ...,
|
||||||
|
repetition_context_size: Optional[int] = ...,
|
||||||
|
): # -> list[Any]:
|
||||||
|
"""
|
||||||
|
Make logits processors for use with ``generate_step``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repetition_penalty (float, optional): The penalty factor for repeating
|
||||||
|
tokens.
|
||||||
|
repetition_context_size (int, optional): The number of tokens to
|
||||||
|
consider for repetition penalty. Default: ``20``.
|
||||||
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Callable[[mx.array, mx.array], mx.array]]:
|
||||||
|
A list of logits processors. Each processor in the list is a
|
||||||
|
callable which takes an array of tokens and an array of logits
|
||||||
|
and returns the updated logits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
def apply_top_k(logprobs: mx.array, top_k: int) -> mx.array:
|
||||||
|
"""
|
||||||
|
Sample from only the top K tokens ranked by probability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs: A vector of log probabilities.
|
||||||
|
top_k (int): Top k tokens to sample from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
def apply_min_p(
|
||||||
|
logprobs: mx.array, min_p: float, min_tokens_to_keep: int = ...
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Apply min-p sampling to the logprobs.
|
||||||
|
|
||||||
|
Min-p keeps all tokens that are above a minimum probability, scaled by the
|
||||||
|
probability of the most likely token. As a result, the filter is more
|
||||||
|
aggressive given a very high-probability token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs: A vector of log probabilities.
|
||||||
|
min_p (float): Minimum token probability. Typical values are in the
|
||||||
|
0.01-0.2 range, comparably selective as setting `top_p` in the
|
||||||
|
0.99-0.8 range.
|
||||||
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
|
be filtered. Default: ``1``.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
def apply_top_p(logprobs: mx.array, top_p: float) -> mx.array:
|
||||||
|
"""
|
||||||
|
Apply top-p (nucleus) sampling to logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs: A vector of log probabilities.
|
||||||
|
top_p: The cumulative probability threshold for top-p filtering.
|
||||||
|
Returns:
|
||||||
|
token selected based on the top-p criterion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
def apply_xtc(
|
||||||
|
logits: mx.array,
|
||||||
|
xtc_probability: float,
|
||||||
|
xtc_threshold: float,
|
||||||
|
xtc_special_tokens: List[int],
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Apply XTC sampling to the logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: The logits from the model's output.
|
||||||
|
xtc_probability (float): Probability of XTC sampling to happen for each token
|
||||||
|
xtc_threshold (float): The threshold the probs need to reach for being sampled.
|
||||||
|
special_tokens_ids (list(int)): List of special tokens IDs to be excluded from XTC sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
def categorical_sampling(logits, temp): # -> array:
|
||||||
|
...
|
||||||
|
def make_repetition_penalty(
|
||||||
|
penalty: float, context_size: int = ...
|
||||||
|
): # -> Callable[..., Any]:
|
||||||
|
"""
|
||||||
|
Make repetition penalty processor.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1909.05858
|
||||||
|
|
||||||
|
Args:
|
||||||
|
penalty (float): The repetition penalty factor to be applied.
|
||||||
|
context_size (int): The number of previous tokens to use.
|
||||||
|
Default: ``20``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[mx.array, List[int]], mx.array]:
|
||||||
|
The repetition penalty processor.
|
||||||
|
"""
|
||||||
168
.mlx_typings/mlx_lm/tokenizer_utils.pyi
Normal file
168
.mlx_typings/mlx_lm/tokenizer_utils.pyi
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
class StreamingDetokenizer:
|
||||||
|
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||||
|
|
||||||
|
Example usage is as follows:
|
||||||
|
|
||||||
|
detokenizer = ...
|
||||||
|
|
||||||
|
# Reset the tokenizer state
|
||||||
|
detokenizer.reset()
|
||||||
|
|
||||||
|
for token in generate(...):
|
||||||
|
detokenizer.add_token(token.item())
|
||||||
|
|
||||||
|
# Contains the whole text so far. Some tokens may not be included
|
||||||
|
# since it contains whole words usually.
|
||||||
|
detokenizer.text
|
||||||
|
|
||||||
|
# Contains the printable segment (usually a word) since the last
|
||||||
|
# time it was accessed
|
||||||
|
detokenizer.last_segment
|
||||||
|
|
||||||
|
# Contains all the tokens added so far
|
||||||
|
detokenizer.tokens
|
||||||
|
|
||||||
|
# Make sure that we detokenize any remaining tokens
|
||||||
|
detokenizer.finalize()
|
||||||
|
|
||||||
|
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ...
|
||||||
|
def reset(self): ...
|
||||||
|
def add_token(self, token): ...
|
||||||
|
def finalize(self): ...
|
||||||
|
@property
|
||||||
|
def last_segment(self):
|
||||||
|
"""Return the last segment of readable text since last time this property was accessed."""
|
||||||
|
|
||||||
|
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||||
|
implementation and should work with every tokenizer.
|
||||||
|
|
||||||
|
Its complexity is O(T^2) where T is the longest line since it will
|
||||||
|
repeatedly detokenize the same tokens until a new line is generated.
|
||||||
|
"""
|
||||||
|
def __init__(self, tokenizer) -> None: ...
|
||||||
|
def reset(self): # -> None:
|
||||||
|
...
|
||||||
|
def add_token(self, token): # -> None:
|
||||||
|
...
|
||||||
|
def finalize(self): # -> None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def text(self): # -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""A streaming detokenizer for SPM models.
|
||||||
|
|
||||||
|
It adds tokens to the text if the next token starts with the special SPM
|
||||||
|
underscore which results in linear complexity.
|
||||||
|
"""
|
||||||
|
def __init__(self, tokenizer, trim_space=...) -> None: ...
|
||||||
|
def reset(self): # -> None:
|
||||||
|
...
|
||||||
|
def add_token(self, token): # -> None:
|
||||||
|
...
|
||||||
|
def finalize(self): # -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""A streaming detokenizer for OpenAI style BPE models.
|
||||||
|
|
||||||
|
It adds tokens to the text if the next token starts with a space similar to
|
||||||
|
the SPM detokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_byte_decoder = ...
|
||||||
|
_space_matches = ...
|
||||||
|
def __init__(self, tokenizer) -> None: ...
|
||||||
|
def reset(self): # -> None:
|
||||||
|
...
|
||||||
|
def add_token(self, token): # -> None:
|
||||||
|
...
|
||||||
|
def finalize(self): # -> None:
|
||||||
|
...
|
||||||
|
@classmethod
|
||||||
|
def make_byte_decoder(cls): # -> None:
|
||||||
|
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||||
|
|
||||||
|
class TokenizerWrapper:
|
||||||
|
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
||||||
|
|
||||||
|
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||||
|
huggingface tokenizer.
|
||||||
|
"""
|
||||||
|
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
|
||||||
|
def add_eos_token(self, token: str): # -> None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def has_thinking(self): # -> bool:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def think_start(self): # -> str | None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def think_end(self): # -> str | None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def has_tool_calling(self): # -> bool:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def tool_call_start(self): # -> str | None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def tool_call_end(self): # -> str | None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def detokenizer(self): # -> NaiveStreamingDetokenizer:
|
||||||
|
"""
|
||||||
|
Get a stateful streaming detokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getattr__(self, attr): # -> set[Any] | Any:
|
||||||
|
...
|
||||||
|
def __setattr__(self, attr, value): # -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||||
|
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||||
|
def __init__(self, *args, **kwargs) -> None: ...
|
||||||
|
def encode(self, text, **kwargs): # -> list[int]:
|
||||||
|
...
|
||||||
|
def encode_batch(self, texts, **kwargs): ...
|
||||||
|
def decode(self, *args, **kwargs): # -> str:
|
||||||
|
...
|
||||||
|
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def load_tokenizer(
|
||||||
|
model_path: Path,
|
||||||
|
tokenizer_config_extra=...,
|
||||||
|
return_tokenizer=...,
|
||||||
|
eos_token_ids=...,
|
||||||
|
) -> (
|
||||||
|
TokenizerWrapper
|
||||||
|
| type[SPMStreamingDetokenizer]
|
||||||
|
| partial[SPMStreamingDetokenizer]
|
||||||
|
| type[BPEStreamingDetokenizer]
|
||||||
|
| type[NaiveStreamingDetokenizer]
|
||||||
|
):
|
||||||
|
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||||
|
detokenizer to use.
|
||||||
|
|
||||||
|
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
||||||
|
a Hugging Face repo ID.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
|
||||||
195
.mlx_typings/mlx_lm/utils.pyi
Normal file
195
.mlx_typings/mlx_lm/utils.pyi
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
from transformers.utils.auto_docstring import ModelArgs
|
||||||
|
|
||||||
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": ...
|
||||||
|
else: ...
|
||||||
|
MODEL_REMAPPING = ...
|
||||||
|
MAX_FILE_SIZE_GB = ...
|
||||||
|
|
||||||
|
def compute_bits_per_weight(model): ...
|
||||||
|
def hf_repo_to_path(hf_repo): # -> Path:
|
||||||
|
...
|
||||||
|
def load_config(model_path: Path) -> dict: ...
|
||||||
|
def load_model(
|
||||||
|
model_path: Path,
|
||||||
|
lazy: bool = False,
|
||||||
|
strict: bool = True,
|
||||||
|
model_config: dict[str, Any] = {},
|
||||||
|
get_model_classes: Callable[
|
||||||
|
[dict[str, Any]], Tuple[Type[nn.Module], Type[ModelArgs]]
|
||||||
|
] = ...,
|
||||||
|
) -> Tuple[nn.Module, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Load and initialize the model from a given path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (Path): The path to load the model from.
|
||||||
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
|
when needed. Default: ``False``
|
||||||
|
strict (bool): Whether or not to raise an exception if weights don't
|
||||||
|
match. Default: ``True``
|
||||||
|
model_config (dict, optional): Optional configuration parameters for the
|
||||||
|
model. Defaults to an empty dictionary.
|
||||||
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||||
|
A function that returns the model class and model args class given a config.
|
||||||
|
Defaults to the ``_get_classes`` function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||||
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load(
|
||||||
|
path_or_hf_repo: str,
|
||||||
|
tokenizer_config=...,
|
||||||
|
model_config=...,
|
||||||
|
adapter_path: Optional[str] = ...,
|
||||||
|
lazy: bool = ...,
|
||||||
|
return_config: bool = ...,
|
||||||
|
revision: str = ...,
|
||||||
|
) -> Union[
|
||||||
|
Tuple[nn.Module, TokenizerWrapper],
|
||||||
|
Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||||
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||||
|
Defaults to an empty dictionary.
|
||||||
|
model_config(dict, optional): Configuration parameters specifically for the model.
|
||||||
|
Defaults to an empty dictionary.
|
||||||
|
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||||
|
to the model. Default: ``None``.
|
||||||
|
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||||
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
|
when needed. Default: ``False``
|
||||||
|
return_config (bool: If ``True`` return the model config as the last item..
|
||||||
|
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
||||||
|
Returns:
|
||||||
|
Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]:
|
||||||
|
A tuple containing the loaded model, tokenizer and, if requested, the model config.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If config file or safetensors are not found.
|
||||||
|
ValueError: If model class or args class are not found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def make_shards(weights: dict, max_file_size_gb: int = ...) -> list:
|
||||||
|
"""
|
||||||
|
Splits the weights into smaller shards.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (dict): Model weights.
|
||||||
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of weight shards.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_model_card(
|
||||||
|
path: Union[str, Path], hf_path: Union[str, Path, None]
|
||||||
|
): # -> None:
|
||||||
|
"""
|
||||||
|
Uploads the model to Hugging Face hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (Union[str, Path]): Local path to the model.
|
||||||
|
hf_path (Union[str, Path, None]): Path to the original Hugging Face model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def upload_to_hub(path: str, upload_repo: str): # -> None:
|
||||||
|
"""
|
||||||
|
Uploads the model to Hugging Face hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Local path to the model.
|
||||||
|
upload_repo (str): Name of the HF repo to upload to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_model(
|
||||||
|
save_path: Union[str, Path], model: nn.Module, *, donate_model: bool = ...
|
||||||
|
) -> None:
|
||||||
|
"""Save model weights and metadata index into specified directory."""
|
||||||
|
|
||||||
|
def quantize_model(
|
||||||
|
model: nn.Module,
|
||||||
|
config: dict,
|
||||||
|
group_size: int,
|
||||||
|
bits: int,
|
||||||
|
mode: str = ...,
|
||||||
|
quant_predicate: Optional[Callable[[str, nn.Module], Union[bool, dict]]] = ...,
|
||||||
|
) -> Tuple[nn.Module, dict]:
|
||||||
|
"""
|
||||||
|
Applies quantization to the model weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to be quantized.
|
||||||
|
config (dict): Model configuration.
|
||||||
|
group_size (int): Group size for quantization.
|
||||||
|
bits (int): Bits per weight for quantization.
|
||||||
|
mode (str): The quantization mode.
|
||||||
|
quant_predicate (Callable): A callable that decides how to quantize
|
||||||
|
each layer based on the path. Accepts the layer `path` and the
|
||||||
|
`module`. Returns either a bool to signify quantize/no quantize or
|
||||||
|
a dict of quantization parameters to pass to `to_quantized`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: Tuple containing quantized model and config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_config(config: dict, config_path: Union[str, Path]) -> None:
|
||||||
|
"""Save the model configuration to the ``config_path``.
|
||||||
|
|
||||||
|
The final configuration will be sorted before saving for better readability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The model configuration.
|
||||||
|
config_path (Union[str, Path]): Model configuration file path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save(
|
||||||
|
dst_path: Union[str, Path],
|
||||||
|
src_path_or_repo: Union[str, Path],
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
donate_model: bool = ...,
|
||||||
|
): # -> None:
|
||||||
|
...
|
||||||
|
def common_prefix_len(list1, list2): # -> int:
|
||||||
|
"""
|
||||||
|
Calculates the length of the common prefix of two lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
list1: The first list of strings.
|
||||||
|
list2: The second list of strings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The length of the common prefix. Returns 0 if lists are empty
|
||||||
|
or do not match at the first element.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def does_model_support_input_embeddings(model: nn.Module) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the model supports input_embeddings in its call signature.
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to check.
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports input_embeddings, False otherwise.
|
||||||
|
"""
|
||||||
2
justfile
2
justfile
@@ -1,5 +1,5 @@
|
|||||||
fmt:
|
fmt:
|
||||||
uv run ruff format src typings
|
uv run ruff format src .mlx_typings
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
uv run ruff check --fix src
|
uv run ruff check --fix src
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ build-backend = "uv_build"
|
|||||||
###
|
###
|
||||||
|
|
||||||
[tool.basedpyright]
|
[tool.basedpyright]
|
||||||
include = [".venv/lib/mlx", "src"]
|
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
|
||||||
typeCheckingMode = "strict"
|
typeCheckingMode = "strict"
|
||||||
failOnWarnings = true
|
failOnWarnings = true
|
||||||
|
|
||||||
@@ -97,8 +97,8 @@ reportUnnecessaryTypeIgnoreComment = "error"
|
|||||||
pythonVersion = "3.13"
|
pythonVersion = "3.13"
|
||||||
pythonPlatform = "Darwin"
|
pythonPlatform = "Darwin"
|
||||||
|
|
||||||
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust"]
|
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust", "**/.github"]
|
||||||
stubPath = "typings"
|
stubPath = ".mlx_typings"
|
||||||
|
|
||||||
[[tool.basedpyright.executionEnvironments]]
|
[[tool.basedpyright.executionEnvironments]]
|
||||||
root = "src"
|
root = "src"
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ from mlx_lm.models.cache import KVCache
|
|||||||
from mlx_lm.sample_utils import make_sampler
|
from mlx_lm.sample_utils import make_sampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mlx_lm.tokenizer_utils import load_tokenizer # type: ignore
|
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||||
from mlx_lm.utils import load_model # type: ignore
|
from mlx_lm.utils import load_model
|
||||||
from pydantic import RootModel
|
from pydantic import RootModel
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@@ -167,12 +167,11 @@ def shard_and_load(
|
|||||||
f"loading model from {model_path} with strategy {model_shard_meta.strategy}"
|
f"loading model from {model_path} with strategy {model_shard_meta.strategy}"
|
||||||
)
|
)
|
||||||
|
|
||||||
model, config = load_model(model_path, lazy=True, strict=False) # type: ignore
|
model, config = load_model(model_path, lazy=True, strict=False)
|
||||||
runner_print(f"{config=}")
|
runner_print(f"{config=}")
|
||||||
assert isinstance(model, nn.Module)
|
assert isinstance(model, nn.Module)
|
||||||
|
|
||||||
tokenizer = load_tokenizer(model_path) # type: ignore
|
tokenizer = cast(TokenizerWrapper, load_tokenizer(model_path))
|
||||||
tokenizer = cast(TokenizerWrapper, tokenizer)
|
|
||||||
|
|
||||||
runner_print(f"Group size: {group.size()}, group rank: {group.rank()}")
|
runner_print(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ async def build_base_shard(model_id: str) -> ShardMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata | None:
|
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
|
||||||
base_shard = await build_base_shard(model_id)
|
base_shard = await build_base_shard(model_id)
|
||||||
return PipelineShardMetadata(
|
return PipelineShardMetadata(
|
||||||
model_meta=base_shard.model_meta,
|
model_meta=base_shard.model_meta,
|
||||||
@@ -150,11 +150,9 @@ class ResumableShardDownloader(ShardDownloader):
|
|||||||
# print("get_shard_download_status")
|
# print("get_shard_download_status")
|
||||||
async def _status_for_model(
|
async def _status_for_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> tuple[Path, RepoDownloadProgress] | None:
|
) -> tuple[Path, RepoDownloadProgress]:
|
||||||
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
||||||
shard = await build_full_shard(model_id)
|
shard = await build_full_shard(model_id)
|
||||||
if shard is None:
|
|
||||||
return None
|
|
||||||
return await download_shard(
|
return await download_shard(
|
||||||
shard, self.on_progress_wrapper, skip_download=True
|
shard, self.on_progress_wrapper, skip_download=True
|
||||||
)
|
)
|
||||||
@@ -168,8 +166,6 @@ class ResumableShardDownloader(ShardDownloader):
|
|||||||
for task in asyncio.as_completed(tasks):
|
for task in asyncio.as_completed(tasks):
|
||||||
try:
|
try:
|
||||||
result = await task
|
result = await task
|
||||||
if result is None:
|
|
||||||
continue
|
|
||||||
path, progress = result
|
path, progress = result
|
||||||
yield (path, progress)
|
yield (path, progress)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -35,16 +35,18 @@ generation_stream = mx.new_stream(mx.default_device())
|
|||||||
|
|
||||||
|
|
||||||
def maybe_quantize_kv_cache(
|
def maybe_quantize_kv_cache(
|
||||||
prompt_cache: list[Any],
|
prompt_cache: list[KVCache | Any],
|
||||||
quantized_kv_start: int,
|
quantized_kv_start: int,
|
||||||
kv_group_size: int,
|
kv_group_size: int,
|
||||||
kv_bits: int | None,
|
kv_bits: int | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if kv_bits is None:
|
if kv_bits is None:
|
||||||
return
|
return
|
||||||
for e, c in enumerate(prompt_cache): # type: ignore[type-arg]
|
for e, c in enumerate(prompt_cache):
|
||||||
if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start: # type: ignore[type-arg]
|
if (
|
||||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) # type: ignore[type-arg]
|
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||||
|
):
|
||||||
|
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
@@ -189,7 +191,7 @@ def generate_step(
|
|||||||
quantize_cache_fn(prompt_cache)
|
quantize_cache_fn(prompt_cache)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
mx.eval([c.state for c in prompt_cache]) # type: ignore
|
mx.eval([c.state for c in prompt_cache])
|
||||||
eval_time = time.time() - start_time
|
eval_time = time.time() - start_time
|
||||||
prompt_processed_tokens += n_to_process
|
prompt_processed_tokens += n_to_process
|
||||||
|
|
||||||
@@ -221,9 +223,17 @@ def generate_step(
|
|||||||
n = 0
|
n = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
mx.eval(y, logprobs)
|
assert y is not None
|
||||||
|
assert logprobs is not None
|
||||||
|
if n != max_tokens:
|
||||||
|
next_y, next_logprobs = _step(y)
|
||||||
|
mx.async_eval(next_y, next_logprobs)
|
||||||
|
if n == 0:
|
||||||
|
mx.eval(y)
|
||||||
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||||
|
if n == max_tokens:
|
||||||
|
break
|
||||||
yield int(y.item()), logprobs
|
yield int(y.item()), logprobs
|
||||||
n += 1
|
|
||||||
if n % 256 == 0:
|
if n % 256 == 0:
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
if n == max_tokens:
|
if n == max_tokens:
|
||||||
|
|||||||
Reference in New Issue
Block a user