fix(loading): make sure to cast the model to cuda if PyTorch

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-08-09 01:42:11 +00:00
parent ae35ee8115
commit deaee67b47
4 changed files with 10 additions and 13 deletions

View File

@@ -861,10 +861,17 @@ class LLM(LLMInterface[M, T], ReprMixin):
@property
def model(self) -> M:
# Run check for GPU
if DEBUG: traceback.print_stack()
if self.config["requires_gpu"] and device_count() < 1: raise GpuNotAvailableError(f"{self} only supports running with GPU (None available).") from None
# NOTE: the signature of load_model here is the wrapper under _wrapped_load_model
if self.__llm_model__ is None: self.__llm_model__ = self.load_model(*self._model_decls, **self._model_attrs)
if self.__llm_model__ is None:
model = self.load_model(*self._model_decls, **self._model_attrs)
# If OOM, then it is probably you don't have enough VRAM to run this model.
if self.__llm_implementation__ == "pt" and is_torch_available() and torch.cuda.is_available() and torch.cuda.device_count() == 1:
try:
model = model.to("cuda")
except Exception as err:
raise OpenLLMException(f"Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.") from err
self.__llm_model__ = model
return self.__llm_model__
@property

View File

@@ -37,7 +37,6 @@ class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNe
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs)
if self.config.use_half_precision: model.half()
if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.to("cuda")
return model
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:

View File

@@ -26,16 +26,9 @@ logger = logging.getLogger(__name__)
class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer"]):
__openllm_internal__ = True
def llm_post_init(self):
self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM:
torch_dtype = attrs.pop("torch_dtype", self.dtype)
return transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, torch_dtype=torch_dtype, **attrs)
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}

View File

@@ -179,8 +179,6 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
device_map = attrs.pop("device_map", "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, device_map=device_map, **hub_attrs, **attrs).eval()
# If OOM, then it is probably you don't have enough VRAM to run this model.
if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda()
# BetterTransformer is currently only supported on PyTorch.
if llm.bettertransformer and isinstance(model, transformers.PreTrainedModel): model = model.to_bettertransformer()
if llm.__llm_implementation__ in {"pt", "vllm"}: check_unintialised_params(model)