feat(llm): custom load_model

This has to with loading models that requires more attention
than the default bentoml.transformers.load_model

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-06-08 04:07:07 -04:00
parent 4369395520
commit 378b209d67
3 changed files with 38 additions and 13 deletions

View File

@@ -53,8 +53,6 @@ else:
logger = logging.getLogger(__name__)
_object_setattr = object.__setattr__
# NOTE: `1-2` -> text-generation and text2text-generation
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
@@ -256,6 +254,11 @@ class LLMInterface(ABC):
"""This function can be implemented if default import_model doesn't satisfy your needs."""
raise NotImplementedError
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:
"""This function can be implemented to override th default load_model behaviour. See falcon for
example implementation."""
raise NotImplementedError
class LLMMetaclass(ABCMeta):
def __new__(
@@ -333,6 +336,11 @@ class LLMMetaclass(ABCMeta):
else:
logger.debug("Using custom 'import_model' for %s", cls_name)
if "load_model" not in namespace:
namespace["load_model"] = bentoml.transformers.load_model
else:
logger.debug("Using custom 'load_model' for %s", cls_name)
# NOTE: populate with default cache.
namespace.update({k: None for k in ("__llm_bentomodel__", "__llm_model__", "__llm_tokenizer__")})
@@ -354,14 +362,14 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
__llm_model__: LLMModel | None = None
__llm_tokenizer__: LLMTokenizer | None = None
__llm_implementation__: t.Literal["pt", "tf", "flax"]
__llm_kwargs__: dict[str, t.Any]
__llm_args__: tuple[t.Any, ...]
__openllm_start_name__: str
__openllm_requires_gpu__: bool
__openllm_post_init__: t.Callable[[t.Self], None] | None
load_in_mha: bool
_llm_attrs: dict[str, t.Any]
_llm_args: tuple[t.Any, ...]
# NOTE: the following is the similar interface to HuggingFace pretrained protocol.
@@ -461,8 +469,8 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
self._pretrained = pretrained
# NOTE: Save the args and kwargs for latter load
self.__llm_args__ = args
self.__llm_kwargs__ = attrs
self._llm_args = args
self._llm_attrs = attrs
if self.__openllm_post_init__:
self.__openllm_post_init__(self)
@@ -591,8 +599,8 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
return tag
def ensure_pretrained_exists(self):
trust_remote_code = self.__llm_kwargs__.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self.__llm_kwargs__)
trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self._llm_attrs)
try:
return bentoml.transformers.get(tag)
except bentoml.exceptions.BentoMLException:
@@ -614,7 +622,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
return self.import_model(
self._pretrained,
tag,
*self.__llm_args__,
*self._llm_args,
tokenizer_kwds=tokenizer_kwds,
trust_remote_code=trust_remote_code,
**kwds,
@@ -628,10 +636,10 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
def model(self) -> LLMModel | torch.nn.Module:
"""The model to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it."""
# Run check for GPU
trust_remote_code = self.__llm_kwargs__.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
self.config.check_if_gpu_is_available(self.__llm_implementation__)
kwds = {k: v for k, v in self.__llm_kwargs__.items() if not k.startswith("_tokenizer_")}
kwds = {k: v for k, v in self._llm_attrs.items() if not k.startswith("_tokenizer_")}
if self.import_kwargs:
kwds = {**{k: v for k, v in self.import_kwargs.items() if not k.startswith("_tokenizer_")}, **kwds}
@@ -642,7 +650,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
kwds["accelerator"] = "bettertransformer"
if self.__llm_model__ is None:
self.__llm_model__ = self._bentomodel.load_model(*self.__llm_args__, **kwds)
self.__llm_model__ = self.load_model(self.tag, *self._llm_args, **kwds)
if (
self.load_in_mha

View File

@@ -153,7 +153,10 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
class _LazyAutoMapping(ConfigModelOrderedDict):
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping"""
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping
This OrderedDict values() and keys() returns the list instead, so you don't
have to do list(mapping.values()) to get the list of values.
"""
def __init__(self, config_mapping: OrderedDict[str, str], model_mapping: OrderedDict[str, str]):
self._config_mapping = config_mapping

View File

@@ -59,6 +59,20 @@ class Falcon(openllm.LLM):
)
return bentoml.transformers.save_model(tag, pipeline, custom_objects={"tokenizer": tokenizer})
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:
trust_remote_code = attrs.pop("trust_remote_code", True)
torch_dtype = attrs.pop("torch_dtype", torch.bfloat16)
device_map = attrs.pop("device_map", "auto")
_ref = bentoml.transformers.get(tag)
return bentoml.transformers.load_model(
_ref,
tokenizer=_ref.custom_objects["tokenizer"],
trust_remote_code=trust_remote_code,
device_map=device_map,
torch_dtype=torch_dtype,
**attrs,
)
def sanitize_parameters(
self,
prompt: str,