mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
325 lines
12 KiB
Python
325 lines
12 KiB
Python
"""
|
|
This type stub file was generated by pyright.
|
|
"""
|
|
|
|
import contextlib
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
from .tokenizer_utils import TokenizerWrapper
|
|
|
|
DEFAULT_PROMPT = ...
|
|
DEFAULT_MAX_TOKENS = ...
|
|
DEFAULT_TEMP = ...
|
|
DEFAULT_TOP_P = ...
|
|
DEFAULT_MIN_P = ...
|
|
DEFAULT_TOP_K = ...
|
|
DEFAULT_XTC_PROBABILITY = ...
|
|
DEFAULT_XTC_THRESHOLD = ...
|
|
DEFAULT_MIN_TOKENS_TO_KEEP = ...
|
|
DEFAULT_SEED = ...
|
|
DEFAULT_MODEL = ...
|
|
DEFAULT_QUANTIZED_KV_START = ...
|
|
|
|
def str2bool(string): # -> bool:
|
|
...
|
|
def setup_arg_parser(): # -> ArgumentParser:
|
|
"""Set up and return the argument parser."""
|
|
|
|
generation_stream = ...
|
|
|
|
@contextlib.contextmanager
|
|
def wired_limit(
|
|
model: nn.Module, streams: Optional[List[mx.Stream]] = ...
|
|
): # -> Generator[None, Any, None]:
|
|
"""
|
|
A context manager to temporarily change the wired limit.
|
|
|
|
Note, the wired limit should not be changed during an async eval. If an
|
|
async eval could be running pass in the streams to synchronize with prior
|
|
to exiting the context manager.
|
|
"""
|
|
@dataclass
|
|
class GenerationResponse:
|
|
"""
|
|
The output of :func:`stream_generate`.
|
|
|
|
Args:
|
|
text (str): The next segment of decoded text. This can be an empty string.
|
|
token (int): The next token.
|
|
from_draft (bool): Whether the token was generated by the draft model.
|
|
logprobs (mx.array): A vector of log probabilities.
|
|
prompt_tokens (int): The number of tokens in the prompt.
|
|
prompt_tps (float): The prompt processing tokens-per-second.
|
|
generation_tokens (int): The number of generated tokens.
|
|
generation_tps (float): The tokens-per-second for generation.
|
|
peak_memory (float): The peak memory used so far in GB.
|
|
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
|
|
"""
|
|
|
|
text: str
|
|
token: int
|
|
logprobs: mx.array
|
|
from_draft: bool
|
|
prompt_tokens: int
|
|
prompt_tps: float
|
|
generation_tokens: int
|
|
generation_tps: float
|
|
peak_memory: float
|
|
finish_reason: Optional[str] = ...
|
|
|
|
def maybe_quantize_kv_cache(
|
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
|
): # -> None:
|
|
...
|
|
def generate_step(
|
|
prompt: mx.array,
|
|
model: nn.Module,
|
|
*,
|
|
max_tokens: int = ...,
|
|
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
|
|
max_kv_size: Optional[int] = ...,
|
|
prompt_cache: Optional[Any] = ...,
|
|
prefill_step_size: int = ...,
|
|
kv_bits: Optional[int] = ...,
|
|
kv_group_size: int = ...,
|
|
quantized_kv_start: int = ...,
|
|
prompt_progress_callback: Optional[Callable[[int], int]] = ...,
|
|
input_embeddings: Optional[mx.array] = ...,
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
"""
|
|
A generator producing token ids based on the given prompt from the model.
|
|
|
|
Args:
|
|
prompt (mx.array): The input prompt.
|
|
model (nn.Module): The model to use for generation.
|
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
|
generator. Default: ``256``.
|
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
|
token from a vector of log probabilities. Default: ``None``.
|
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
|
A list of functions that take tokens and logits and return the processed
|
|
logits. Default: ``None``.
|
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
|
entries (except the first 4 tokens) will be overwritten.
|
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
|
provided, the cache will be updated in place.
|
|
prefill_step_size (int): Step size for processing the prompt.
|
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
|
None implies no cache quantization. Default: ``None``.
|
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
|
when ``kv_bits`` is non-None. Default: ``0``.
|
|
prompt_progress_callback (Callable[[int], int]): A call-back which takes the
|
|
prompt tokens processed so far and the total number of prompt tokens.
|
|
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
|
|
conjunction with prompt tokens. Default: ``None``.
|
|
|
|
Yields:
|
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
|
"""
|
|
|
|
def speculative_generate_step(
|
|
prompt: mx.array,
|
|
model: nn.Module,
|
|
draft_model: nn.Module,
|
|
*,
|
|
num_draft_tokens: int = ...,
|
|
max_tokens: int = ...,
|
|
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,
|
|
prompt_cache: Optional[Any] = ...,
|
|
prefill_step_size: int = ...,
|
|
kv_bits: Optional[int] = ...,
|
|
kv_group_size: int = ...,
|
|
quantized_kv_start: int = ...,
|
|
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
|
|
"""
|
|
A generator producing token ids based on the given prompt from the model.
|
|
|
|
Args:
|
|
prompt (mx.array): The input prompt.
|
|
model (nn.Module): The model to use for generation.
|
|
draft_model (nn.Module): The draft model for speculative decoding.
|
|
num_draft_tokens (int, optional): The number of draft tokens for
|
|
speculative decoding. Default: ``2``.
|
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
|
generator. Default: ``256``.
|
|
sampler (Callable[[mx.array], mx.array], optional): A sampler for sampling a
|
|
token from a vector of log probabilities. Default: ``None``.
|
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
|
A list of functions that take tokens and logits and return the processed
|
|
logits. Default: ``None``.
|
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
|
provided, the cache will be updated in place. The cache must be trimmable.
|
|
prefill_step_size (int): Step size for processing the prompt.
|
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
|
None implies no cache quantization. Default: ``None``.
|
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
|
when ``kv_bits`` is non-None. Default: ``0``.
|
|
|
|
Yields:
|
|
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
|
|
and a bool indicating if the token was generated by the draft model
|
|
"""
|
|
|
|
def stream_generate(
|
|
model: nn.Module,
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
|
prompt: Union[str, mx.array, List[int]],
|
|
max_tokens: int = ...,
|
|
draft_model: Optional[nn.Module] = ...,
|
|
**kwargs,
|
|
) -> Generator[GenerationResponse, None, None]:
|
|
"""
|
|
A generator producing text based on the given prompt from the model.
|
|
|
|
Args:
|
|
model (nn.Module): The model to use for generation.
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
prompt (Union[str, mx.array, List[int]]): The input prompt string or
|
|
integer tokens.
|
|
max_tokens (int): The maximum number of tokens to generate.
|
|
Default: ``256``.
|
|
draft_model (Optional[nn.Module]): An optional draft model. If provided
|
|
then speculative decoding is used. The draft model must use the same
|
|
tokenizer as the main model. Default: ``None``.
|
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
|
See :func:`generate_step` for more details.
|
|
|
|
Yields:
|
|
GenerationResponse: An instance containing the generated text segment and
|
|
associated metadata. See :class:`GenerationResponse` for details.
|
|
"""
|
|
|
|
def generate(
|
|
model: nn.Module,
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
|
prompt: Union[str, List[int]],
|
|
verbose: bool = ...,
|
|
**kwargs,
|
|
) -> str:
|
|
"""
|
|
Generate a complete response from the model.
|
|
|
|
Args:
|
|
model (nn.Module): The language model.
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
|
verbose (bool): If ``True``, print tokens and timing information.
|
|
Default: ``False``.
|
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
|
See :func:`stream_generate` for more details.
|
|
"""
|
|
@dataclass
|
|
class BatchStats:
|
|
"""
|
|
An data object to hold generation stats.
|
|
|
|
Args:
|
|
prompt_tokens (int): The number of prompt tokens processed.
|
|
prompt_tps (float): The prompt processing tokens-per-second.
|
|
prompt_time (float): The time in seconds spent in prompt processing.
|
|
generation_tokens (int): The number of generated tokens.
|
|
generation_tps (float): The tokens-per-second for generation.
|
|
generation_time (float): The time in seconds spent in generation .
|
|
peak_memory (float): The peak memory used so far in GB.
|
|
"""
|
|
|
|
prompt_tokens: int = ...
|
|
prompt_tps: float = ...
|
|
prompt_time: float = ...
|
|
generation_tokens: int = ...
|
|
generation_tps: float = ...
|
|
generation_time: float = ...
|
|
peak_memory: float = ...
|
|
|
|
@dataclass
|
|
class BatchResponse:
|
|
"""
|
|
An data object to hold a batch generation response.
|
|
|
|
Args:
|
|
texts: (List[str]): The generated text for each prompt.
|
|
stats (BatchStats): Statistics about the generation.
|
|
"""
|
|
|
|
texts: List[str]
|
|
stats: BatchStats
|
|
|
|
@dataclass
|
|
class Batch:
|
|
uids: List[int]
|
|
y: mx.array
|
|
logprobs: mx.array
|
|
max_tokens: List[int]
|
|
num_tokens: List[int]
|
|
cache: List[Any]
|
|
def __len__(self): # -> int:
|
|
...
|
|
def filter(self, keep_idx: List[int]): # -> None:
|
|
...
|
|
def extend(self, other): # -> None:
|
|
...
|
|
|
|
class BatchGenerator:
|
|
@dataclass
|
|
class Response:
|
|
uid: int
|
|
token: int
|
|
logprobs: mx.array
|
|
finish_reason: Optional[str]
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
max_tokens: int = ...,
|
|
stop_tokens: Optional[set] = ...,
|
|
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
|
completion_batch_size: int = ...,
|
|
prefill_batch_size: int = ...,
|
|
prefill_step_size: int = ...,
|
|
) -> None: ...
|
|
def insert(
|
|
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
|
): # -> list[Any]:
|
|
...
|
|
def stats(self): # -> BatchStats:
|
|
...
|
|
def next(self): # -> list[Any]:
|
|
...
|
|
|
|
def batch_generate(
|
|
model,
|
|
tokenizer,
|
|
prompts: List[int],
|
|
max_tokens: Union[int, List[int]] = ...,
|
|
verbose: bool = ...,
|
|
**kwargs,
|
|
) -> BatchResponse:
|
|
"""
|
|
Generate responses for the given batch of prompts.
|
|
|
|
Args:
|
|
model (nn.Module): The language model.
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
prompt (List[List[int]]): The input prompts.
|
|
verbose (bool): If ``True``, print tokens and timing information.
|
|
Default: ``False``.
|
|
max_tokens (Union[int, List[int]): Maximum number of output tokens. This
|
|
can be per prompt if a list is provided.
|
|
kwargs: The remaining options get passed to :obj:`BatchGenerator`.
|
|
See :obj:`BatchGenerator` for more details.
|
|
"""
|
|
|
|
def main(): # -> None:
|
|
...
|
|
|
|
if __name__ == "__main__": ...
|