diff --git a/openllm-python/src/openllm/_generation.py b/openllm-python/src/openllm/_generation.py index 8139b0fd..d92d3e0b 100644 --- a/openllm-python/src/openllm/_generation.py +++ b/openllm-python/src/openllm/_generation.py @@ -11,7 +11,7 @@ class StopSequenceCriteria(transformers.StoppingCriteria): self.stop_sequences, self.tokenizer = stop_sequences, tokenizer def __call__(self, input_ids: torch.Tensor, scores: t.Any, **_: t.Any) -> bool: return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences) class StopOnTokens(transformers.StoppingCriteria): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool: return t.cast(int, input_ids[0][-1]) in {50278, 50279, 50277, 1, 0} + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool: return input_ids[0][-1] in {50278, 50279, 50277, 1, 0} def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsProcessorList: generation_config = config.generation_config logits_processor = transformers.LogitsProcessorList()