mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-22 14:31:26 -05:00
feat(gpu): Make sure that we run models on GPU if available
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -49,7 +49,7 @@ dependencies = [
|
||||
# bentoml[grpc,grpc-reflection] include grpcio, grpcio-reflection
|
||||
"bentoml[io-image,io-pandas,io-file,grpc,grpc-reflection]>=1.0.19",
|
||||
# bentoml[torch] includes torch and transformers
|
||||
"transformers[torch]>=4.29.0",
|
||||
"transformers[torch,accelerate,tokenizers,onnxruntime,onnx]>=4.29.0",
|
||||
# Super fast JSON serialization
|
||||
"orjson",
|
||||
"inflection",
|
||||
|
||||
@@ -431,6 +431,7 @@ class LLMConfig(pydantic.BaseModel, ABC):
|
||||
__openllm_start_name__: str = ""
|
||||
__openllm_timeout__: int = 0
|
||||
__openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize"
|
||||
__openllm_trust_remote_code__: bool = False
|
||||
GenerationConfig: type[t.Any] = GenerationConfig
|
||||
|
||||
def __init_subclass__(
|
||||
@@ -438,6 +439,7 @@ class LLMConfig(pydantic.BaseModel, ABC):
|
||||
*,
|
||||
default_timeout: int | None = None,
|
||||
name_type: t.Literal["dasherize", "lowercase"] = "dasherize",
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs: t.Any,
|
||||
):
|
||||
if default_timeout is None:
|
||||
@@ -446,6 +448,7 @@ class LLMConfig(pydantic.BaseModel, ABC):
|
||||
if name_type not in ("dasherize", "lowercase"):
|
||||
raise RuntimeError(f"Unknown name_type {name_type}. Only allowed are 'dasherize' and 'lowercase'.")
|
||||
cls.__openllm_name_type__ = name_type
|
||||
cls.__openllm_trust_remote_code__ = trust_remote_code
|
||||
|
||||
super(LLMConfig, cls).__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
from bentoml._internal.runner.strategy import Strategy
|
||||
|
||||
from .types import LLMModel, LLMTokenizer, ModelSignatureDict
|
||||
from ._types import LLMModel, LLMTokenizer, ModelSignatureDict
|
||||
else:
|
||||
ModelSignatureDict = dict
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
@@ -296,7 +296,8 @@ class LLM(LLMInterface):
|
||||
dolly_v2_runner = openllm.Runner("dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat8, device_map='gpu')
|
||||
```
|
||||
|
||||
Note: If you implement your own `import_model`, then `import_kwargs` will be ignored.
|
||||
Note: If you implement your own `import_model`, then `import_kwargs` will be the default kwargs for every load. You can still override those
|
||||
via ``openllm.Runner``.
|
||||
|
||||
Note that this tag will be generated based on `self.default_model` or the given `pretrained` kwds.
|
||||
passed from the __init__ constructor.
|
||||
@@ -334,11 +335,26 @@ class LLM(LLMInterface):
|
||||
@property
|
||||
def _bentomodel(self) -> bentoml.Model:
|
||||
if self.__bentomodel__ is None:
|
||||
tag, kwds = openllm.utils.generate_tags(self._pretrained, prefix=self._implementation, **self._kwargs)
|
||||
trust_remote_code = self._kwargs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
tag, kwds = openllm.utils.generate_tags(
|
||||
self._pretrained, prefix=self._implementation, trust_remote_code=trust_remote_code, **self._kwargs
|
||||
)
|
||||
|
||||
tokenizer_kwds = {k[len("_tokenizer_") :]: v for k, v in kwds.items() if k.startswith("_tokenizer_")}
|
||||
kwds = {k: v for k, v in kwds.items() if not k.startswith("_tokenizer_")}
|
||||
|
||||
if self.import_kwargs:
|
||||
tokenizer_kwds = {
|
||||
**{
|
||||
k[len("_tokenizer_") :]: v for k, v in self.import_kwargs.items() if k.startswith("_tokenizer_")
|
||||
},
|
||||
**tokenizer_kwds,
|
||||
}
|
||||
kwds = {
|
||||
**{k: v for k, v in self.import_kwargs.items() if not k.startswith("_tokenizer_")},
|
||||
**kwds,
|
||||
}
|
||||
|
||||
try:
|
||||
self.__bentomodel__ = bentoml.transformers.get(tag)
|
||||
except bentoml.exceptions.BentoMLException:
|
||||
@@ -348,17 +364,21 @@ class LLM(LLMInterface):
|
||||
if hasattr(self, "import_model"):
|
||||
logger.debug("Using custom 'import_model' defined in subclass.")
|
||||
self.__bentomodel__ = self.import_model(
|
||||
self._pretrained, tag, *self._args, tokenizer_kwds=tokenizer_kwds, **kwds
|
||||
self._pretrained,
|
||||
tag,
|
||||
*self._args,
|
||||
tokenizer_kwds=tokenizer_kwds,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwds,
|
||||
)
|
||||
else:
|
||||
if self.import_kwargs:
|
||||
kwds = {**self.import_kwargs, **kwds}
|
||||
# NOTE: In this branch, we just use the default implementation.
|
||||
self.__bentomodel__ = import_model(
|
||||
self._pretrained,
|
||||
tag,
|
||||
*self._args,
|
||||
tokenizer_kwds=tokenizer_kwds,
|
||||
trust_remote_code=trust_remote_code,
|
||||
__openllm_framework__=self._implementation,
|
||||
**kwds,
|
||||
)
|
||||
@@ -473,10 +493,24 @@ class LLM(LLMInterface):
|
||||
|
||||
|
||||
def Runner(start_name: str, **kwds: t.Any):
|
||||
"""Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'
|
||||
|
||||
Args:
|
||||
start_name: Supported model name from 'openllm models'
|
||||
init_local: Whether to init_local this given Runner. This is useful during development. (Default to False)
|
||||
**kwds: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs
|
||||
behaviour
|
||||
"""
|
||||
init_local = kwds.pop("init_local", False)
|
||||
envvar = openllm.utils.get_framework_env(start_name)
|
||||
if envvar == "flax":
|
||||
return openllm.AutoFlaxLLM.create_runner(start_name, **kwds)
|
||||
runner = openllm.AutoFlaxLLM.create_runner(start_name, **kwds)
|
||||
elif envvar == "tf":
|
||||
return openllm.AutoTFLLM.create_runner(start_name, **kwds)
|
||||
runner = openllm.AutoTFLLM.create_runner(start_name, **kwds)
|
||||
else:
|
||||
return openllm.AutoLLM.create_runner(start_name, **kwds)
|
||||
runner = openllm.AutoLLM.create_runner(start_name, **kwds)
|
||||
|
||||
if init_local:
|
||||
runner.init_local()
|
||||
|
||||
return runner
|
||||
|
||||
@@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
|
||||
class ChatGLMConfig(openllm.LLMConfig, name_type="lowercase"):
|
||||
class ChatGLMConfig(openllm.LLMConfig, name_type="lowercase", trust_remote_code=True, default_timeout=3600000):
|
||||
"""Configuration for the ChatGLM model."""
|
||||
|
||||
retain_history: bool = True
|
||||
|
||||
@@ -44,6 +44,8 @@ class ChatGLM(openllm.LLM, _internal=True):
|
||||
|
||||
variants = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"]
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def model_post_init(self, _: t.Any):
|
||||
self.history: list[tuple[str, str]] = []
|
||||
|
||||
@@ -70,9 +72,9 @@ class ChatGLM(openllm.LLM, _internal=True):
|
||||
temperature: float | None = None,
|
||||
**kwargs: t.Any,
|
||||
) -> t.Any:
|
||||
if self.config.use_half_precision:
|
||||
self.model = self.model.half()
|
||||
if torch.cuda.is_available():
|
||||
if self.config.use_half_precision:
|
||||
self.model = self.model.half()
|
||||
self.model = self.model.cuda()
|
||||
else:
|
||||
self.model = self.model.float()
|
||||
@@ -90,19 +92,22 @@ class ChatGLM(openllm.LLM, _internal=True):
|
||||
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n"
|
||||
prompt_text += f"[Round {len(self.history)}]\n问:{prompt}\n答:"
|
||||
|
||||
inputs = self.tokenizer([prompt_text], return_tensors="pt").to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
generation_config=self.config.with_options(
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
do_sample=True,
|
||||
**kwargs,
|
||||
).to_generation_config(),
|
||||
logits_processor=logit_processor,
|
||||
)
|
||||
inputs = self.tokenizer([prompt_text], return_tensors="pt").to(self.device)
|
||||
with torch.device(self.device):
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
generation_config=self.config.with_options(
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
do_sample=True,
|
||||
**kwargs,
|
||||
).to_generation_config(),
|
||||
logits_processor=logit_processor,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
outputs = outputs.cpu()
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
|
||||
response = self.tokenizer.decode(outputs)
|
||||
response = self.model.process_response(response)
|
||||
|
||||
@@ -25,7 +25,7 @@ from .configuration_dolly_v2 import (DEFAULT_PROMPT_TEMPLATE, END_KEY,
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from openllm.types import LLMTokenizer
|
||||
from ..._types import LLMTokenizer
|
||||
else:
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
|
||||
@@ -61,6 +61,8 @@ class DollyV2(openllm.LLM, _internal=True):
|
||||
|
||||
import_kwargs = {"device_map": "auto", "torch_dtype": torch.bfloat16, "_tokenizer_padding_size": "left"}
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
@@ -114,20 +116,23 @@ class DollyV2(openllm.LLM, _internal=True):
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
|
||||
generated_sequence = self.model.generate(
|
||||
input_ids=input_ids.to(self.model.device),
|
||||
attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
do_sample=do_sample,
|
||||
eos_token_id=eos_token_id,
|
||||
generation_config=llm_config.to_generation_config(),
|
||||
)
|
||||
with torch.device(self.device):
|
||||
generated_sequence = self.model.generate(
|
||||
input_ids=input_ids.to(self.device),
|
||||
attention_mask=attention_mask.to(self.device) if attention_mask is not None else None,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
do_sample=do_sample,
|
||||
eos_token_id=eos_token_id,
|
||||
generation_config=llm_config.to_generation_config(),
|
||||
)
|
||||
|
||||
out_b = generated_sequence.shape[0]
|
||||
|
||||
generated_sequence: list[list[int]] = (
|
||||
generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])[0].numpy().tolist()
|
||||
)
|
||||
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])[0]
|
||||
if torch.cuda.is_available():
|
||||
generated_sequence = generated_sequence.cpu()
|
||||
|
||||
generated_sequence: list[list[int]] = generated_sequence.numpy().tolist()
|
||||
records: list[dict[str, t.Any]] = []
|
||||
for sequence in generated_sequence:
|
||||
# The response will be set to this variable if we can identify it.
|
||||
|
||||
Reference in New Issue
Block a user