mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-22 14:31:26 -05:00
feat(llm): custom load_model
This has to with loading models that requires more attention than the default bentoml.transformers.load_model Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -53,8 +53,6 @@ else:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
# NOTE: `1-2` -> text-generation and text2text-generation
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
|
||||
@@ -256,6 +254,11 @@ class LLMInterface(ABC):
|
||||
"""This function can be implemented if default import_model doesn't satisfy your needs."""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
"""This function can be implemented to override th default load_model behaviour. See falcon for
|
||||
example implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LLMMetaclass(ABCMeta):
|
||||
def __new__(
|
||||
@@ -333,6 +336,11 @@ class LLMMetaclass(ABCMeta):
|
||||
else:
|
||||
logger.debug("Using custom 'import_model' for %s", cls_name)
|
||||
|
||||
if "load_model" not in namespace:
|
||||
namespace["load_model"] = bentoml.transformers.load_model
|
||||
else:
|
||||
logger.debug("Using custom 'load_model' for %s", cls_name)
|
||||
|
||||
# NOTE: populate with default cache.
|
||||
namespace.update({k: None for k in ("__llm_bentomodel__", "__llm_model__", "__llm_tokenizer__")})
|
||||
|
||||
@@ -354,14 +362,14 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
__llm_model__: LLMModel | None = None
|
||||
__llm_tokenizer__: LLMTokenizer | None = None
|
||||
__llm_implementation__: t.Literal["pt", "tf", "flax"]
|
||||
__llm_kwargs__: dict[str, t.Any]
|
||||
__llm_args__: tuple[t.Any, ...]
|
||||
|
||||
__openllm_start_name__: str
|
||||
__openllm_requires_gpu__: bool
|
||||
__openllm_post_init__: t.Callable[[t.Self], None] | None
|
||||
|
||||
load_in_mha: bool
|
||||
_llm_attrs: dict[str, t.Any]
|
||||
_llm_args: tuple[t.Any, ...]
|
||||
|
||||
# NOTE: the following is the similar interface to HuggingFace pretrained protocol.
|
||||
|
||||
@@ -461,8 +469,8 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
self._pretrained = pretrained
|
||||
|
||||
# NOTE: Save the args and kwargs for latter load
|
||||
self.__llm_args__ = args
|
||||
self.__llm_kwargs__ = attrs
|
||||
self._llm_args = args
|
||||
self._llm_attrs = attrs
|
||||
|
||||
if self.__openllm_post_init__:
|
||||
self.__openllm_post_init__(self)
|
||||
@@ -591,8 +599,8 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
return tag
|
||||
|
||||
def ensure_pretrained_exists(self):
|
||||
trust_remote_code = self.__llm_kwargs__.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self.__llm_kwargs__)
|
||||
trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self._llm_attrs)
|
||||
try:
|
||||
return bentoml.transformers.get(tag)
|
||||
except bentoml.exceptions.BentoMLException:
|
||||
@@ -614,7 +622,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
return self.import_model(
|
||||
self._pretrained,
|
||||
tag,
|
||||
*self.__llm_args__,
|
||||
*self._llm_args,
|
||||
tokenizer_kwds=tokenizer_kwds,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwds,
|
||||
@@ -628,10 +636,10 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
def model(self) -> LLMModel | torch.nn.Module:
|
||||
"""The model to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it."""
|
||||
# Run check for GPU
|
||||
trust_remote_code = self.__llm_kwargs__.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
self.config.check_if_gpu_is_available(self.__llm_implementation__)
|
||||
|
||||
kwds = {k: v for k, v in self.__llm_kwargs__.items() if not k.startswith("_tokenizer_")}
|
||||
kwds = {k: v for k, v in self._llm_attrs.items() if not k.startswith("_tokenizer_")}
|
||||
|
||||
if self.import_kwargs:
|
||||
kwds = {**{k: v for k, v in self.import_kwargs.items() if not k.startswith("_tokenizer_")}, **kwds}
|
||||
@@ -642,7 +650,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
kwds["accelerator"] = "bettertransformer"
|
||||
|
||||
if self.__llm_model__ is None:
|
||||
self.__llm_model__ = self._bentomodel.load_model(*self.__llm_args__, **kwds)
|
||||
self.__llm_model__ = self.load_model(self.tag, *self._llm_args, **kwds)
|
||||
|
||||
if (
|
||||
self.load_in_mha
|
||||
|
||||
@@ -153,7 +153,10 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
|
||||
|
||||
|
||||
class _LazyAutoMapping(ConfigModelOrderedDict):
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping"""
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping
|
||||
This OrderedDict values() and keys() returns the list instead, so you don't
|
||||
have to do list(mapping.values()) to get the list of values.
|
||||
"""
|
||||
|
||||
def __init__(self, config_mapping: OrderedDict[str, str], model_mapping: OrderedDict[str, str]):
|
||||
self._config_mapping = config_mapping
|
||||
|
||||
@@ -59,6 +59,20 @@ class Falcon(openllm.LLM):
|
||||
)
|
||||
return bentoml.transformers.save_model(tag, pipeline, custom_objects={"tokenizer": tokenizer})
|
||||
|
||||
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
trust_remote_code = attrs.pop("trust_remote_code", True)
|
||||
torch_dtype = attrs.pop("torch_dtype", torch.bfloat16)
|
||||
device_map = attrs.pop("device_map", "auto")
|
||||
_ref = bentoml.transformers.get(tag)
|
||||
return bentoml.transformers.load_model(
|
||||
_ref,
|
||||
tokenizer=_ref.custom_objects["tokenizer"],
|
||||
trust_remote_code=trust_remote_code,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch_dtype,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
Reference in New Issue
Block a user