Files
exo/.mlx_typings/mlx_lm/generate.pyi
2025-11-06 21:59:29 +00:00

325 lines
12 KiB
Python

"""
This type stub file was generated by pyright.
"""
import contextlib
from dataclasses import dataclass
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from transformers import PreTrainedTokenizer
from .tokenizer_utils import TokenizerWrapper
DEFAULT_PROMPT = ...
DEFAULT_MAX_TOKENS = ...
DEFAULT_TEMP = ...
DEFAULT_TOP_P = ...
DEFAULT_MIN_P = ...
DEFAULT_TOP_K = ...
DEFAULT_XTC_PROBABILITY = ...
DEFAULT_XTC_THRESHOLD = ...
DEFAULT_MIN_TOKENS_TO_KEEP = ...
DEFAULT_SEED = ...
DEFAULT_MODEL = ...
DEFAULT_QUANTIZED_KV_START = ...
def str2bool(string): # -> bool:
...
def setup_arg_parser(): # -> ArgumentParser:
"""Set up and return the argument parser."""
generation_stream = ...
@contextlib.contextmanager
def wired_limit(
model: nn.Module, streams: Optional[List[mx.Stream]] = ...
): # -> Generator[None, Any, None]:
"""
A context manager to temporarily change the wired limit.
Note, the wired limit should not be changed during an async eval. If an
async eval could be running pass in the streams to synchronize with prior
to exiting the context manager.
"""
@dataclass
class GenerationResponse:
"""
The output of :func:`stream_generate`.
Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
from_draft (bool): Whether the token was generated by the draft model.
logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB.
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
"""
text: str
token: int
logprobs: mx.array
from_draft: bool
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
finish_reason: Optional[str] = ...
def maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
): # -> None:
...
def generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
max_kv_size: Optional[int] = ...,
prompt_cache: Optional[Any] = ...,
prefill_step_size: int = ...,
kv_bits: Optional[int] = ...,
kv_group_size: int = ...,
quantized_kv_start: int = ...,
prompt_progress_callback: Optional[Callable[[int], int]] = ...,
input_embeddings: Optional[mx.array] = ...,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
prompt_progress_callback (Callable[[int], int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
conjunction with prompt tokens. Default: ``None``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
draft_model: nn.Module,
*,
num_draft_tokens: int = ...,
max_tokens: int = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
prompt_cache: Optional[Any] = ...,
prefill_step_size: int = ...,
kv_bits: Optional[int] = ...,
kv_group_size: int = ...,
quantized_kv_start: int = ...,
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
draft_model (nn.Module): The draft model for speculative decoding.
num_draft_tokens (int, optional): The number of draft tokens for
speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[[mx.array], mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
and a bool indicating if the token was generated by the draft model
"""
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = ...,
draft_model: Optional[nn.Module] = ...,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or
integer tokens.
max_tokens (int): The maximum number of tokens to generate.
Default: ``256``.
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
GenerationResponse: An instance containing the generated text segment and
associated metadata. See :class:`GenerationResponse` for details.
"""
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, List[int]],
verbose: bool = ...,
**kwargs,
) -> str:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
See :func:`stream_generate` for more details.
"""
@dataclass
class BatchStats:
"""
An data object to hold generation stats.
Args:
prompt_tokens (int): The number of prompt tokens processed.
prompt_tps (float): The prompt processing tokens-per-second.
prompt_time (float): The time in seconds spent in prompt processing.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
generation_time (float): The time in seconds spent in generation .
peak_memory (float): The peak memory used so far in GB.
"""
prompt_tokens: int = ...
prompt_tps: float = ...
prompt_time: float = ...
generation_tokens: int = ...
generation_tps: float = ...
generation_time: float = ...
peak_memory: float = ...
@dataclass
class BatchResponse:
"""
An data object to hold a batch generation response.
Args:
texts: (List[str]): The generated text for each prompt.
stats (BatchStats): Statistics about the generation.
"""
texts: List[str]
stats: BatchStats
@dataclass
class Batch:
uids: List[int]
y: mx.array
logprobs: mx.array
max_tokens: List[int]
num_tokens: List[int]
cache: List[Any]
def __len__(self): # -> int:
...
def filter(self, keep_idx: List[int]): # -> None:
...
def extend(self, other): # -> None:
...
class BatchGenerator:
@dataclass
class Response:
uid: int
token: int
logprobs: mx.array
finish_reason: Optional[str]
def __init__(
self,
model,
max_tokens: int = ...,
stop_tokens: Optional[set] = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
completion_batch_size: int = ...,
prefill_batch_size: int = ...,
prefill_step_size: int = ...,
) -> None: ...
def insert(
self, prompts, max_tokens: Union[List[int], int, None] = ...
): # -> list[Any]:
...
def stats(self): # -> BatchStats:
...
def next(self): # -> list[Any]:
...
def batch_generate(
model,
tokenizer,
prompts: List[int],
max_tokens: Union[int, List[int]] = ...,
verbose: bool = ...,
**kwargs,
) -> BatchResponse:
"""
Generate responses for the given batch of prompts.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (List[List[int]]): The input prompts.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
max_tokens (Union[int, List[int]): Maximum number of output tokens. This
can be per prompt if a list is provided.
kwargs: The remaining options get passed to :obj:`BatchGenerator`.
See :obj:`BatchGenerator` for more details.
"""
def main(): # -> None:
...
if __name__ == "__main__": ...