fix: generate attrs class internally to conform with interface

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-06-01 19:06:06 +00:00
parent 84358b28cd
commit a94294bc65
8 changed files with 557 additions and 356 deletions

View File

@@ -38,10 +38,10 @@ class FlanT5Config(openllm.LLMConfig):
"""
from __future__ import annotations
import functools
import logging
import os
import typing as t
from operator import itemgetter
import attr
import click
@@ -59,6 +59,7 @@ if t.TYPE_CHECKING:
import tensorflow as tf
import torch
import transformers
from attr import _CountingAttr, _make_init
from transformers.generation.beam_constraints import Constraint
P = t.ParamSpec("P")
@@ -68,9 +69,15 @@ if t.TYPE_CHECKING:
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
DictStrAny = dict[str, t.Any]
ItemgetterAny = itemgetter[t.Any]
else:
Constraint = t.Any
DictStrAny = dict
ItemgetterAny = itemgetter
# NOTE: Using internal API from attr here, since we are actually
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
from attr._make import _CountingAttr, _make_init
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
tf = openllm.utils.LazyLoader("tf", globals(), "tensorflow")
@@ -115,9 +122,6 @@ def attrs_to_options(
)
_IGNORE_FIELDS = ("__openllm_name_type__", "generation_config")
@attr.define
class GenerationConfig:
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
@@ -128,7 +132,9 @@ class GenerationConfig:
# NOTE: parameters for controlling the length of the output
max_new_tokens: int = dantic.Field(
20, ge=0, description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."
20,
ge=0,
description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.",
)
min_length: int = dantic.Field(
0,
@@ -137,7 +143,7 @@ class GenerationConfig:
input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.""",
)
min_new_tokens: int = dantic.Field(
None, description="The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt."
description="The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.",
)
early_stopping: bool = dantic.Field(
False,
@@ -149,7 +155,6 @@ class GenerationConfig:
""",
)
max_time: float = dantic.Field(
None,
description="""The maximum amount of time you allow the computation to run for in seconds. generation will
still finish the current pass after allocated time has been passed.""",
)
@@ -162,7 +167,6 @@ class GenerationConfig:
groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.""",
)
penalty_alpha: float = dantic.Field(
None,
description="""The values balance the model confidence and the degeneration penalty in
contrastive search decoding.""",
)
@@ -241,7 +245,6 @@ class GenerationConfig:
0, description="If set to int > 0, all ngrams of that size can only occur once."
)
bad_words_ids: t.List[t.List[int]] = dantic.Field(
None,
description="""List of token ids that are not allowed to be generated. In order to get the token ids
of the words that should not appear in the generated text, use
`tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`.
@@ -250,7 +253,6 @@ class GenerationConfig:
# NOTE: t.Union is not yet supported on CLI, but the environment variable should already be available.
force_words_ids: t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]] = dantic.Field(
None,
description="""List of token ids that must be generated. If given a `List[List[int]]`, this is treated
as a simple list of words that must be included, the opposite to `bad_words_ids`.
If given `List[List[List[int]]]`, this triggers a
@@ -266,13 +268,11 @@ class GenerationConfig:
""",
)
constraints: t.List[Constraint] = dantic.Field(
None,
description="""Custom constraints that can be added to the generation to ensure that the output
description="""Custom constraints that can be added to the generation to ensure that the output
will contain the use of certain tokens as defined by ``Constraint`` objects, in the most sensible way possible.
""",
)
forced_bos_token_id: int = dantic.Field(
None,
description="""The id of the token to force as the first generated token after the
``decoder_start_token_id``. Useful for multilingual models like
[mBART](https://huggingface.co/docs/transformers/model_doc/mbart) where the first generated token needs
@@ -280,7 +280,6 @@ class GenerationConfig:
""",
)
forced_eos_token_id: t.Union[int, t.List[int]] = dantic.Field(
None,
description="""The id of the token to force as the last generated token when `max_length` is reached.
Optionally, use a list to set multiple *end-of-sequence* tokens.""",
)
@@ -290,26 +289,22 @@ class GenerationConfig:
generation method to crash. Note that using `remove_invalid_values` can slow down generation.""",
)
exponential_decay_length_penalty: t.Tuple[int, float] = dantic.Field(
None,
description="""This tuple adds an exponentially increasing length penalty, after a certain amount of tokens
have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index`
indicates where penalty starts and `decay_factor` represents the factor of exponential decay
""",
)
suppress_tokens: t.List[int] = dantic.Field(
None,
description="""A list of tokens that will be suppressed at generation. The `SupressTokens` logit
processor will set their log probs to `-inf` so that they are not sampled.
""",
)
begin_suppress_tokens: t.List[int] = dantic.Field(
None,
description="""A list of tokens that will be suppressed at the beginning of the generation. The
`SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
""",
)
forced_decoder_ids: t.List[t.List[int]] = dantic.Field(
None,
description="""A list of pairs of integers which indicates a mapping from generation indices to token indices
that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
@@ -338,10 +333,9 @@ class GenerationConfig:
)
# NOTE: Special tokens that can be used at generation time
pad_token_id: int = dantic.Field(None, description="The id of the *padding* token.")
bos_token_id: int = dantic.Field(None, description="The id of the *beginning-of-sequence* token.")
pad_token_id: int = dantic.Field(description="The id of the *padding* token.")
bos_token_id: int = dantic.Field(description="The id of the *beginning-of-sequence* token.")
eos_token_id: t.Union[int, t.List[int]] = dantic.Field(
None,
description="""The id of the *end-of-sequence* token. Optionally, use a list to set
multiple *end-of-sequence* tokens.""",
)
@@ -354,7 +348,6 @@ class GenerationConfig:
""",
)
decoder_start_token_id: int = dantic.Field(
None,
description="""If an encoder-decoder model starts decoding with a
different token than *bos*, the id of that token.
""",
@@ -373,43 +366,223 @@ class GenerationConfig:
)
self.__attrs_init__(**attrs)
def model_dump(self, exclude_none: bool = False, **_: t.Any) -> dict[str, t.Any]:
target: dict[str, t.Any] = {}
for k in attr.fields_dict(self.__class__):
v = getattr(self, k, None)
if exclude_none and v is None:
continue
if not k.startswith("_"):
target[k] = v
return target
def generation_config_dump(genconf: GenerationConfig) -> dict[str, t.Any]:
target: dict[str, t.Any] = {}
for k in attr.fields_dict(genconf.__class__):
v = getattr(genconf, k, None)
if v is None:
continue
if not k.startswith("_"):
target[k] = v
return target
bentoml_cattr.register_unstructure_hook(
GenerationConfig, functools.partial(GenerationConfig.model_dump, exclude_none=True)
bentoml_cattr.register_unstructure_hook_func(
lambda cls: lenient_issubclass(cls, GenerationConfig),
generation_config_dump,
)
def _populate_value_from_env_var(
key: str, transform: t.Callable[[str], str] | None = None, fallback: t.Any = None
) -> t.Any:
if transform is not None and callable(transform):
key = transform(key)
return os.environ.get(key, fallback)
def env_transformers(cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]:
transformed: list[attr.Attribute[t.Any]] = []
for f in fields:
default = os.environ.get(f.metadata["env"], None)
if default is not None:
try:
default = orjson.loads(default)
except orjson.JSONDecodeError as err:
raise ValueError(f"Failed to load from environment variables: {err}")
else:
default = f.default
transformed.append(f.evolve(default=default))
if "env" not in f.metadata:
raise ValueError(
"Make sure to setup the field with 'cls.Field' or 'attr.field(..., metadata={\"env\": \"...\"})'"
)
_from_env = _populate_value_from_env_var(f.metadata["env"])
if _from_env is not None:
f = f.evolve(default=_from_env)
transformed.append(f)
return transformed
# sentinel object for unequivocal object() getattr
_sentinel = object()
def _has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
"""
Check whether *cls* defines *attrib_name* (and doesn't just inherit it).
"""
attr = getattr(cls, attrib_name, _sentinel)
if attr is _sentinel:
return False
for base_cls in cls.__mro__[1:]:
a = getattr(base_cls, attrib_name, None)
if attr is a:
return False
return True
def _get_annotations(cls: type[t.Any]) -> DictStrAny:
"""
Get annotations for *cls*.
"""
if _has_own_attribute(cls, "__annotations__"):
return cls.__annotations__
return DictStrAny()
# The below is vendorred from attrs
def _collect_base_attrs(
cls: type[LLMConfig], taken_attr_names: set[str]
) -> tuple[list[attr.Attribute[t.Any]], dict[str, type[t.Any]]]:
"""
Collect attr.ibs from base classes of *cls*, except *taken_attr_names*.
"""
base_attrs: list[attr.Attribute[t.Any]] = []
base_attr_map: dict[str, type[t.Any]] = {} # A dictionary of base attrs to their classes.
# Traverse the MRO and collect attributes.
for base_cls in reversed(cls.__mro__[1:-1]):
for a in getattr(base_cls, "__attrs_attrs__", []):
if a.inherited or a.name in taken_attr_names:
continue
a = a.evolve(inherited=True)
base_attrs.append(a)
base_attr_map[a.name] = base_cls
# For each name, only keep the freshest definition i.e. the furthest at the back.
filtered: list[attr.Attribute[t.Any]] = []
seen: set[str] = set()
for a in reversed(base_attrs):
if a.name in seen:
continue
filtered.insert(0, a)
seen.add(a.name)
return filtered, base_attr_map
_classvar_prefixes = (
"typing.ClassVar",
"t.ClassVar",
"ClassVar",
"typing_extensions.ClassVar",
)
def _is_class_var(annot: str | t.Any) -> bool:
"""
Check whether *annot* is a typing.ClassVar.
The string comparison hack is used to avoid evaluating all string
annotations which would put attrs-based classes at a performance
disadvantage compared to plain old classes.
"""
annot = str(annot)
# Annotation can be quoted.
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
annot = annot[1:-1]
return annot.startswith(_classvar_prefixes)
def _add_method_dunders(cls: type[t.Any], method: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
"""
Add __module__ and __qualname__ to a *method* if possible.
"""
try:
method.__module__ = cls.__module__
except AttributeError:
pass
try:
method.__qualname__ = ".".join((cls.__qualname__, method.__name__))
except AttributeError:
pass
try:
method.__doc__ = "Method generated by attrs for class " f"{cls.__qualname__}."
except AttributeError:
pass
return method
# NOTE: vendorred from attrs
def _compile_and_eval(script: str, globs: dict[str, t.Any], locs: dict[str, t.Any] | None = None, filename: str = ""):
"""
"Exec" the script with the given global (globs) and local (locs) variables.
"""
bytecode = compile(script, filename, "exec")
eval(bytecode, globs, locs)
def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[tuple[attr.Attribute[t.Any], ...]]:
"""
Create a tuple subclass to hold `Attribute`s for an `attrs` class.
The subclass is a bare tuple with properties for names.
class MyClassAttributes(tuple):
__slots__ = ()
x = property(itemgetter(0))
"""
attr_class_name = f"{cls_name}Attributes"
attr_class_template = [
f"class {attr_class_name}(tuple):",
" __slots__ = ()",
]
if attr_names:
for i, attr_name in enumerate(attr_names):
attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
else:
attr_class_template.append(" pass")
globs: dict[str, t.Any] = {"_attrs_itemgetter": ItemgetterAny, "_attrs_property": property}
_compile_and_eval("\n".join(attr_class_template), globs)
return globs[attr_class_name]
def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]:
attribs: DictStrAny = {}
_has_gen_class = _has_own_attribute(cls, "GenerationConfig")
for key, field in attr.fields_dict(GenerationConfig).items():
class_gen_value = cls.GenerationConfig.__dict__.get(key, attr.NOTHING) if _has_gen_class else attr.NOTHING
attribs[key] = cls.Field(
class_gen_value if class_gen_value is not attr.NOTHING else field.default,
description=field.metadata.get("description"),
env=f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{key.upper()}",
validator=field.validator,
)
if _has_gen_class:
delattr(cls, "GenerationConfig")
return attr.make_class(
cls.__name__.replace("Config", "GenerationConfig"), attribs, field_transformer=env_transformers
)
@attr.define
class LLMConfig:
Field = dantic.Field
"""Field is a alias to the internal dantic utilities to easily create
attrs.fields with pydantic-compatible interface.
"""
if t.TYPE_CHECKING:
def __attrs_init__(self, **attrs: t.Any):
...
# The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
__attrs_attrs__: tuple[attr.Attribute[t.Any], ...] = tuple()
__openllm_attrs__: tuple[str, ...] = tuple()
__openllm_timeout__: int = 3600
@@ -420,9 +593,9 @@ class LLMConfig:
__openllm_start_name__: str = ""
__openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize"
__openllm_env__: openllm.utils.ModelEnv = dantic.Field(None, init=False)
__openllm_env__: openllm.utils.ModelEnv = Field(None, init=False)
generation_class: type[GenerationConfig] = dantic.Field(None, init=False)
generation_class: type[GenerationConfig] = Field(None, init=False)
GenerationConfig: type = type
@@ -442,36 +615,6 @@ class LLMConfig:
start_name = model_name
cls.__openllm_name_type__ = name_type
attributes = {
key: dantic.Field(
field.default,
description=field.metadata.get("description"),
env=f"OPENLLM_{model_name.upper()}_GENERATION_{key.upper()}",
validator=field.validator,
)
for key, field in attr.fields_dict(GenerationConfig).items()
}
generation_class: type[GenerationConfig]
if hasattr(cls, "GenerationConfig"):
generation_class = attr.make_class(
cls.__name__.replace("Config", "GenerationConfig"),
attributes,
field_transformer=env_transformers,
)
delattr(cls, "GenerationConfig")
else:
generation_class = attr.make_class(
"GenerationConfig",
attributes,
field_transformer=env_transformers,
)
generation_class.model_dump = GenerationConfig.model_dump # type: ignore
# Set the generation_config attributes here.
cls.generation_class = generation_class
cls.__openllm_requires_gpu__ = requires_gpu
cls.__openllm_timeout__ = default_timeout or 3600
cls.__openllm_trust_remote_code__ = trust_remote_code
@@ -480,51 +623,100 @@ class LLMConfig:
cls.__openllm_start_name__ = start_name
cls.__openllm_env__ = openllm.utils.ModelEnv(model_name)
if hasattr(cls, "__annotations__"):
anns = cls.__annotations__
else:
anns = {}
# NOTE: Since we want to enable a pydantic-like experience
# this means we will have to hide the attr abstraction, and generate
# all of the Field from __init_subclass__
# Some of the logics here are from attr._make._transform_attrs
anns = _get_annotations(cls)
cd = cls.__dict__
def field_env_key(key: str) -> str:
return f"OPENLLM_{model_name.upper()}_{key.upper()}"
ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)}
ca_list: list[tuple[str, _CountingAttr[t.Any]]] = []
annotated_names: set[str] = set()
for attr_name, typ in anns.items():
if _is_class_var(typ):
continue
annotated_names.add(attr_name)
val = cd.get(attr_name, attr.NOTHING)
if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val):
if val is attr.NOTHING:
val = cls.Field(env=field_env_key(attr_name))
else:
val = cls.Field(default=val, env=field_env_key(attr_name))
ca_list.append((attr_name, val))
unannotated = ca_names - annotated_names
if len(unannotated) > 0:
missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter)
raise openllm.exceptions.MissingAnnotationAttributeError(
f"The following field doesn't have a type annotation: {missing_annotated}"
)
# NOTE: we know need to determine the list of the attrs
# by mro to at the very least support inheritance. Tho it is not recommended.
own_attrs: list[attr.Attribute[t.Any]] = [
attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=anns.get(attr_name))
for attr_name, ca in ca_list
]
base_attrs, base_attr_map = _collect_base_attrs(cls, {a.name for a in own_attrs})
attrs: list[attr.Attribute[t.Any]] = own_attrs + base_attrs
# Mandatory vs non-mandatory attr order only matters when they are part of
# the __init__ signature and when they aren't kw_only (which are moved to
# the end and can be mandatory or non-mandatory in any order, as they will
# be specified as keyword args anyway). Check the order of those attrs:
had_default = False
for a in (a for a in attrs if a.init is not False and a.kw_only is False):
if had_default is True and a.default is attr.NOTHING:
raise ValueError(
"No mandatory attributes allowed after an attribute with a "
f"default value or factory. Attribute in question: {a!r}"
)
if had_default is False and a.default is not attr.NOTHING:
had_default = True
# NOTE: Resolve the alias and default value from environment variable
attrs = [
a.evolve(
alias=a.name.lstrip("_") if not a.alias else None,
# NOTE: This is where we actually populate with the environment variable set for this attrs.
default=_populate_value_from_env_var(a.name, transform=field_env_key, fallback=a.default),
)
for a in attrs
]
_has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
_has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
# __openllm_attrs__ is a tracking tuple[attr.Attribute[t.Any]]
# that we construct ourself.
openllm_attrs: tuple[str, ...] = tuple()
cur_attrs: tuple[attr.Attribute[t.Any]] = tuple()
cls.__openllm_attrs__ = tuple(a.name for a in attrs)
AttrsTuple = _make_attr_tuple_class(cls.__name__, cls.__openllm_attrs__)
# NOTE: generate a __attrs_init__ for the subclass
cls.__attrs_init__ = _add_method_dunders(
cls,
_make_init(
cls,
AttrsTuple(attrs),
_has_pre_init,
_has_post_init,
False,
True,
True,
base_attr_map,
False,
None,
attrs_init=True,
),
)
cls.__attrs_attrs__ = AttrsTuple(attrs)
for key, value in vars(cls).items():
if key in _IGNORE_FIELDS:
continue
# NOTE: we probably want to decorate all of the function in LLMConfig
# with an internal flag, which then allow user to define their own
# function and do whatever they want with the LLMConfig. Currently, this is
# a limitation.
if not key.startswith("_") and not callable(value):
env_key = f"OPENLLM_{model_name.upper()}_{key.upper()}"
default = os.environ.get(env_key, None)
if default is not None:
try:
default = orjson.loads(default)
except orjson.JSONDecodeError as err:
raise ValueError(f"Failed to load from environment variables: {err}")
else:
default = value
annotation = anns.get(key, None)
if annotation is not None:
# NOTE: eval is dangerous, but we don't provide any specific imports here.
annotation = eval(annotation, {}, {})
attribute: attr.Attribute[t.Any] = attr.Attribute.from_counting_attr(
key, dantic.Field(default, env=env_key, alias=key), type=annotation
)
openllm_attrs += (key,)
cur_attrs += (attribute,)
setattr(cls, key, attribute.default)
cls.__openllm_attrs__ = openllm_attrs
cls.__attrs_attrs__ = cur_attrs
# NOTE: Finally, set the generation_class for this given config.
cls.generation_class = _make_internal_generation_class(cls)
@property
def name_type(self) -> t.Literal["dasherize", "lowercase"]:
@@ -537,8 +729,6 @@ class LLMConfig:
__openllm_extras__: dict[str, t.Any] | None = None,
**attrs: t.Any,
):
self.__openllm_env__ = openllm.utils.ModelEnv(self.__openllm_model_name__)
to_exclude = list(attr.fields_dict(self.generation_class)) + list(self.__openllm_attrs__)
self.__openllm_extras__ = __openllm_extras__ or {k: v for k, v in attrs.items() if k not in to_exclude}
@@ -546,22 +736,18 @@ class LLMConfig:
if generation_config is None:
generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(self.generation_class)}
else:
logger.debug("Overriding default 'generation_config' with '%s'", generation_config)
self.generation_config = self.generation_class(**generation_config)
# NOTE: since our subclass is not a dataclass-like, we need to set attr like this.
for k, fields in attr.fields_dict(self.__class__).items():
if k in attrs and attrs[k] != fields.default:
setattr(self, k, attrs[k])
else:
setattr(self, k, fields.default)
attrs = {k: v for k, v in attrs.items() if k not in generation_config}
# set the remaning attrs to class
for k, v in attrs.items():
if k not in self.__openllm_attrs__:
setattr(self, k, v)
extras = set(attrs).difference(set(attr.fields_dict(self.__class__)))
if len(extras) > 0:
# Set all of the keys in extras to the class
for k in extras:
setattr(self, k, attrs[k])
self.__attrs_init__(**{k: v for k, v in attrs.items() if k not in extras})
def __repr__(self) -> str:
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={self.generation_config}"

View File

@@ -272,14 +272,11 @@ class LLMMetaclass(ABCMeta):
if cls_name.startswith("Flax"):
implementation = "flax"
prefix_class_name_config = cls_name[4:]
namespace["__annotations__"].update({"model": "transformers.FlaxPreTrainedModel"})
elif cls_name.startswith("TF"):
implementation = "tf"
prefix_class_name_config = cls_name[2:]
namespace["__annotations__"].update({"model": "transformers.TFPreTrainedModel"})
else:
implementation = "pt"
namespace["__annotations__"].update({"model": "transformers.PreTrainedModel"})
namespace["__llm_implementation__"] = implementation
# NOTE: setup config class branch

View File

@@ -166,6 +166,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
self._cached_grpc: dict[str, t.Any] = {}
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
breakpoint()
cmd_name = self.resolve_alias(cmd_name)
if ctx.command.name == "start":
if cmd_name not in self._cached_http:
@@ -465,213 +466,217 @@ output_option = click.option(
)
@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm")
def cli():
"""
\b
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
██╔═══██╗██╔══██╗██╔════╝████╗ ████ ██ ███╗ ████║
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ████████║
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔██║
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
def cli_builder():
@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm")
def cli():
"""
\b
██████╗ ██████╗ ███████╗███╗ ████ ██ ███╗ ███
██╔═══████╔══██╗██╔════╝████╗ ██║██║ ██║ ████████║
██║ ██║██████╔╝█████╗ ████╗ ██║██║ ██║ ██╔████╔██║
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗████║ ██║ ██║╚██╔╝██║
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
\b
OpenLLM: Your one stop-and-go-solution for serving any Open Large-Language Model
\b
OpenLLM: Your one stop-and-go-solution for serving any Open Large-Language Model
- StableLM, Falcon, ChatGLM, Dolly, Flan-T5, and more
- StableLM, Falcon, ChatGLM, Dolly, Flan-T5, and more
\b
- Powered by BentoML 🍱 + HuggingFace 🤗
"""
@cli.command(name="version")
@output_option
@click.pass_context
def _(ctx: click.Context, output: t.Literal["json", "pretty", "porcelain"]):
"""🚀 OpenLLM version."""
from gettext import gettext
from .__about__ import __version__
message = gettext("%(prog)s, version %(version)s")
version = __version__
prog_name = ctx.find_root().info_name
if output == "pretty":
click.echo(message % {"prog": prog_name, "version": version}, color=ctx.color)
elif output == "json":
click.echo(orjson.dumps({"version": version}, option=orjson.OPT_INDENT_2).decode())
else:
click.echo(version)
ctx.exit()
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start")
def _():
"""
Start any LLM as a REST server.
$ openllm start <model_name> --<options> ...
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc")
def _():
"""
Start any LLM as a gRPC server.
$ openllm start-grpc <model_name> --<options> ...
"""
@cli.command(name="bundle", aliases=["build"])
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].")
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
@output_option
def _(model_name: str, pretrained: str | None, overwrite: bool, output: t.Literal["json", "pretty", "porcelain"]):
"""Package a given models into a Bento.
$ openllm bundle flan-t5
"""
from bentoml._internal.configuration import get_quiet_mode
bento, _previously_built = openllm.build(
model_name, __cli__=True, pretrained=pretrained, _overwrite_existing_bento=overwrite
)
if output == "pretty":
if not get_quiet_mode():
click.echo("\n" + OPENLLM_FIGLET)
if not _previously_built:
click.secho(f"Successfully built {bento}.", fg="green")
else:
click.secho(
f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.",
fg="yellow",
)
click.secho(
"\nPossible next steps:\n\n * Push to BentoCloud with `bentoml push`:\n "
+ f"$ bentoml push {bento.tag}",
fg="blue",
)
click.secho(
"\n * Containerize your Bento with `bentoml containerize`:\n "
+ f"$ bentoml containerize {bento.tag}",
fg="blue",
)
elif output == "json":
click.secho(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
else:
click.echo(bento.tag)
return bento
@cli.command(name="models")
@output_option
def _(output: t.Literal["json", "pretty", "porcelain"]):
"""List all supported models."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
failed_initialized: list[tuple[str, Exception]] = []
if output == "pretty":
import rich
import rich.box
from rich.table import Table
from rich.text import Text
console = rich.get_console()
table = Table(title="Supported LLMs", box=rich.box.SQUARE, show_lines=True)
table.add_column("LLM")
table.add_column("Description")
table.add_column("Variants")
for m in models:
docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)")
try:
model = openllm.AutoLLM.for_model(m)
table.add_row(m, docs, f"{model.variants}")
except Exception as err:
failed_initialized.append((m, err))
console.print(table)
if len(failed_initialized) > 0:
console.print(
"\n[bold yellow] The following models are supported but failed to initialize:[/bold yellow]\n"
)
for m, err in failed_initialized:
console.print(Text(f"- {m}: ") + Text(f"{err}\n", style="bold red"))
elif output == "json":
result_json: dict[str, dict[t.Literal["variants", "description"], t.Any]] = {}
for m in models:
docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)")
try:
model = openllm.AutoLLM.for_model(m)
result_json[m] = {"variants": model.variants, "description": docs}
except Exception as err:
logger.debug("Exception caught while parsing model %s", m, exc_info=err)
result_json[m] = {"variants": None, "description": docs}
click.secho(orjson.dumps(result_json, option=orjson.OPT_INDENT_2).decode())
else:
click.echo("\n".join(models))
sys.exit(0)
@cli.command(name="download_models")
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@click.option(
"--pretrained", type=click.STRING, default=None, help="Optional pretrained name or path to fine-tune weight."
)
@output_option
def _(model_name: str, pretrained: str | None, output: t.Literal["json", "pretty", "porcelain"]):
"""Setup LLM interactively.
Note: This is useful for development and setup for fine-tune.
"""
config = openllm.AutoConfig.for_model(model_name)
env = config.__openllm_env__.get_framework_env()
if env == "flax":
model = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
elif env == "tf":
model = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
else:
model = openllm.AutoLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
tag = model.make_tag()
if len(bentoml.models.list(tag)) == 0:
if output == "pretty":
click.secho(f"{tag} does not exists yet!. Downloading...", nl=True)
m = model.ensure_pretrained_exists()
click.secho(f"Saved model: {m.tag}")
elif output == "json":
m = model.ensure_pretrained_exists()
click.secho(
orjson.dumps(
{"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2
).decode()
)
else:
m = model.ensure_pretrained_exists()
click.secho(m.tag)
else:
m = model.ensure_pretrained_exists()
if output == "pretty":
click.secho(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True)
elif output == "json":
click.secho(
orjson.dumps(
{"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2
).decode()
)
else:
click.echo(m.tag)
return m
\b
- Powered by BentoML 🍱 + HuggingFace 🤗
"""
if psutil.WINDOWS:
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
@cli.command()
@output_option
@click.pass_context
def version(ctx: click.Context, output: t.Literal["json", "pretty", "porcelain"]):
"""🚀 OpenLLM version."""
from gettext import gettext
from .__about__ import __version__
message = gettext("%(prog)s, version %(version)s")
version = __version__
prog_name = ctx.find_root().info_name
if output == "pretty":
click.echo(message % {"prog": prog_name, "version": version}, color=ctx.color)
elif output == "json":
click.echo(orjson.dumps({"version": version}, option=orjson.OPT_INDENT_2).decode())
else:
click.echo(version)
ctx.exit()
return cli
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start")
def start_cli():
"""
Start any LLM as a REST server.
$ openllm start <model_name> --<options> ...
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc")
def start_grpc_cli():
"""
Start any LLM as a gRPC server.
$ openllm start-grpc <model_name> --<options> ...
"""
@cli.command(aliases=["build"])
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
@click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].")
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
@output_option
def bundle(model_name: str, pretrained: str | None, overwrite: bool, output: t.Literal["json", "pretty", "porcelain"]):
"""Package a given models into a Bento.
$ openllm bundle flan-t5
"""
from bentoml._internal.configuration import get_quiet_mode
bento, _previously_built = openllm.build(
model_name, __cli__=True, pretrained=pretrained, _overwrite_existing_bento=overwrite
)
if output == "pretty":
if not get_quiet_mode():
click.echo("\n" + OPENLLM_FIGLET)
if not _previously_built:
click.secho(f"Successfully built {bento}.", fg="green")
else:
click.secho(
f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.",
fg="yellow",
)
click.secho(
"\nPossible next steps:\n\n * Push to BentoCloud with `bentoml push`:\n "
+ f"$ bentoml push {bento.tag}",
fg="blue",
)
click.secho(
"\n * Containerize your Bento with `bentoml containerize`:\n "
+ f"$ bentoml containerize {bento.tag}",
fg="blue",
)
elif output == "json":
click.secho(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
else:
click.echo(bento.tag)
return bento
@cli.command()
@output_option
def models(output: t.Literal["json", "pretty", "porcelain"]):
"""List all supported models."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
failed_initialized: list[tuple[str, Exception]] = []
if output == "pretty":
import rich
import rich.box
from rich.table import Table
from rich.text import Text
console = rich.get_console()
table = Table(title="Supported LLMs", box=rich.box.SQUARE, show_lines=True)
table.add_column("LLM")
table.add_column("Description")
table.add_column("Variants")
for m in models:
docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)")
try:
model = openllm.AutoLLM.for_model(m)
table.add_row(m, docs, f"{model.variants}")
except Exception as err:
failed_initialized.append((m, err))
console.print(table)
if len(failed_initialized) > 0:
console.print(
"\n[bold yellow] The following models are supported but failed to initialize:[/bold yellow]\n"
)
for m, err in failed_initialized:
console.print(Text(f"- {m}: ") + Text(f"{err}\n", style="bold red"))
elif output == "json":
result_json: dict[str, dict[t.Literal["variants", "description"], t.Any]] = {}
for m in models:
docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)")
try:
model = openllm.AutoLLM.for_model(m)
result_json[m] = {"variants": model.variants, "description": docs}
except Exception as err:
logger.debug("Exception caught while parsing model %s", m, exc_info=err)
result_json[m] = {"variants": None, "description": docs}
click.secho(orjson.dumps(result_json, option=orjson.OPT_INDENT_2).decode())
else:
click.echo("\n".join(models))
sys.exit(0)
@cli.command()
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
@click.option(
"--pretrained", type=click.STRING, default=None, help="Optional pretrained name or path to fine-tune weight."
)
@output_option
def download_models(model_name: str, pretrained: str | None, output: t.Literal["json", "pretty", "porcelain"]):
"""Setup LLM interactively.
Note: This is useful for development and setup for fine-tune.
"""
config = openllm.AutoConfig.for_model(model_name)
env = config.__openllm_env__.get_framework_env()
if env == "flax":
model = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
elif env == "tf":
model = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
else:
model = openllm.AutoLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
tag = model.make_tag()
if len(bentoml.models.list(tag)) == 0:
if output == "pretty":
click.secho(f"{tag} does not exists yet!. Downloading...", nl=True)
m = model.ensure_pretrained_exists()
click.secho(f"Saved model: {m.tag}")
elif output == "json":
m = model.ensure_pretrained_exists()
click.secho(
orjson.dumps(
{"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2
).decode()
)
else:
m = model.ensure_pretrained_exists()
click.secho(m.tag)
else:
m = model.ensure_pretrained_exists()
if output == "pretty":
click.secho(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True)
elif output == "json":
click.secho(
orjson.dumps(
{"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2
).decode()
)
else:
click.echo(m.tag)
return m
cli = cli_builder()
if __name__ == "__main__":
cli()

View File

@@ -35,5 +35,9 @@ class ForbiddenAttributeError(OpenLLMException):
"""Raised when using an _internal field."""
class MissingAnnotationAttributeError(OpenLLMException):
"""Raised when a field under openllm.LLMConfig is missing annotations."""
class MissingDependencyError(BaseException):
"""Raised when a dependency is missing."""

View File

@@ -34,7 +34,9 @@ class DollyV2Config(openllm.LLMConfig, default_timeout=3600000):
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
"""
return_full_text: bool = False
return_full_text: bool = openllm.LLMConfig.Field(
False, description="Whether to return the full prompt to the users."
)
class GenerationConfig:
temperature: float = 0.9

View File

@@ -15,6 +15,8 @@
from __future__ import annotations
import functools
import os
import typing as t
import attr
@@ -30,8 +32,16 @@ if t.TYPE_CHECKING:
_T = t.TypeVar("_T")
def _default_converter(value: t.Any, env: str | None) -> t.Any:
if env is not None:
value = os.environ.get(env, value)
if value is not None and isinstance(value, str):
return eval(value, {"__builtins__": {}}, {})
return value
def Field(
value: t.Any,
default: t.Any = None,
*,
ge: int | float | None = None,
le: int | float | None = None,
@@ -52,10 +62,6 @@ def Field(
**kwargs: The rest of the arguments are passed to attr.field
"""
metadata = attrs.pop("metadata", {})
default = attrs.pop("default", value)
if default is not value:
raise ValueError("Either specify 'default=value' or provide 'value' as the only argument")
if description is None:
description = "(No description is available)"
metadata["description"] = description
@@ -63,6 +69,8 @@ def Field(
metadata["env"] = env
piped: list[_ValidatorType[t.Any]] = []
converter = attrs.pop("converter", functools.partial(_default_converter, env=env))
if ge is not None:
piped.append(attr.validators.ge(ge))
if le is not None:
@@ -77,7 +85,7 @@ def Field(
else:
_validator = attr.validators.and_(*piped)
return attr.field(default=default, metadata=metadata, validator=_validator, **attrs)
return attr.field(default=default, metadata=metadata, validator=_validator, converter=converter, **attrs)
def allows_multiple(field_type: t.Any) -> bool:

View File

@@ -1,17 +0,0 @@
from __future__ import annotations
import ast
import typing as t
class SourceGenerator(ast.NodeVisitor):
def __init__(
self,
indent_width: str,
add_line_information: bool = ...,
pretty_string: t.Callable[..., t.Any] = ...,
len: t.Callable[[t.Any], int] = ...,
isinstance: t.Callable[[t.Any, t.Any], bool] = ...,
callable: t.Callable[[t.Any], bool] = ...,
) -> None: ...
def newline(self, node: ast.AST | None = ..., extra: t.Any = ...) -> None: ...
def write(*params: t.Any) -> None: ...

View File

@@ -9,6 +9,7 @@ from typing import (
Literal,
Mapping,
Optional,
ParamSpec,
Protocol,
Sequence,
Tuple,
@@ -42,6 +43,7 @@ __license__: str
__copyright__: str
_T = TypeVar("_T")
_C = TypeVar("_C", bound=type)
_P = ParamSpec("_P")
_EqOrderType = Union[bool, Callable[[Any], Any]]
_ValidatorType = Callable[[Any, "Attribute[_T]", _T], Any]
_ConverterType = Callable[[Any], Any]
@@ -484,3 +486,17 @@ def get_run_validators() -> bool: ...
attributes = ...
attr = ...
dataclass = ...
def _make_init(
cls: type[AttrsInstance],
attrs: tuple[Attribute[_T]],
pre_init: bool,
post_init: bool,
frozen: bool,
slots: bool,
cache_hash: bool,
base_attr_map: dict[str, Any],
is_exc: bool,
cls_on_setattr: Any,
attrs_init: bool,
) -> Callable[_P, Any]: ...