fix: make sure to only update fields when correct type is parse

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
paperspace
2024-05-14 06:42:46 +00:00
parent 03c51ab133
commit e9246e7772

View File

@@ -72,7 +72,11 @@ class LLMService:
request_id: t.Optional[str] = None,
llm_config: t.Dict[str, t.Any] = pydantic.Field(default=LLMConfig, description='LLM Config'),
) -> core.GenerationOutput:
llm_config.update(stop=stop, stop_token_ids=stop_token_ids)
if stop:
llm_config.update(stop=stop)
if stop_token_ids:
llm_config.update(stop_token_ids=stop_token_ids)
return await self.llm.generate(
prompt=prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **llm_config
)
@@ -87,7 +91,10 @@ class LLMService:
request_id: t.Optional[str] = None,
llm_config: t.Dict[str, t.Any] = pydantic.Field(default=LLMConfig, description='LLM Config'),
) -> t.AsyncGenerator[str, None]:
llm_config.update(stop=stop, stop_token_ids=stop_token_ids)
if stop:
llm_config.update(stop=stop)
if stop_token_ids:
llm_config.update(stop_token_ids=stop_token_ids)
_config = LLMConfig.model_construct_env(**core.utils.dict_filter_none(llm_config))
previous_texts = [''] * _config['n']