feat(gpu): Make sure that we run models on GPU if available

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-05-25 02:10:36 +00:00
committed by Aaron
parent 135bafacaf
commit 73d152fc77
7 changed files with 85 additions and 38 deletions

View File

@@ -49,7 +49,7 @@ dependencies = [
# bentoml[grpc,grpc-reflection] include grpcio, grpcio-reflection
"bentoml[io-image,io-pandas,io-file,grpc,grpc-reflection]>=1.0.19",
# bentoml[torch] includes torch and transformers
"transformers[torch]>=4.29.0",
"transformers[torch,accelerate,tokenizers,onnxruntime,onnx]>=4.29.0",
# Super fast JSON serialization
"orjson",
"inflection",

View File

@@ -431,6 +431,7 @@ class LLMConfig(pydantic.BaseModel, ABC):
__openllm_start_name__: str = ""
__openllm_timeout__: int = 0
__openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize"
__openllm_trust_remote_code__: bool = False
GenerationConfig: type[t.Any] = GenerationConfig
def __init_subclass__(
@@ -438,6 +439,7 @@ class LLMConfig(pydantic.BaseModel, ABC):
*,
default_timeout: int | None = None,
name_type: t.Literal["dasherize", "lowercase"] = "dasherize",
trust_remote_code: bool = False,
**kwargs: t.Any,
):
if default_timeout is None:
@@ -446,6 +448,7 @@ class LLMConfig(pydantic.BaseModel, ABC):
if name_type not in ("dasherize", "lowercase"):
raise RuntimeError(f"Unknown name_type {name_type}. Only allowed are 'dasherize' and 'lowercase'.")
cls.__openllm_name_type__ = name_type
cls.__openllm_trust_remote_code__ = trust_remote_code
super(LLMConfig, cls).__init_subclass__(**kwargs)

View File

@@ -35,7 +35,7 @@ if t.TYPE_CHECKING:
import transformers
from bentoml._internal.runner.strategy import Strategy
from .types import LLMModel, LLMTokenizer, ModelSignatureDict
from ._types import LLMModel, LLMTokenizer, ModelSignatureDict
else:
ModelSignatureDict = dict
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
@@ -296,7 +296,8 @@ class LLM(LLMInterface):
dolly_v2_runner = openllm.Runner("dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat8, device_map='gpu')
```
Note: If you implement your own `import_model`, then `import_kwargs` will be ignored.
Note: If you implement your own `import_model`, then `import_kwargs` will be the default kwargs for every load. You can still override those
via ``openllm.Runner``.
Note that this tag will be generated based on `self.default_model` or the given `pretrained` kwds.
passed from the __init__ constructor.
@@ -334,11 +335,26 @@ class LLM(LLMInterface):
@property
def _bentomodel(self) -> bentoml.Model:
if self.__bentomodel__ is None:
tag, kwds = openllm.utils.generate_tags(self._pretrained, prefix=self._implementation, **self._kwargs)
trust_remote_code = self._kwargs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
tag, kwds = openllm.utils.generate_tags(
self._pretrained, prefix=self._implementation, trust_remote_code=trust_remote_code, **self._kwargs
)
tokenizer_kwds = {k[len("_tokenizer_") :]: v for k, v in kwds.items() if k.startswith("_tokenizer_")}
kwds = {k: v for k, v in kwds.items() if not k.startswith("_tokenizer_")}
if self.import_kwargs:
tokenizer_kwds = {
**{
k[len("_tokenizer_") :]: v for k, v in self.import_kwargs.items() if k.startswith("_tokenizer_")
},
**tokenizer_kwds,
}
kwds = {
**{k: v for k, v in self.import_kwargs.items() if not k.startswith("_tokenizer_")},
**kwds,
}
try:
self.__bentomodel__ = bentoml.transformers.get(tag)
except bentoml.exceptions.BentoMLException:
@@ -348,17 +364,21 @@ class LLM(LLMInterface):
if hasattr(self, "import_model"):
logger.debug("Using custom 'import_model' defined in subclass.")
self.__bentomodel__ = self.import_model(
self._pretrained, tag, *self._args, tokenizer_kwds=tokenizer_kwds, **kwds
self._pretrained,
tag,
*self._args,
tokenizer_kwds=tokenizer_kwds,
trust_remote_code=trust_remote_code,
**kwds,
)
else:
if self.import_kwargs:
kwds = {**self.import_kwargs, **kwds}
# NOTE: In this branch, we just use the default implementation.
self.__bentomodel__ = import_model(
self._pretrained,
tag,
*self._args,
tokenizer_kwds=tokenizer_kwds,
trust_remote_code=trust_remote_code,
__openllm_framework__=self._implementation,
**kwds,
)
@@ -473,10 +493,24 @@ class LLM(LLMInterface):
def Runner(start_name: str, **kwds: t.Any):
"""Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'
Args:
start_name: Supported model name from 'openllm models'
init_local: Whether to init_local this given Runner. This is useful during development. (Default to False)
**kwds: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs
behaviour
"""
init_local = kwds.pop("init_local", False)
envvar = openllm.utils.get_framework_env(start_name)
if envvar == "flax":
return openllm.AutoFlaxLLM.create_runner(start_name, **kwds)
runner = openllm.AutoFlaxLLM.create_runner(start_name, **kwds)
elif envvar == "tf":
return openllm.AutoTFLLM.create_runner(start_name, **kwds)
runner = openllm.AutoTFLLM.create_runner(start_name, **kwds)
else:
return openllm.AutoLLM.create_runner(start_name, **kwds)
runner = openllm.AutoLLM.create_runner(start_name, **kwds)
if init_local:
runner.init_local()
return runner

View File

@@ -16,7 +16,7 @@ from __future__ import annotations
import openllm
class ChatGLMConfig(openllm.LLMConfig, name_type="lowercase"):
class ChatGLMConfig(openllm.LLMConfig, name_type="lowercase", trust_remote_code=True, default_timeout=3600000):
"""Configuration for the ChatGLM model."""
retain_history: bool = True

View File

@@ -44,6 +44,8 @@ class ChatGLM(openllm.LLM, _internal=True):
variants = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def model_post_init(self, _: t.Any):
self.history: list[tuple[str, str]] = []
@@ -70,9 +72,9 @@ class ChatGLM(openllm.LLM, _internal=True):
temperature: float | None = None,
**kwargs: t.Any,
) -> t.Any:
if self.config.use_half_precision:
self.model = self.model.half()
if torch.cuda.is_available():
if self.config.use_half_precision:
self.model = self.model.half()
self.model = self.model.cuda()
else:
self.model = self.model.float()
@@ -90,19 +92,22 @@ class ChatGLM(openllm.LLM, _internal=True):
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n"
prompt_text += f"[Round {len(self.history)}]\n问:{prompt}\n答:"
inputs = self.tokenizer([prompt_text], return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
generation_config=self.config.with_options(
max_length=max_length,
num_beams=num_beams,
top_p=top_p,
temperature=temperature,
do_sample=True,
**kwargs,
).to_generation_config(),
logits_processor=logit_processor,
)
inputs = self.tokenizer([prompt_text], return_tensors="pt").to(self.device)
with torch.device(self.device):
outputs = self.model.generate(
**inputs,
generation_config=self.config.with_options(
max_length=max_length,
num_beams=num_beams,
top_p=top_p,
temperature=temperature,
do_sample=True,
**kwargs,
).to_generation_config(),
logits_processor=logit_processor,
)
if torch.cuda.is_available():
outputs = outputs.cpu()
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = self.tokenizer.decode(outputs)
response = self.model.process_response(response)

View File

@@ -25,7 +25,7 @@ from .configuration_dolly_v2 import (DEFAULT_PROMPT_TEMPLATE, END_KEY,
if t.TYPE_CHECKING:
import torch
from openllm.types import LLMTokenizer
from ..._types import LLMTokenizer
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
@@ -61,6 +61,8 @@ class DollyV2(openllm.LLM, _internal=True):
import_kwargs = {"device_map": "auto", "torch_dtype": torch.bfloat16, "_tokenizer_padding_size": "left"}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@torch.inference_mode()
def generate(
self,
@@ -114,20 +116,23 @@ class DollyV2(openllm.LLM, _internal=True):
else:
in_b = input_ids.shape[0]
generated_sequence = self.model.generate(
input_ids=input_ids.to(self.model.device),
attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=do_sample,
eos_token_id=eos_token_id,
generation_config=llm_config.to_generation_config(),
)
with torch.device(self.device):
generated_sequence = self.model.generate(
input_ids=input_ids.to(self.device),
attention_mask=attention_mask.to(self.device) if attention_mask is not None else None,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=do_sample,
eos_token_id=eos_token_id,
generation_config=llm_config.to_generation_config(),
)
out_b = generated_sequence.shape[0]
generated_sequence: list[list[int]] = (
generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])[0].numpy().tolist()
)
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])[0]
if torch.cuda.is_available():
generated_sequence = generated_sequence.cpu()
generated_sequence: list[list[int]] = generated_sequence.numpy().tolist()
records: list[dict[str, t.Any]] = []
for sequence in generated_sequence:
# The response will be set to this variable if we can identify it.