From e9246e7772f09fd93f312fd829a94590fc58c6d4 Mon Sep 17 00:00:00 2001 From: paperspace <29749331+aarnphm@users.noreply.github.com> Date: Tue, 14 May 2024 06:42:46 +0000 Subject: [PATCH] fix: make sure to only update fields when correct type is parse Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --- openllm-python/src/_openllm_tiny/_service.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/openllm-python/src/_openllm_tiny/_service.py b/openllm-python/src/_openllm_tiny/_service.py index cb08689a..2f88bff3 100644 --- a/openllm-python/src/_openllm_tiny/_service.py +++ b/openllm-python/src/_openllm_tiny/_service.py @@ -72,7 +72,11 @@ class LLMService: request_id: t.Optional[str] = None, llm_config: t.Dict[str, t.Any] = pydantic.Field(default=LLMConfig, description='LLM Config'), ) -> core.GenerationOutput: - llm_config.update(stop=stop, stop_token_ids=stop_token_ids) + if stop: + llm_config.update(stop=stop) + if stop_token_ids: + llm_config.update(stop_token_ids=stop_token_ids) + return await self.llm.generate( prompt=prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **llm_config ) @@ -87,7 +91,10 @@ class LLMService: request_id: t.Optional[str] = None, llm_config: t.Dict[str, t.Any] = pydantic.Field(default=LLMConfig, description='LLM Config'), ) -> t.AsyncGenerator[str, None]: - llm_config.update(stop=stop, stop_token_ids=stop_token_ids) + if stop: + llm_config.update(stop=stop) + if stop_token_ids: + llm_config.update(stop_token_ids=stop_token_ids) _config = LLMConfig.model_construct_env(**core.utils.dict_filter_none(llm_config)) previous_texts = [''] * _config['n']