mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-05 07:36:15 -05:00
fix: synchronize device for inference
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -981,17 +981,17 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
for i in range(self.config['max_new_tokens']):
|
||||
torch.cuda.synchronize()
|
||||
if i == 0: # prefill
|
||||
out = self.model(torch.as_tensor([input_ids]), use_cache=True)
|
||||
out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else: # decoding
|
||||
out = self.model(input_ids=torch.as_tensor([[token]]), use_cache=True, past_key_values=past_key_values)
|
||||
out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logits_processor:
|
||||
if self.config['repetition_penalty'] > 1.0:
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=logits.device)
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=self.device)
|
||||
else:
|
||||
tmp_output_ids = None
|
||||
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
||||
|
||||
Reference in New Issue
Block a user