diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 08c77cdb..fd4e277f 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -981,17 +981,17 @@ class LLM(LLMInterface[M, T], ReprMixin): for i in range(self.config['max_new_tokens']): torch.cuda.synchronize() if i == 0: # prefill - out = self.model(torch.as_tensor([input_ids]), use_cache=True) + out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: # decoding - out = self.model(input_ids=torch.as_tensor([[token]]), use_cache=True, past_key_values=past_key_values) + out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values) logits = out.logits past_key_values = out.past_key_values if logits_processor: if self.config['repetition_penalty'] > 1.0: - tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=logits.device) + tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=self.device) else: tmp_output_ids = None last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]