Files
OpenLLM/openllm-python/src/openllm/_generation.py
Aaron Pham b7af7765d4 fix(yapf): align weird new lines break [generated] [skip ci] (#284)
fix(yapf): align weird new lines break

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
2023-09-01 05:34:22 -04:00

59 lines
2.8 KiB
Python

# mypy: disable-error-code="misc"
from __future__ import annotations
import typing as t
import transformers
if t.TYPE_CHECKING:
import torch
import openllm
# reexport from transformers
LogitsProcessorList = transformers.LogitsProcessorList
StoppingCriteriaList = transformers.StoppingCriteriaList
class StopSequenceCriteria(transformers.StoppingCriteria):
def __init__(self, stop_sequences: str | list[str],
tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast):
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]
self.stop_sequences, self.tokenizer = stop_sequences, tokenizer
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **_: t.Any) -> bool:
return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
class StopOnTokens(transformers.StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool:
return input_ids[0][-1] in {50278, 50279, 50277, 1, 0}
def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsProcessorList:
generation_config = config.generation_config
logits_processor = transformers.LogitsProcessorList()
if generation_config['temperature'] >= 1e-5 and generation_config['temperature'] != 1.0:
logits_processor.append(transformers.TemperatureLogitsWarper(generation_config['temperature']))
if generation_config['repetition_penalty'] > 1.0:
logits_processor.append(transformers.RepetitionPenaltyLogitsProcessor(generation_config['repetition_penalty']))
if 1e-8 <= generation_config['top_p']:
logits_processor.append(transformers.TopPLogitsWarper(generation_config['top_p']))
if generation_config['top_k'] > 0: logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
return logits_processor
# NOTE: The ordering here is important. Some models have two of these and we have a preference for which value gets used.
SEQLEN_KEYS = ['max_sequence_length', 'seq_length', 'max_position_embeddings', 'max_seq_len', 'model_max_length']
def get_context_length(config: transformers.PretrainedConfig) -> int:
rope_scaling = getattr(config, 'rope_scaling', None)
rope_scaling_factor = config.rope_scaling['factor'] if rope_scaling else 1.0
for key in SEQLEN_KEYS:
if getattr(config, key, None) is not None: return int(rope_scaling_factor * getattr(config, key))
return 2048
def is_sentence_complete(output: str) -> bool:
return output.endswith(('.', '?', '!', '...', '', '?', '!', '', '"', "'", ''))
def is_partial_stop(output: str, stop_str: str) -> bool:
'''Check whether the output contains a partial stop str.'''
for i in range(0, min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]): return True
return False