fix: synchronize device for inference

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-09-06 14:08:29 -04:00
parent 030a861a08
commit 675b372981

View File

@@ -981,17 +981,17 @@ class LLM(LLMInterface[M, T], ReprMixin):
for i in range(self.config['max_new_tokens']):
torch.cuda.synchronize()
if i == 0: # prefill
out = self.model(torch.as_tensor([input_ids]), use_cache=True)
out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else: # decoding
out = self.model(input_ids=torch.as_tensor([[token]]), use_cache=True, past_key_values=past_key_values)
out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if self.config['repetition_penalty'] > 1.0:
tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=logits.device)
tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=self.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]