mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-03 14:46:00 -05:00
fix(generation): unecessary casting [ec2 build] [wheel build]
This breaks when compiled to cfunc Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,7 @@ class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
self.stop_sequences, self.tokenizer = stop_sequences, tokenizer
|
||||
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **_: t.Any) -> bool: return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
class StopOnTokens(transformers.StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool: return t.cast(int, input_ids[0][-1]) in {50278, 50279, 50277, 1, 0}
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool: return input_ids[0][-1] in {50278, 50279, 50277, 1, 0}
|
||||
def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsProcessorList:
|
||||
generation_config = config.generation_config
|
||||
logits_processor = transformers.LogitsProcessorList()
|
||||
|
||||
Reference in New Issue
Block a user