diff --git a/openllm-python/src/_openllm_tiny/_service.py b/openllm-python/src/_openllm_tiny/_service.py index cb08689a..2f88bff3 100644 --- a/openllm-python/src/_openllm_tiny/_service.py +++ b/openllm-python/src/_openllm_tiny/_service.py @@ -72,7 +72,11 @@ class LLMService: request_id: t.Optional[str] = None, llm_config: t.Dict[str, t.Any] = pydantic.Field(default=LLMConfig, description='LLM Config'), ) -> core.GenerationOutput: - llm_config.update(stop=stop, stop_token_ids=stop_token_ids) + if stop: + llm_config.update(stop=stop) + if stop_token_ids: + llm_config.update(stop_token_ids=stop_token_ids) + return await self.llm.generate( prompt=prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **llm_config ) @@ -87,7 +91,10 @@ class LLMService: request_id: t.Optional[str] = None, llm_config: t.Dict[str, t.Any] = pydantic.Field(default=LLMConfig, description='LLM Config'), ) -> t.AsyncGenerator[str, None]: - llm_config.update(stop=stop, stop_token_ids=stop_token_ids) + if stop: + llm_config.update(stop=stop) + if stop_token_ids: + llm_config.update(stop_token_ids=stop_token_ids) _config = LLMConfig.model_construct_env(**core.utils.dict_filter_none(llm_config)) previous_texts = [''] * _config['n']