From d5f213e49cdbb097f13ff6256972caadb7dcf48d Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Sun, 28 May 2023 19:14:51 -0700 Subject: [PATCH] compat: provide shim for pydantic 1 and 2 Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- src/openllm/_llm.py | 151 ++++++++++++++---- .../models/chatglm/modeling_chatglm.py | 3 +- .../models/stablelm/modeling_stablelm.py | 2 +- src/openllm/utils/__init__.py | 73 +-------- src/openllm/utils/dantic.py | 9 +- 5 files changed, 129 insertions(+), 109 deletions(-) diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 3f70fc98..cfbbf3fe 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -19,6 +19,7 @@ import enum import functools import logging import os +import re import typing as t from abc import ABC, ABCMeta, abstractmethod @@ -29,7 +30,7 @@ import openllm from ._configuration import ModelSignature from .exceptions import ForbiddenAttributeError, OpenLLMException -from .utils import LazyLoader, bentoml_cattr +from .utils import ENV_VARS_TRUE_VALUES, LazyLoader, bentoml_cattr if t.TYPE_CHECKING: import torch @@ -44,13 +45,21 @@ else: logger = logging.getLogger(__name__) # NOTE: `1-2` -> text-generation and text2text-generation -_return_tensors_to_framework_map = { +FRAMEWORK_TO_AUTOCLASS_MAPPING = { "pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), } +def convert_transformers_model_name(name: str) -> str: + if os.path.exists(os.path.dirname(name)): + name = os.path.basename(name) + logger.debug("Given name is a path, only returning the basename %s") + return name + return re.sub("[^a-zA-Z0-9]+", "-", name) + + class TypeMeta(enum.EnumMeta): def __getitem__(self, key: str) -> enum.Enum: # Type safe getters @@ -133,7 +142,7 @@ def import_model( return bentoml.transformers.save_model( tag, getattr( - transformers, _return_tensors_to_framework_map[_model_framework][TaskType[task_type].value - 1] + transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[_model_framework][TaskType[task_type].value - 1] ).from_pretrained( model_name, *model_args, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs ), @@ -286,14 +295,12 @@ class LLMMetaclass(ABCMeta): namespace[key] = getattr(config_class, key) # NOTE: set a default variable to transform to BetterTransformer by default for inference - namespace["__openllm_bettertransformer__"] = namespace.get( - "__openllm_bettertransformer__", implementation in ("pt",) - ) + namespace["load_in_mha"] = namespace.get("load_in_mha", implementation in ("pt",)) if namespace["__openllm_requires_gpu__"]: # For all models that requires GPU, no need to offload it to BetterTransformer # use bitsandbytes instead - namespace["__openllm_bettertransformer__"] = False + namespace["load_in_mha"] = False # NOTE: import_model branch if "import_model" not in namespace: @@ -328,11 +335,11 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): __openllm_start_name__: str __openllm_requires_gpu__: bool - __openllm_bettertransformer__: bool __openllm_post_init__: t.Callable[[t.Self], None] | None # NOTE: the following will be populated by __init__ config: openllm.LLMConfig + load_in_mha: bool # NOTE: the following is the similar interface to HuggingFace pretrained protocol. @@ -346,7 +353,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): self, pretrained: str | None = None, llm_config: openllm.LLMConfig | None = None, - load_as_bttransformer: bool | None = None, + load_in_mha: bool | None = None, *args: t.Any, **attrs: t.Any, ): @@ -411,9 +418,9 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): pretrained: The pretrained model to use. Defaults to None. It will use 'self.default_model' if None. llm_config: The config to use for this LLM. Defaults to None. If not passed, we will use 'self.config_class' to construct default configuration. - load_as_bttransformer: Whether to apply BetterTransformer during inference load. - See https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/ - for more information. + load_in_mha: Whether to apply BetterTransformer (or Torch MultiHeadAttention) during inference load. + See https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/ + for more information. *args: The args to be passed to the model. **attrs: The kwargs to be passed to the model. """ @@ -441,11 +448,11 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): self.__llm_args__ = args self.__llm_kwargs__ = attrs - if load_as_bttransformer is None: - load_as_bttransformer = self.__openllm_bettertransformer__ - self._load_as_bttransformer = os.environ.get( - self.config.__openllm_env__.bettertransformer, load_as_bttransformer - ) + if load_in_mha is not None: + self.load_in_mha = ( + os.environ.get(self.config.__openllm_env__.bettertransformer, str(load_in_mha)).upper() + in ENV_VARS_TRUE_VALUES + ) if self.__openllm_post_init__: self.__openllm_post_init__(self) @@ -463,7 +470,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): # NOTE: The section below defines a loose contract with langchain's LLM interface. @property def llm_type(self) -> str: - return openllm.utils.convert_transformers_model_name(self._pretrained) + return convert_transformers_model_name(self._pretrained) @property def identifying_params(self) -> dict[str, t.Any]: @@ -476,31 +483,106 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): return self.__llm_bentomodel__ @t.overload - def make_tag(self, return_unused_kwargs: t.Literal[False] = ..., trust_remote_code: bool = ...) -> bentoml.Tag: + def make_tag( + self, + model_name_or_path: str | None = None, + return_unused_kwargs: t.Literal[False] = ..., + trust_remote_code: bool = ..., + **attrs: t.Any, + ) -> bentoml.Tag: ... @t.overload def make_tag( - self, return_unused_kwargs: t.Literal[True] = ..., trust_remote_code: bool = ... + self, + model_name_or_path: str | None = None, + return_unused_kwargs: t.Literal[True] = ..., + trust_remote_code: bool = ..., + **attrs: t.Any, ) -> tuple[bentoml.Tag, dict[str, t.Any]]: ... def make_tag( - self, return_unused_kwargs: bool = False, trust_remote_code: bool = False + self, + model_name_or_path: str | None = None, + return_unused_kwargs: bool = False, + trust_remote_code: bool = False, + **attrs: t.Any, ) -> bentoml.Tag | tuple[bentoml.Tag, dict[str, t.Any]]: - tag, kwds = openllm.utils.generate_tags( - self._pretrained, - prefix=self.__llm_implementation__, - trust_remote_code=trust_remote_code, - **self.__llm_kwargs__, + """Generate a ``bentoml.Tag`` from a given transformers model name. + + Note that this depends on your model to have a config class available. + + Args: + model_name_or_path: The transformers model name or path to load the model from. + If it is a path, then `openllm_model_version` must be passed in as a kwarg. + return_unused_kwargs: Whether to return unused kwargs. Defaults to False. If set, it will return a tuple + of ``bentoml.Tag`` and a dict of unused kwargs. + trust_remote_code: Whether to trust the remote code. Defaults to False. + **attrs: Additional kwargs to pass to the ``transformers.AutoConfig`` constructor. + If your pass ``return_unused_kwargs=True``, it will be ignored. + + Returns: + A tuple of ``bentoml.Tag`` and a dict of unused kwargs. + """ + if model_name_or_path is None: + model_name_or_path = self._pretrained + + if "return_unused_kwargs" in attrs: + logger.debug("Ignoring 'return_unused_kwargs' in 'generate_tag_from_model_name'.") + attrs.pop("return_unused_kwargs", None) + + config, attrs = t.cast( + "tuple[transformers.PretrainedConfig, dict[str, t.Any]]", + transformers.AutoConfig.from_pretrained( + model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **attrs + ), ) - if not return_unused_kwargs: - return tag - return tag, kwds + name = convert_transformers_model_name(model_name_or_path) + + if os.path.exists(os.path.dirname(model_name_or_path)): + # If the model_name_or_path is a path, we assume it's a local path, + # then users must pass a version for this. + model_version = attrs.pop("openllm_model_version", None) + if model_version is None: + logger.warning( + """\ + When passing a path, it is recommended to also pass 'openllm_model_version' + into Runner/AutoLLM intialization. + + For example: + + >>> import openllm + >>> runner = openllm.Runner('/path/to/fine-tuning/model', openllm_model_version='lora-version') + + Example with AutoLLM: + + >>> import openllm + >>> model = openllm.AutoLLM.for_model('/path/to/fine-tuning/model', openllm_model_version='lora-version') + + No worries, OpenLLM will generate one for you. But for your own convenience, make sure to + specify 'openllm_model_version'. + """ + ) + model_version = bentoml.Tag.from_taglike(name).make_new_version().version + else: + model_version = getattr(config, "_commit_hash", None) + if model_version is None: + logger.warning( + "Given %s from '%s' doesn't contain a commit hash. We will generate" + " the tag without specific version.", + t.cast("type[transformers.PretrainedConfig]", config.__class__), + model_name_or_path, + ) + tag = bentoml.Tag.from_taglike(f"{self.__llm_implementation__}-{name}:{model_version}") + + if return_unused_kwargs: + return tag, attrs + return tag def ensure_pretrained_exists(self): trust_remote_code = self.__llm_kwargs__.pop("trust_remote_code", self.config.__openllm_trust_remote_code__) - tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code) + tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self.__llm_kwargs__) try: return bentoml.transformers.get(tag) except bentoml.exceptions.BentoMLException: @@ -546,21 +628,22 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): else: kwds = self.__llm_kwargs__ kwds["trust_remote_code"] = trust_remote_code - if self._load_as_bttransformer and "_pretrained_class" not in self._bentomodel.info.metadata: + if self.load_in_mha and "_pretrained_class" not in self._bentomodel.info.metadata: # This is a pipeline, provide a accelerator args kwds["accelerator"] = "bettertransformer" if self.__llm_model__ is None: # Hmm, bentoml.transformers.load_model doesn't yet support args. self.__llm_model__ = self._bentomodel.load_model(*self.__llm_args__, **kwds) + if ( - self._load_as_bttransformer - and "_framework" in self._bentomodel.info.metadata + self.load_in_mha + and all(i in self._bentomodel.info.metadata for i in ("_framework", "_pretrained_class")) and self._bentomodel.info.metadata["_framework"] == "torch" ): from optimum.bettertransformer import BetterTransformer - return BetterTransformer.transform(self.__llm_model__) + self.__llm_model__ = BetterTransformer.transform(self.__llm_model__) return self.__llm_model__ @property diff --git a/src/openllm/models/chatglm/modeling_chatglm.py b/src/openllm/models/chatglm/modeling_chatglm.py index 73b3ee05..fc2d2420 100644 --- a/src/openllm/models/chatglm/modeling_chatglm.py +++ b/src/openllm/models/chatglm/modeling_chatglm.py @@ -40,7 +40,8 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): class ChatGLM(openllm.LLM): __openllm_internal__ = True - __openllm_bettertransformer__ = False # NOTE: disable bettertransformer for ChatGLM since it is already quantized + + load_in_mha = False # NOTE: disable bettertransformer for ChatGLM since it is already quantized default_model = "THUDM/chatglm-6b-int4" diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index ec404db1..ef54c77b 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) class StableLM(openllm.LLM): __openllm_internal__ = True - __openllm_bettertransformer__ = False + load_in_mha = False default_model = "StabilityAI/stablelm-tuned-alpha-3b" diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index c51d68c7..3bbc9944 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -22,12 +22,10 @@ import importlib.machinery import itertools import logging import os -import re import types import typing as t import attrs -import bentoml import inflection from bentoml._internal.types import LazyType as LazyType @@ -42,6 +40,7 @@ from bentoml._internal.utils import resolve_user_filepath as resolve_user_filepa from . import analytics as analytics from . import codegen as codegen from . import dantic as dantic +from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES from .import_utils import DummyMetaclass as DummyMetaclass from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available from .import_utils import is_einops_available as is_einops_available @@ -104,76 +103,6 @@ class ModelEnv: return envvar -def convert_transformers_model_name(name: str) -> str: - if os.path.exists(os.path.dirname(name)): - name = os.path.basename(name) - logger.debug("Given name is a path, only returning the basename %s") - return name - return re.sub("[^a-zA-Z0-9]+", "-", name) - - -def generate_tags( - model_name_or_path: str, prefix: str | None = None, **attrs: t.Any -) -> tuple[bentoml.Tag, dict[str, t.Any]]: - """Generate a ``bentoml.Tag`` from a given transformers model name. - - Note that this depends on your model to have a config class available. - - Args: - model_name_or_path: The transformers model name or path to load the model from. - If it is a path, then `openllm_model_version` must be passed in as a kwarg. - prefix: The prefix to prepend to the tag. If None, then no prefix will be prepended. - **attrs: Additional kwargs to pass to the ``transformers.AutoConfig`` constructor. - If your pass ``return_unused_kwargs=True``, it will be ignored. - - Returns: - A tuple of ``bentoml.Tag`` and a dict of unused kwargs. - """ - if "return_unused_kwargs" in attrs: - logger.debug("Ignoring 'return_unused_kwargs' in 'generate_tag_from_model_name'.") - attrs.pop("return_unused_kwargs", None) - - config, attrs = t.cast( - "tuple[transformers.PretrainedConfig, dict[str, t.Any]]", - transformers.AutoConfig.from_pretrained(model_name_or_path, return_unused_kwargs=True, **attrs), - ) - name = convert_transformers_model_name(model_name_or_path) - - if os.path.exists(os.path.dirname(model_name_or_path)): - # If the model_name_or_path is a path, we assume it's a local path, - # then users must pass a version for this. - model_version = attrs.pop("openllm_model_version", None) - if model_version is None: - logger.warning( - """\ - When passing a path, it is recommended to also pass 'openllm_model_version' into Runner/AutoLLM intialization. - - For example: - - >>> import openllm - >>> runner = openllm.Runner('/path/to/fine-tuning/model', openllm_model_version='lora-version') - - Example with AutoLLM: - - >>> import openllm - >>> model = openllm.AutoLLM.for_model('/path/to/fine-tuning/model', openllm_model_version='lora-version') - - No worries, OpenLLM will generate one for you. But for your own convenience, make sure to - specify 'openllm_model_version'. - """ - ) - model_version = bentoml.Tag.from_taglike(name).make_new_version().version - else: - model_version = getattr(config, "_commit_hash", None) - if model_version is None: - logger.warning( - "Given %s from '%s' doesn't contain a commit hash. We will generate the tag without specific version.", - t.cast("type[transformers.PretrainedConfig]", config.__class__), - model_name_or_path, - ) - return bentoml.Tag.from_taglike((f"{prefix}-" if prefix is not None else "") + f"{name}:{model_version}"), attrs - - class LazyModule(types.ModuleType): """ Module class that surfaces all objects but only performs associated imports when the objects are requested. diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index 7b8749af..e573f570 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -14,10 +14,17 @@ import click import orjson from click import ParamType from pydantic import BaseModel -from pydantic._internal._utils import lenient_issubclass +from pydantic.version import VERSION as PYDANTIC_VERSION from . import LazyType +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + +if PYDANTIC_V2: + from pydantic._internal._utils import lenient_issubclass +else: + from pydantic.utils import lenient_issubclass as lenient_issubclass + def parse_default(default: t.Any, field_type: t.Any) -> t.Any: """Converts pydantic defaults into click default types.