diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 2602aa04..9b07c326 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -53,8 +53,6 @@ else: logger = logging.getLogger(__name__) -_object_setattr = object.__setattr__ - # NOTE: `1-2` -> text-generation and text2text-generation FRAMEWORK_TO_AUTOCLASS_MAPPING = { "pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), @@ -256,6 +254,11 @@ class LLMInterface(ABC): """This function can be implemented if default import_model doesn't satisfy your needs.""" raise NotImplementedError + def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: + """This function can be implemented to override th default load_model behaviour. See falcon for + example implementation.""" + raise NotImplementedError + class LLMMetaclass(ABCMeta): def __new__( @@ -333,6 +336,11 @@ class LLMMetaclass(ABCMeta): else: logger.debug("Using custom 'import_model' for %s", cls_name) + if "load_model" not in namespace: + namespace["load_model"] = bentoml.transformers.load_model + else: + logger.debug("Using custom 'load_model' for %s", cls_name) + # NOTE: populate with default cache. namespace.update({k: None for k in ("__llm_bentomodel__", "__llm_model__", "__llm_tokenizer__")}) @@ -354,14 +362,14 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): __llm_model__: LLMModel | None = None __llm_tokenizer__: LLMTokenizer | None = None __llm_implementation__: t.Literal["pt", "tf", "flax"] - __llm_kwargs__: dict[str, t.Any] - __llm_args__: tuple[t.Any, ...] __openllm_start_name__: str __openllm_requires_gpu__: bool __openllm_post_init__: t.Callable[[t.Self], None] | None load_in_mha: bool + _llm_attrs: dict[str, t.Any] + _llm_args: tuple[t.Any, ...] # NOTE: the following is the similar interface to HuggingFace pretrained protocol. @@ -461,8 +469,8 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): self._pretrained = pretrained # NOTE: Save the args and kwargs for latter load - self.__llm_args__ = args - self.__llm_kwargs__ = attrs + self._llm_args = args + self._llm_attrs = attrs if self.__openllm_post_init__: self.__openllm_post_init__(self) @@ -591,8 +599,8 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): return tag def ensure_pretrained_exists(self): - trust_remote_code = self.__llm_kwargs__.pop("trust_remote_code", self.config.__openllm_trust_remote_code__) - tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self.__llm_kwargs__) + trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__) + tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self._llm_attrs) try: return bentoml.transformers.get(tag) except bentoml.exceptions.BentoMLException: @@ -614,7 +622,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): return self.import_model( self._pretrained, tag, - *self.__llm_args__, + *self._llm_args, tokenizer_kwds=tokenizer_kwds, trust_remote_code=trust_remote_code, **kwds, @@ -628,10 +636,10 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): def model(self) -> LLMModel | torch.nn.Module: """The model to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it.""" # Run check for GPU - trust_remote_code = self.__llm_kwargs__.pop("trust_remote_code", self.config.__openllm_trust_remote_code__) + trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__) self.config.check_if_gpu_is_available(self.__llm_implementation__) - kwds = {k: v for k, v in self.__llm_kwargs__.items() if not k.startswith("_tokenizer_")} + kwds = {k: v for k, v in self._llm_attrs.items() if not k.startswith("_tokenizer_")} if self.import_kwargs: kwds = {**{k: v for k, v in self.import_kwargs.items() if not k.startswith("_tokenizer_")}, **kwds} @@ -642,7 +650,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): kwds["accelerator"] = "bettertransformer" if self.__llm_model__ is None: - self.__llm_model__ = self._bentomodel.load_model(*self.__llm_args__, **kwds) + self.__llm_model__ = self.load_model(self.tag, *self._llm_args, **kwds) if ( self.load_in_mha diff --git a/src/openllm/models/auto/factory.py b/src/openllm/models/auto/factory.py index a2203afa..1bb00736 100644 --- a/src/openllm/models/auto/factory.py +++ b/src/openllm/models/auto/factory.py @@ -153,7 +153,10 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any: class _LazyAutoMapping(ConfigModelOrderedDict): - """Based on transformers.models.auto.configuration_auto._LazyAutoMapping""" + """Based on transformers.models.auto.configuration_auto._LazyAutoMapping + This OrderedDict values() and keys() returns the list instead, so you don't + have to do list(mapping.values()) to get the list of values. + """ def __init__(self, config_mapping: OrderedDict[str, str], model_mapping: OrderedDict[str, str]): self._config_mapping = config_mapping diff --git a/src/openllm/models/falcon/modeling_falcon.py b/src/openllm/models/falcon/modeling_falcon.py index 9b0ee0a3..c35f3e64 100644 --- a/src/openllm/models/falcon/modeling_falcon.py +++ b/src/openllm/models/falcon/modeling_falcon.py @@ -59,6 +59,20 @@ class Falcon(openllm.LLM): ) return bentoml.transformers.save_model(tag, pipeline, custom_objects={"tokenizer": tokenizer}) + def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: + trust_remote_code = attrs.pop("trust_remote_code", True) + torch_dtype = attrs.pop("torch_dtype", torch.bfloat16) + device_map = attrs.pop("device_map", "auto") + _ref = bentoml.transformers.get(tag) + return bentoml.transformers.load_model( + _ref, + tokenizer=_ref.custom_objects["tokenizer"], + trust_remote_code=trust_remote_code, + device_map=device_map, + torch_dtype=torch_dtype, + **attrs, + ) + def sanitize_parameters( self, prompt: str,