diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 55edfd91..b08b0e75 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -454,6 +454,11 @@ class LLM(t.Generic[M, T], ReprMixin): if stop_token_ids is None: stop_token_ids = [] + eos_token_id = attrs.get('eos_token_id', config['eos_token_id']) + if eos_token_id is not None: + if not isinstance(eos_token_id, list): + eos_token_id = [eos_token_id] + stop_token_ids.extend(eos_token_id) if self.tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(self.tokenizer.eos_token_id) if stop is None: