feat(generation): add support for eos_token_id (#714)

chore: add support for custom eos_token_id

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-21 02:01:36 -05:00
committed by GitHub
parent fde78a2c78
commit e70246ca5d

View File

@@ -454,6 +454,11 @@ class LLM(t.Generic[M, T], ReprMixin):
if stop_token_ids is None:
stop_token_ids = []
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
if eos_token_id is not None:
if not isinstance(eos_token_id, list):
eos_token_id = [eos_token_id]
stop_token_ids.extend(eos_token_id)
if self.tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(self.tokenizer.eos_token_id)
if stop is None: