From 675b3729811ad23fc8decc9e02fa636d317d98ee Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:08:29 -0400 Subject: [PATCH] fix: synchronize device for inference Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- openllm-python/src/openllm/_llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 08c77cdb..fd4e277f 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -981,17 +981,17 @@ class LLM(LLMInterface[M, T], ReprMixin): for i in range(self.config['max_new_tokens']): torch.cuda.synchronize() if i == 0: # prefill - out = self.model(torch.as_tensor([input_ids]), use_cache=True) + out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: # decoding - out = self.model(input_ids=torch.as_tensor([[token]]), use_cache=True, past_key_values=past_key_values) + out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values) logits = out.logits past_key_values = out.past_key_values if logits_processor: if self.config['repetition_penalty'] > 1.0: - tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=logits.device) + tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=self.device) else: tmp_output_ids = None last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]