compat: provide shim for pydantic 1 and 2

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-05-28 19:14:51 -07:00
parent aff23584a5
commit d5f213e49c
5 changed files with 129 additions and 109 deletions

View File

@@ -19,6 +19,7 @@ import enum
import functools
import logging
import os
import re
import typing as t
from abc import ABC, ABCMeta, abstractmethod
@@ -29,7 +30,7 @@ import openllm
from ._configuration import ModelSignature
from .exceptions import ForbiddenAttributeError, OpenLLMException
from .utils import LazyLoader, bentoml_cattr
from .utils import ENV_VARS_TRUE_VALUES, LazyLoader, bentoml_cattr
if t.TYPE_CHECKING:
import torch
@@ -44,13 +45,21 @@ else:
logger = logging.getLogger(__name__)
# NOTE: `1-2` -> text-generation and text2text-generation
_return_tensors_to_framework_map = {
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
"tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"),
"flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
}
def convert_transformers_model_name(name: str) -> str:
if os.path.exists(os.path.dirname(name)):
name = os.path.basename(name)
logger.debug("Given name is a path, only returning the basename %s")
return name
return re.sub("[^a-zA-Z0-9]+", "-", name)
class TypeMeta(enum.EnumMeta):
def __getitem__(self, key: str) -> enum.Enum:
# Type safe getters
@@ -133,7 +142,7 @@ def import_model(
return bentoml.transformers.save_model(
tag,
getattr(
transformers, _return_tensors_to_framework_map[_model_framework][TaskType[task_type].value - 1]
transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[_model_framework][TaskType[task_type].value - 1]
).from_pretrained(
model_name, *model_args, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
),
@@ -286,14 +295,12 @@ class LLMMetaclass(ABCMeta):
namespace[key] = getattr(config_class, key)
# NOTE: set a default variable to transform to BetterTransformer by default for inference
namespace["__openllm_bettertransformer__"] = namespace.get(
"__openllm_bettertransformer__", implementation in ("pt",)
)
namespace["load_in_mha"] = namespace.get("load_in_mha", implementation in ("pt",))
if namespace["__openllm_requires_gpu__"]:
# For all models that requires GPU, no need to offload it to BetterTransformer
# use bitsandbytes instead
namespace["__openllm_bettertransformer__"] = False
namespace["load_in_mha"] = False
# NOTE: import_model branch
if "import_model" not in namespace:
@@ -328,11 +335,11 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
__openllm_start_name__: str
__openllm_requires_gpu__: bool
__openllm_bettertransformer__: bool
__openllm_post_init__: t.Callable[[t.Self], None] | None
# NOTE: the following will be populated by __init__
config: openllm.LLMConfig
load_in_mha: bool
# NOTE: the following is the similar interface to HuggingFace pretrained protocol.
@@ -346,7 +353,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
self,
pretrained: str | None = None,
llm_config: openllm.LLMConfig | None = None,
load_as_bttransformer: bool | None = None,
load_in_mha: bool | None = None,
*args: t.Any,
**attrs: t.Any,
):
@@ -411,9 +418,9 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
pretrained: The pretrained model to use. Defaults to None. It will use 'self.default_model' if None.
llm_config: The config to use for this LLM. Defaults to None. If not passed, we will use 'self.config_class'
to construct default configuration.
load_as_bttransformer: Whether to apply BetterTransformer during inference load.
See https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/
for more information.
load_in_mha: Whether to apply BetterTransformer (or Torch MultiHeadAttention) during inference load.
See https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/
for more information.
*args: The args to be passed to the model.
**attrs: The kwargs to be passed to the model.
"""
@@ -441,11 +448,11 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
self.__llm_args__ = args
self.__llm_kwargs__ = attrs
if load_as_bttransformer is None:
load_as_bttransformer = self.__openllm_bettertransformer__
self._load_as_bttransformer = os.environ.get(
self.config.__openllm_env__.bettertransformer, load_as_bttransformer
)
if load_in_mha is not None:
self.load_in_mha = (
os.environ.get(self.config.__openllm_env__.bettertransformer, str(load_in_mha)).upper()
in ENV_VARS_TRUE_VALUES
)
if self.__openllm_post_init__:
self.__openllm_post_init__(self)
@@ -463,7 +470,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
# NOTE: The section below defines a loose contract with langchain's LLM interface.
@property
def llm_type(self) -> str:
return openllm.utils.convert_transformers_model_name(self._pretrained)
return convert_transformers_model_name(self._pretrained)
@property
def identifying_params(self) -> dict[str, t.Any]:
@@ -476,31 +483,106 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
return self.__llm_bentomodel__
@t.overload
def make_tag(self, return_unused_kwargs: t.Literal[False] = ..., trust_remote_code: bool = ...) -> bentoml.Tag:
def make_tag(
self,
model_name_or_path: str | None = None,
return_unused_kwargs: t.Literal[False] = ...,
trust_remote_code: bool = ...,
**attrs: t.Any,
) -> bentoml.Tag:
...
@t.overload
def make_tag(
self, return_unused_kwargs: t.Literal[True] = ..., trust_remote_code: bool = ...
self,
model_name_or_path: str | None = None,
return_unused_kwargs: t.Literal[True] = ...,
trust_remote_code: bool = ...,
**attrs: t.Any,
) -> tuple[bentoml.Tag, dict[str, t.Any]]:
...
def make_tag(
self, return_unused_kwargs: bool = False, trust_remote_code: bool = False
self,
model_name_or_path: str | None = None,
return_unused_kwargs: bool = False,
trust_remote_code: bool = False,
**attrs: t.Any,
) -> bentoml.Tag | tuple[bentoml.Tag, dict[str, t.Any]]:
tag, kwds = openllm.utils.generate_tags(
self._pretrained,
prefix=self.__llm_implementation__,
trust_remote_code=trust_remote_code,
**self.__llm_kwargs__,
"""Generate a ``bentoml.Tag`` from a given transformers model name.
Note that this depends on your model to have a config class available.
Args:
model_name_or_path: The transformers model name or path to load the model from.
If it is a path, then `openllm_model_version` must be passed in as a kwarg.
return_unused_kwargs: Whether to return unused kwargs. Defaults to False. If set, it will return a tuple
of ``bentoml.Tag`` and a dict of unused kwargs.
trust_remote_code: Whether to trust the remote code. Defaults to False.
**attrs: Additional kwargs to pass to the ``transformers.AutoConfig`` constructor.
If your pass ``return_unused_kwargs=True``, it will be ignored.
Returns:
A tuple of ``bentoml.Tag`` and a dict of unused kwargs.
"""
if model_name_or_path is None:
model_name_or_path = self._pretrained
if "return_unused_kwargs" in attrs:
logger.debug("Ignoring 'return_unused_kwargs' in 'generate_tag_from_model_name'.")
attrs.pop("return_unused_kwargs", None)
config, attrs = t.cast(
"tuple[transformers.PretrainedConfig, dict[str, t.Any]]",
transformers.AutoConfig.from_pretrained(
model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **attrs
),
)
if not return_unused_kwargs:
return tag
return tag, kwds
name = convert_transformers_model_name(model_name_or_path)
if os.path.exists(os.path.dirname(model_name_or_path)):
# If the model_name_or_path is a path, we assume it's a local path,
# then users must pass a version for this.
model_version = attrs.pop("openllm_model_version", None)
if model_version is None:
logger.warning(
"""\
When passing a path, it is recommended to also pass 'openllm_model_version'
into Runner/AutoLLM intialization.
For example:
>>> import openllm
>>> runner = openllm.Runner('/path/to/fine-tuning/model', openllm_model_version='lora-version')
Example with AutoLLM:
>>> import openllm
>>> model = openllm.AutoLLM.for_model('/path/to/fine-tuning/model', openllm_model_version='lora-version')
No worries, OpenLLM will generate one for you. But for your own convenience, make sure to
specify 'openllm_model_version'.
"""
)
model_version = bentoml.Tag.from_taglike(name).make_new_version().version
else:
model_version = getattr(config, "_commit_hash", None)
if model_version is None:
logger.warning(
"Given %s from '%s' doesn't contain a commit hash. We will generate"
" the tag without specific version.",
t.cast("type[transformers.PretrainedConfig]", config.__class__),
model_name_or_path,
)
tag = bentoml.Tag.from_taglike(f"{self.__llm_implementation__}-{name}:{model_version}")
if return_unused_kwargs:
return tag, attrs
return tag
def ensure_pretrained_exists(self):
trust_remote_code = self.__llm_kwargs__.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code)
tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self.__llm_kwargs__)
try:
return bentoml.transformers.get(tag)
except bentoml.exceptions.BentoMLException:
@@ -546,21 +628,22 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
else:
kwds = self.__llm_kwargs__
kwds["trust_remote_code"] = trust_remote_code
if self._load_as_bttransformer and "_pretrained_class" not in self._bentomodel.info.metadata:
if self.load_in_mha and "_pretrained_class" not in self._bentomodel.info.metadata:
# This is a pipeline, provide a accelerator args
kwds["accelerator"] = "bettertransformer"
if self.__llm_model__ is None:
# Hmm, bentoml.transformers.load_model doesn't yet support args.
self.__llm_model__ = self._bentomodel.load_model(*self.__llm_args__, **kwds)
if (
self._load_as_bttransformer
and "_framework" in self._bentomodel.info.metadata
self.load_in_mha
and all(i in self._bentomodel.info.metadata for i in ("_framework", "_pretrained_class"))
and self._bentomodel.info.metadata["_framework"] == "torch"
):
from optimum.bettertransformer import BetterTransformer
return BetterTransformer.transform(self.__llm_model__)
self.__llm_model__ = BetterTransformer.transform(self.__llm_model__)
return self.__llm_model__
@property

View File

@@ -40,7 +40,8 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
class ChatGLM(openllm.LLM):
__openllm_internal__ = True
__openllm_bettertransformer__ = False # NOTE: disable bettertransformer for ChatGLM since it is already quantized
load_in_mha = False # NOTE: disable bettertransformer for ChatGLM since it is already quantized
default_model = "THUDM/chatglm-6b-int4"

View File

@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
class StableLM(openllm.LLM):
__openllm_internal__ = True
__openllm_bettertransformer__ = False
load_in_mha = False
default_model = "StabilityAI/stablelm-tuned-alpha-3b"

View File

@@ -22,12 +22,10 @@ import importlib.machinery
import itertools
import logging
import os
import re
import types
import typing as t
import attrs
import bentoml
import inflection
from bentoml._internal.types import LazyType as LazyType
@@ -42,6 +40,7 @@ from bentoml._internal.utils import resolve_user_filepath as resolve_user_filepa
from . import analytics as analytics
from . import codegen as codegen
from . import dantic as dantic
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
from .import_utils import DummyMetaclass as DummyMetaclass
from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available
from .import_utils import is_einops_available as is_einops_available
@@ -104,76 +103,6 @@ class ModelEnv:
return envvar
def convert_transformers_model_name(name: str) -> str:
if os.path.exists(os.path.dirname(name)):
name = os.path.basename(name)
logger.debug("Given name is a path, only returning the basename %s")
return name
return re.sub("[^a-zA-Z0-9]+", "-", name)
def generate_tags(
model_name_or_path: str, prefix: str | None = None, **attrs: t.Any
) -> tuple[bentoml.Tag, dict[str, t.Any]]:
"""Generate a ``bentoml.Tag`` from a given transformers model name.
Note that this depends on your model to have a config class available.
Args:
model_name_or_path: The transformers model name or path to load the model from.
If it is a path, then `openllm_model_version` must be passed in as a kwarg.
prefix: The prefix to prepend to the tag. If None, then no prefix will be prepended.
**attrs: Additional kwargs to pass to the ``transformers.AutoConfig`` constructor.
If your pass ``return_unused_kwargs=True``, it will be ignored.
Returns:
A tuple of ``bentoml.Tag`` and a dict of unused kwargs.
"""
if "return_unused_kwargs" in attrs:
logger.debug("Ignoring 'return_unused_kwargs' in 'generate_tag_from_model_name'.")
attrs.pop("return_unused_kwargs", None)
config, attrs = t.cast(
"tuple[transformers.PretrainedConfig, dict[str, t.Any]]",
transformers.AutoConfig.from_pretrained(model_name_or_path, return_unused_kwargs=True, **attrs),
)
name = convert_transformers_model_name(model_name_or_path)
if os.path.exists(os.path.dirname(model_name_or_path)):
# If the model_name_or_path is a path, we assume it's a local path,
# then users must pass a version for this.
model_version = attrs.pop("openllm_model_version", None)
if model_version is None:
logger.warning(
"""\
When passing a path, it is recommended to also pass 'openllm_model_version' into Runner/AutoLLM intialization.
For example:
>>> import openllm
>>> runner = openllm.Runner('/path/to/fine-tuning/model', openllm_model_version='lora-version')
Example with AutoLLM:
>>> import openllm
>>> model = openllm.AutoLLM.for_model('/path/to/fine-tuning/model', openllm_model_version='lora-version')
No worries, OpenLLM will generate one for you. But for your own convenience, make sure to
specify 'openllm_model_version'.
"""
)
model_version = bentoml.Tag.from_taglike(name).make_new_version().version
else:
model_version = getattr(config, "_commit_hash", None)
if model_version is None:
logger.warning(
"Given %s from '%s' doesn't contain a commit hash. We will generate the tag without specific version.",
t.cast("type[transformers.PretrainedConfig]", config.__class__),
model_name_or_path,
)
return bentoml.Tag.from_taglike((f"{prefix}-" if prefix is not None else "") + f"{name}:{model_version}"), attrs
class LazyModule(types.ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.

View File

@@ -14,10 +14,17 @@ import click
import orjson
from click import ParamType
from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass
from pydantic.version import VERSION as PYDANTIC_VERSION
from . import LazyType
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
if PYDANTIC_V2:
from pydantic._internal._utils import lenient_issubclass
else:
from pydantic.utils import lenient_issubclass as lenient_issubclass
def parse_default(default: t.Any, field_type: t.Any) -> t.Any:
"""Converts pydantic defaults into click default types.