mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-14 11:28:43 -04:00
fix(loading): make sure to cast the model to cuda if PyTorch
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -861,10 +861,17 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
@property
|
||||
def model(self) -> M:
|
||||
# Run check for GPU
|
||||
if DEBUG: traceback.print_stack()
|
||||
if self.config["requires_gpu"] and device_count() < 1: raise GpuNotAvailableError(f"{self} only supports running with GPU (None available).") from None
|
||||
# NOTE: the signature of load_model here is the wrapper under _wrapped_load_model
|
||||
if self.__llm_model__ is None: self.__llm_model__ = self.load_model(*self._model_decls, **self._model_attrs)
|
||||
if self.__llm_model__ is None:
|
||||
model = self.load_model(*self._model_decls, **self._model_attrs)
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
if self.__llm_implementation__ == "pt" and is_torch_available() and torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
try:
|
||||
model = model.to("cuda")
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f"Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.") from err
|
||||
self.__llm_model__ = model
|
||||
return self.__llm_model__
|
||||
|
||||
@property
|
||||
|
||||
@@ -37,7 +37,6 @@ class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNe
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs)
|
||||
if self.config.use_half_precision: model.half()
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.to("cuda")
|
||||
return model
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
|
||||
@@ -26,16 +26,9 @@ logger = logging.getLogger(__name__)
|
||||
class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def llm_post_init(self):
|
||||
self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM:
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
return transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, torch_dtype=torch_dtype, **attrs)
|
||||
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
|
||||
@@ -179,8 +179,6 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
|
||||
device_map = attrs.pop("device_map", "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, device_map=device_map, **hub_attrs, **attrs).eval()
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda()
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
if llm.bettertransformer and isinstance(model, transformers.PreTrainedModel): model = model.to_bettertransformer()
|
||||
if llm.__llm_implementation__ in {"pt", "vllm"}: check_unintialised_params(model)
|
||||
|
||||
Reference in New Issue
Block a user