Files
OpenLLM/openllm-python/src/openllm/_generation.py
Aaron Pham 4a6f13ddd2 feat(type): provide structured annotations stubs (#663)
* feat(type): provide client stubs

separation of concern for more brevity code base

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* docs: update changelog

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-16 02:58:45 -05:00

58 lines
2.3 KiB
Python

import transformers
class StopSequenceCriteria(transformers.StoppingCriteria):
def __init__(self, stop_sequences, tokenizer):
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
self.stop_sequences, self.tokenizer = stop_sequences, tokenizer
def __call__(self, input_ids, scores, **kwargs):
return any(
self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences
)
class StopOnTokens(transformers.StoppingCriteria):
def __call__(self, input_ids, scores, **kwargs):
return input_ids[0][-1] in {50278, 50279, 50277, 1, 0}
def prepare_logits_processor(config):
generation_config = config.generation_config
logits_processor = transformers.LogitsProcessorList()
if generation_config['temperature'] >= 1e-5 and generation_config['temperature'] != 1.0:
logits_processor.append(transformers.TemperatureLogitsWarper(generation_config['temperature']))
if generation_config['repetition_penalty'] > 1.0:
logits_processor.append(transformers.RepetitionPenaltyLogitsProcessor(generation_config['repetition_penalty']))
if 1e-8 <= generation_config['top_p']:
logits_processor.append(transformers.TopPLogitsWarper(generation_config['top_p']))
if generation_config['top_k'] > 0:
logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
return logits_processor
# NOTE: The ordering here is important. Some models have two of these and we have a preference for which value gets used.
SEQLEN_KEYS = ['max_sequence_length', 'seq_length', 'max_position_embeddings', 'max_seq_len', 'model_max_length']
def get_context_length(config):
rope_scaling = getattr(config, 'rope_scaling', None)
rope_scaling_factor = config.rope_scaling['factor'] if rope_scaling else 1.0
for key in SEQLEN_KEYS:
if getattr(config, key, None) is not None:
return int(rope_scaling_factor * getattr(config, key))
return 2048
def is_sentence_complete(output):
return output.endswith(('.', '?', '!', '...', '', '?', '!', '', '"', "'", ''))
def is_partial_stop(output, stop_str):
"""Check whether the output contains a partial stop str."""
for i in range(min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]):
return True
return False