fix(generate): make sure to only pass prompt_token_ids if it is a valid

mutable

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
paperspace
2024-05-29 02:42:13 +00:00
parent fd527352b2
commit c820cececb
2 changed files with 5 additions and 2 deletions

View File

@@ -179,7 +179,7 @@ class LLM:
k: config.__getitem__(k) for k in set(inspect.signature(SamplingParams).parameters.keys())
}),
request_id=request_id,
prompt_token_ids=prompt_token_ids,
prompt_token_ids=prompt_token_ids if prompt_token_ids else None,
):
yield generations
except Exception as err:

View File

@@ -106,7 +106,10 @@ class LLMService:
finish_reason_sent = [False] * _config['n']
async for generations in self.llm.generate_iterator(
prompt=prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **llm_config
prompt=prompt,
prompt_token_ids=prompt_token_ids,
request_id=request_id,
**core.utils.dict_filter_none(llm_config),
):
for output in generations.outputs:
i = output.index