From a94294bc657a992ed290c20fa6aa62b089400e18 Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Thu, 1 Jun 2023 19:06:06 +0000 Subject: [PATCH] fix: generate attrs class internally to conform with interface Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --- src/openllm/_configuration.py | 446 +++++++++++++----- src/openllm/_llm.py | 3 - src/openllm/cli.py | 403 ++++++++-------- src/openllm/exceptions.py | 4 + .../models/dolly_v2/configuration_dolly_v2.py | 4 +- src/openllm/utils/dantic.py | 20 +- typings/astor/__init__.pyi | 17 - typings/attr/__init__.pyi | 16 + 8 files changed, 557 insertions(+), 356 deletions(-) delete mode 100644 typings/astor/__init__.pyi diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index e9d4c137..9246f2d9 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -38,10 +38,10 @@ class FlanT5Config(openllm.LLMConfig): """ from __future__ import annotations -import functools import logging import os import typing as t +from operator import itemgetter import attr import click @@ -59,6 +59,7 @@ if t.TYPE_CHECKING: import tensorflow as tf import torch import transformers + from attr import _CountingAttr, _make_init from transformers.generation.beam_constraints import Constraint P = t.ParamSpec("P") @@ -68,9 +69,15 @@ if t.TYPE_CHECKING: ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]] DictStrAny = dict[str, t.Any] + ItemgetterAny = itemgetter[t.Any] else: Constraint = t.Any DictStrAny = dict + ItemgetterAny = itemgetter + # NOTE: Using internal API from attr here, since we are actually + # allowing subclass of openllm.LLMConfig to become 'attrs'-ish + from attr._make import _CountingAttr, _make_init + transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") torch = openllm.utils.LazyLoader("torch", globals(), "torch") tf = openllm.utils.LazyLoader("tf", globals(), "tensorflow") @@ -115,9 +122,6 @@ def attrs_to_options( ) -_IGNORE_FIELDS = ("__openllm_name_type__", "generation_config") - - @attr.define class GenerationConfig: """Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``, @@ -128,7 +132,9 @@ class GenerationConfig: # NOTE: parameters for controlling the length of the output max_new_tokens: int = dantic.Field( - 20, ge=0, description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt." + 20, + ge=0, + description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.", ) min_length: int = dantic.Field( 0, @@ -137,7 +143,7 @@ class GenerationConfig: input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.""", ) min_new_tokens: int = dantic.Field( - None, description="The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt." + description="The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.", ) early_stopping: bool = dantic.Field( False, @@ -149,7 +155,6 @@ class GenerationConfig: """, ) max_time: float = dantic.Field( - None, description="""The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed.""", ) @@ -162,7 +167,6 @@ class GenerationConfig: groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.""", ) penalty_alpha: float = dantic.Field( - None, description="""The values balance the model confidence and the degeneration penalty in contrastive search decoding.""", ) @@ -241,7 +245,6 @@ class GenerationConfig: 0, description="If set to int > 0, all ngrams of that size can only occur once." ) bad_words_ids: t.List[t.List[int]] = dantic.Field( - None, description="""List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`. @@ -250,7 +253,6 @@ class GenerationConfig: # NOTE: t.Union is not yet supported on CLI, but the environment variable should already be available. force_words_ids: t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]] = dantic.Field( - None, description="""List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this triggers a @@ -266,13 +268,11 @@ class GenerationConfig: """, ) constraints: t.List[Constraint] = dantic.Field( - None, - description="""Custom constraints that can be added to the generation to ensure that the output + description="""Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by ``Constraint`` objects, in the most sensible way possible. """, ) forced_bos_token_id: int = dantic.Field( - None, description="""The id of the token to force as the first generated token after the ``decoder_start_token_id``. Useful for multilingual models like [mBART](https://huggingface.co/docs/transformers/model_doc/mbart) where the first generated token needs @@ -280,7 +280,6 @@ class GenerationConfig: """, ) forced_eos_token_id: t.Union[int, t.List[int]] = dantic.Field( - None, description="""The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a list to set multiple *end-of-sequence* tokens.""", ) @@ -290,26 +289,22 @@ class GenerationConfig: generation method to crash. Note that using `remove_invalid_values` can slow down generation.""", ) exponential_decay_length_penalty: t.Tuple[int, float] = dantic.Field( - None, description="""This tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay """, ) suppress_tokens: t.List[int] = dantic.Field( - None, description="""A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled. """, ) begin_suppress_tokens: t.List[int] = dantic.Field( - None, description="""A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. """, ) forced_decoder_ids: t.List[t.List[int]] = dantic.Field( - None, description="""A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123. @@ -338,10 +333,9 @@ class GenerationConfig: ) # NOTE: Special tokens that can be used at generation time - pad_token_id: int = dantic.Field(None, description="The id of the *padding* token.") - bos_token_id: int = dantic.Field(None, description="The id of the *beginning-of-sequence* token.") + pad_token_id: int = dantic.Field(description="The id of the *padding* token.") + bos_token_id: int = dantic.Field(description="The id of the *beginning-of-sequence* token.") eos_token_id: t.Union[int, t.List[int]] = dantic.Field( - None, description="""The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.""", ) @@ -354,7 +348,6 @@ class GenerationConfig: """, ) decoder_start_token_id: int = dantic.Field( - None, description="""If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. """, @@ -373,43 +366,223 @@ class GenerationConfig: ) self.__attrs_init__(**attrs) - def model_dump(self, exclude_none: bool = False, **_: t.Any) -> dict[str, t.Any]: - target: dict[str, t.Any] = {} - for k in attr.fields_dict(self.__class__): - v = getattr(self, k, None) - if exclude_none and v is None: - continue - if not k.startswith("_"): - target[k] = v - return target + +def generation_config_dump(genconf: GenerationConfig) -> dict[str, t.Any]: + target: dict[str, t.Any] = {} + for k in attr.fields_dict(genconf.__class__): + v = getattr(genconf, k, None) + if v is None: + continue + if not k.startswith("_"): + target[k] = v + return target -bentoml_cattr.register_unstructure_hook( - GenerationConfig, functools.partial(GenerationConfig.model_dump, exclude_none=True) +bentoml_cattr.register_unstructure_hook_func( + lambda cls: lenient_issubclass(cls, GenerationConfig), + generation_config_dump, ) +def _populate_value_from_env_var( + key: str, transform: t.Callable[[str], str] | None = None, fallback: t.Any = None +) -> t.Any: + if transform is not None and callable(transform): + key = transform(key) + + return os.environ.get(key, fallback) + + def env_transformers(cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]: transformed: list[attr.Attribute[t.Any]] = [] for f in fields: - default = os.environ.get(f.metadata["env"], None) - if default is not None: - try: - default = orjson.loads(default) - except orjson.JSONDecodeError as err: - raise ValueError(f"Failed to load from environment variables: {err}") - else: - default = f.default - transformed.append(f.evolve(default=default)) + if "env" not in f.metadata: + raise ValueError( + "Make sure to setup the field with 'cls.Field' or 'attr.field(..., metadata={\"env\": \"...\"})'" + ) + _from_env = _populate_value_from_env_var(f.metadata["env"]) + if _from_env is not None: + f = f.evolve(default=_from_env) + transformed.append(f) return transformed +# sentinel object for unequivocal object() getattr +_sentinel = object() + + +def _has_own_attribute(cls: type[t.Any], attrib_name: t.Any): + """ + Check whether *cls* defines *attrib_name* (and doesn't just inherit it). + """ + attr = getattr(cls, attrib_name, _sentinel) + if attr is _sentinel: + return False + + for base_cls in cls.__mro__[1:]: + a = getattr(base_cls, attrib_name, None) + if attr is a: + return False + + return True + + +def _get_annotations(cls: type[t.Any]) -> DictStrAny: + """ + Get annotations for *cls*. + """ + if _has_own_attribute(cls, "__annotations__"): + return cls.__annotations__ + + return DictStrAny() + + +# The below is vendorred from attrs +def _collect_base_attrs( + cls: type[LLMConfig], taken_attr_names: set[str] +) -> tuple[list[attr.Attribute[t.Any]], dict[str, type[t.Any]]]: + """ + Collect attr.ibs from base classes of *cls*, except *taken_attr_names*. + """ + base_attrs: list[attr.Attribute[t.Any]] = [] + base_attr_map: dict[str, type[t.Any]] = {} # A dictionary of base attrs to their classes. + + # Traverse the MRO and collect attributes. + for base_cls in reversed(cls.__mro__[1:-1]): + for a in getattr(base_cls, "__attrs_attrs__", []): + if a.inherited or a.name in taken_attr_names: + continue + + a = a.evolve(inherited=True) + base_attrs.append(a) + base_attr_map[a.name] = base_cls + + # For each name, only keep the freshest definition i.e. the furthest at the back. + filtered: list[attr.Attribute[t.Any]] = [] + seen: set[str] = set() + for a in reversed(base_attrs): + if a.name in seen: + continue + filtered.insert(0, a) + seen.add(a.name) + + return filtered, base_attr_map + + +_classvar_prefixes = ( + "typing.ClassVar", + "t.ClassVar", + "ClassVar", + "typing_extensions.ClassVar", +) + + +def _is_class_var(annot: str | t.Any) -> bool: + """ + Check whether *annot* is a typing.ClassVar. + + The string comparison hack is used to avoid evaluating all string + annotations which would put attrs-based classes at a performance + disadvantage compared to plain old classes. + """ + annot = str(annot) + + # Annotation can be quoted. + if annot.startswith(("'", '"')) and annot.endswith(("'", '"')): + annot = annot[1:-1] + + return annot.startswith(_classvar_prefixes) + + +def _add_method_dunders(cls: type[t.Any], method: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: + """ + Add __module__ and __qualname__ to a *method* if possible. + """ + try: + method.__module__ = cls.__module__ + except AttributeError: + pass + + try: + method.__qualname__ = ".".join((cls.__qualname__, method.__name__)) + except AttributeError: + pass + + try: + method.__doc__ = "Method generated by attrs for class " f"{cls.__qualname__}." + except AttributeError: + pass + + return method + + +# NOTE: vendorred from attrs +def _compile_and_eval(script: str, globs: dict[str, t.Any], locs: dict[str, t.Any] | None = None, filename: str = ""): + """ + "Exec" the script with the given global (globs) and local (locs) variables. + """ + bytecode = compile(script, filename, "exec") + eval(bytecode, globs, locs) + + +def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[tuple[attr.Attribute[t.Any], ...]]: + """ + Create a tuple subclass to hold `Attribute`s for an `attrs` class. + + The subclass is a bare tuple with properties for names. + + class MyClassAttributes(tuple): + __slots__ = () + x = property(itemgetter(0)) + """ + attr_class_name = f"{cls_name}Attributes" + attr_class_template = [ + f"class {attr_class_name}(tuple):", + " __slots__ = ()", + ] + if attr_names: + for i, attr_name in enumerate(attr_names): + attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))") + else: + attr_class_template.append(" pass") + globs: dict[str, t.Any] = {"_attrs_itemgetter": ItemgetterAny, "_attrs_property": property} + _compile_and_eval("\n".join(attr_class_template), globs) + return globs[attr_class_name] + + +def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]: + attribs: DictStrAny = {} + _has_gen_class = _has_own_attribute(cls, "GenerationConfig") + for key, field in attr.fields_dict(GenerationConfig).items(): + class_gen_value = cls.GenerationConfig.__dict__.get(key, attr.NOTHING) if _has_gen_class else attr.NOTHING + attribs[key] = cls.Field( + class_gen_value if class_gen_value is not attr.NOTHING else field.default, + description=field.metadata.get("description"), + env=f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{key.upper()}", + validator=field.validator, + ) + if _has_gen_class: + delattr(cls, "GenerationConfig") + + return attr.make_class( + cls.__name__.replace("Config", "GenerationConfig"), attribs, field_transformer=env_transformers + ) + + @attr.define class LLMConfig: + Field = dantic.Field + """Field is a alias to the internal dantic utilities to easily create + attrs.fields with pydantic-compatible interface. + """ + if t.TYPE_CHECKING: + + def __attrs_init__(self, **attrs: t.Any): + ... + # The following is handled via __init_subclass__, and is only used for TYPE_CHECKING __attrs_attrs__: tuple[attr.Attribute[t.Any], ...] = tuple() - __openllm_attrs__: tuple[str, ...] = tuple() __openllm_timeout__: int = 3600 @@ -420,9 +593,9 @@ class LLMConfig: __openllm_start_name__: str = "" __openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize" - __openllm_env__: openllm.utils.ModelEnv = dantic.Field(None, init=False) + __openllm_env__: openllm.utils.ModelEnv = Field(None, init=False) - generation_class: type[GenerationConfig] = dantic.Field(None, init=False) + generation_class: type[GenerationConfig] = Field(None, init=False) GenerationConfig: type = type @@ -442,36 +615,6 @@ class LLMConfig: start_name = model_name cls.__openllm_name_type__ = name_type - - attributes = { - key: dantic.Field( - field.default, - description=field.metadata.get("description"), - env=f"OPENLLM_{model_name.upper()}_GENERATION_{key.upper()}", - validator=field.validator, - ) - for key, field in attr.fields_dict(GenerationConfig).items() - } - - generation_class: type[GenerationConfig] - if hasattr(cls, "GenerationConfig"): - generation_class = attr.make_class( - cls.__name__.replace("Config", "GenerationConfig"), - attributes, - field_transformer=env_transformers, - ) - delattr(cls, "GenerationConfig") - else: - generation_class = attr.make_class( - "GenerationConfig", - attributes, - field_transformer=env_transformers, - ) - generation_class.model_dump = GenerationConfig.model_dump # type: ignore - - # Set the generation_config attributes here. - cls.generation_class = generation_class - cls.__openllm_requires_gpu__ = requires_gpu cls.__openllm_timeout__ = default_timeout or 3600 cls.__openllm_trust_remote_code__ = trust_remote_code @@ -480,51 +623,100 @@ class LLMConfig: cls.__openllm_start_name__ = start_name cls.__openllm_env__ = openllm.utils.ModelEnv(model_name) - if hasattr(cls, "__annotations__"): - anns = cls.__annotations__ - else: - anns = {} + # NOTE: Since we want to enable a pydantic-like experience + # this means we will have to hide the attr abstraction, and generate + # all of the Field from __init_subclass__ + # Some of the logics here are from attr._make._transform_attrs + anns = _get_annotations(cls) + cd = cls.__dict__ + + def field_env_key(key: str) -> str: + return f"OPENLLM_{model_name.upper()}_{key.upper()}" + + ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)} + ca_list: list[tuple[str, _CountingAttr[t.Any]]] = [] + annotated_names: set[str] = set() + for attr_name, typ in anns.items(): + if _is_class_var(typ): + continue + annotated_names.add(attr_name) + val = cd.get(attr_name, attr.NOTHING) + if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val): + if val is attr.NOTHING: + val = cls.Field(env=field_env_key(attr_name)) + else: + val = cls.Field(default=val, env=field_env_key(attr_name)) + ca_list.append((attr_name, val)) + unannotated = ca_names - annotated_names + + if len(unannotated) > 0: + missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter) + raise openllm.exceptions.MissingAnnotationAttributeError( + f"The following field doesn't have a type annotation: {missing_annotated}" + ) + + # NOTE: we know need to determine the list of the attrs + # by mro to at the very least support inheritance. Tho it is not recommended. + own_attrs: list[attr.Attribute[t.Any]] = [ + attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=anns.get(attr_name)) + for attr_name, ca in ca_list + ] + base_attrs, base_attr_map = _collect_base_attrs(cls, {a.name for a in own_attrs}) + attrs: list[attr.Attribute[t.Any]] = own_attrs + base_attrs + + # Mandatory vs non-mandatory attr order only matters when they are part of + # the __init__ signature and when they aren't kw_only (which are moved to + # the end and can be mandatory or non-mandatory in any order, as they will + # be specified as keyword args anyway). Check the order of those attrs: + had_default = False + for a in (a for a in attrs if a.init is not False and a.kw_only is False): + if had_default is True and a.default is attr.NOTHING: + raise ValueError( + "No mandatory attributes allowed after an attribute with a " + f"default value or factory. Attribute in question: {a!r}" + ) + + if had_default is False and a.default is not attr.NOTHING: + had_default = True + + # NOTE: Resolve the alias and default value from environment variable + attrs = [ + a.evolve( + alias=a.name.lstrip("_") if not a.alias else None, + # NOTE: This is where we actually populate with the environment variable set for this attrs. + default=_populate_value_from_env_var(a.name, transform=field_env_key, fallback=a.default), + ) + for a in attrs + ] + + _has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) + _has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) # __openllm_attrs__ is a tracking tuple[attr.Attribute[t.Any]] # that we construct ourself. - openllm_attrs: tuple[str, ...] = tuple() - cur_attrs: tuple[attr.Attribute[t.Any]] = tuple() + cls.__openllm_attrs__ = tuple(a.name for a in attrs) + AttrsTuple = _make_attr_tuple_class(cls.__name__, cls.__openllm_attrs__) + # NOTE: generate a __attrs_init__ for the subclass + cls.__attrs_init__ = _add_method_dunders( + cls, + _make_init( + cls, + AttrsTuple(attrs), + _has_pre_init, + _has_post_init, + False, + True, + True, + base_attr_map, + False, + None, + attrs_init=True, + ), + ) + cls.__attrs_attrs__ = AttrsTuple(attrs) - for key, value in vars(cls).items(): - if key in _IGNORE_FIELDS: - continue - - # NOTE: we probably want to decorate all of the function in LLMConfig - # with an internal flag, which then allow user to define their own - # function and do whatever they want with the LLMConfig. Currently, this is - # a limitation. - if not key.startswith("_") and not callable(value): - env_key = f"OPENLLM_{model_name.upper()}_{key.upper()}" - default = os.environ.get(env_key, None) - if default is not None: - try: - default = orjson.loads(default) - except orjson.JSONDecodeError as err: - raise ValueError(f"Failed to load from environment variables: {err}") - else: - default = value - - annotation = anns.get(key, None) - if annotation is not None: - # NOTE: eval is dangerous, but we don't provide any specific imports here. - annotation = eval(annotation, {}, {}) - - attribute: attr.Attribute[t.Any] = attr.Attribute.from_counting_attr( - key, dantic.Field(default, env=env_key, alias=key), type=annotation - ) - - openllm_attrs += (key,) - cur_attrs += (attribute,) - - setattr(cls, key, attribute.default) - - cls.__openllm_attrs__ = openllm_attrs - cls.__attrs_attrs__ = cur_attrs + # NOTE: Finally, set the generation_class for this given config. + cls.generation_class = _make_internal_generation_class(cls) @property def name_type(self) -> t.Literal["dasherize", "lowercase"]: @@ -537,8 +729,6 @@ class LLMConfig: __openllm_extras__: dict[str, t.Any] | None = None, **attrs: t.Any, ): - self.__openllm_env__ = openllm.utils.ModelEnv(self.__openllm_model_name__) - to_exclude = list(attr.fields_dict(self.generation_class)) + list(self.__openllm_attrs__) self.__openllm_extras__ = __openllm_extras__ or {k: v for k, v in attrs.items() if k not in to_exclude} @@ -546,22 +736,18 @@ class LLMConfig: if generation_config is None: generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(self.generation_class)} - else: - logger.debug("Overriding default 'generation_config' with '%s'", generation_config) self.generation_config = self.generation_class(**generation_config) - # NOTE: since our subclass is not a dataclass-like, we need to set attr like this. - for k, fields in attr.fields_dict(self.__class__).items(): - if k in attrs and attrs[k] != fields.default: - setattr(self, k, attrs[k]) - else: - setattr(self, k, fields.default) + attrs = {k: v for k, v in attrs.items() if k not in generation_config} - # set the remaning attrs to class - for k, v in attrs.items(): - if k not in self.__openllm_attrs__: - setattr(self, k, v) + extras = set(attrs).difference(set(attr.fields_dict(self.__class__))) + if len(extras) > 0: + # Set all of the keys in extras to the class + for k in extras: + setattr(self, k, attrs[k]) + + self.__attrs_init__(**{k: v for k, v in attrs.items() if k not in extras}) def __repr__(self) -> str: bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={self.generation_config}" diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index dec9f9f8..b900dff2 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -272,14 +272,11 @@ class LLMMetaclass(ABCMeta): if cls_name.startswith("Flax"): implementation = "flax" prefix_class_name_config = cls_name[4:] - namespace["__annotations__"].update({"model": "transformers.FlaxPreTrainedModel"}) elif cls_name.startswith("TF"): implementation = "tf" prefix_class_name_config = cls_name[2:] - namespace["__annotations__"].update({"model": "transformers.TFPreTrainedModel"}) else: implementation = "pt" - namespace["__annotations__"].update({"model": "transformers.PreTrainedModel"}) namespace["__llm_implementation__"] = implementation # NOTE: setup config class branch diff --git a/src/openllm/cli.py b/src/openllm/cli.py index db803b20..1d514044 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -166,6 +166,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): self._cached_grpc: dict[str, t.Any] = {} def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: + breakpoint() cmd_name = self.resolve_alias(cmd_name) if ctx.command.name == "start": if cmd_name not in self._cached_http: @@ -465,213 +466,217 @@ output_option = click.option( ) -@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm") -def cli(): - """ - \b - ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ - ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ - ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ - ██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║ - ╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║ - ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝ +def cli_builder(): + @click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm") + def cli(): + """ + \b + ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ + ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ + ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ + ██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║ + ╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║ + ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝ - \b - OpenLLM: Your one stop-and-go-solution for serving any Open Large-Language Model + \b + OpenLLM: Your one stop-and-go-solution for serving any Open Large-Language Model - - StableLM, Falcon, ChatGLM, Dolly, Flan-T5, and more + - StableLM, Falcon, ChatGLM, Dolly, Flan-T5, and more + + \b + - Powered by BentoML 🍱 + HuggingFace 🤗 + """ + + @cli.command(name="version") + @output_option + @click.pass_context + def _(ctx: click.Context, output: t.Literal["json", "pretty", "porcelain"]): + """🚀 OpenLLM version.""" + from gettext import gettext + + from .__about__ import __version__ + + message = gettext("%(prog)s, version %(version)s") + version = __version__ + prog_name = ctx.find_root().info_name + + if output == "pretty": + click.echo(message % {"prog": prog_name, "version": version}, color=ctx.color) + elif output == "json": + click.echo(orjson.dumps({"version": version}, option=orjson.OPT_INDENT_2).decode()) + else: + click.echo(version) + + ctx.exit() + + @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start") + def _(): + """ + Start any LLM as a REST server. + + $ openllm start -- ... + """ + + @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc") + def _(): + """ + Start any LLM as a gRPC server. + + $ openllm start-grpc -- ... + """ + + @cli.command(name="bundle", aliases=["build"]) + @click.argument( + "model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]) + ) + @click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].") + @click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.") + @output_option + def _(model_name: str, pretrained: str | None, overwrite: bool, output: t.Literal["json", "pretty", "porcelain"]): + """Package a given models into a Bento. + + $ openllm bundle flan-t5 + """ + from bentoml._internal.configuration import get_quiet_mode + + bento, _previously_built = openllm.build( + model_name, __cli__=True, pretrained=pretrained, _overwrite_existing_bento=overwrite + ) + + if output == "pretty": + if not get_quiet_mode(): + click.echo("\n" + OPENLLM_FIGLET) + if not _previously_built: + click.secho(f"Successfully built {bento}.", fg="green") + else: + click.secho( + f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", + fg="yellow", + ) + + click.secho( + "\nPossible next steps:\n\n * Push to BentoCloud with `bentoml push`:\n " + + f"$ bentoml push {bento.tag}", + fg="blue", + ) + click.secho( + "\n * Containerize your Bento with `bentoml containerize`:\n " + + f"$ bentoml containerize {bento.tag}", + fg="blue", + ) + elif output == "json": + click.secho(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) + else: + click.echo(bento.tag) + return bento + + @cli.command(name="models") + @output_option + def _(output: t.Literal["json", "pretty", "porcelain"]): + """List all supported models.""" + models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()) + failed_initialized: list[tuple[str, Exception]] = [] + if output == "pretty": + import rich + import rich.box + from rich.table import Table + from rich.text import Text + + console = rich.get_console() + table = Table(title="Supported LLMs", box=rich.box.SQUARE, show_lines=True) + table.add_column("LLM") + table.add_column("Description") + table.add_column("Variants") + for m in models: + docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)") + try: + model = openllm.AutoLLM.for_model(m) + table.add_row(m, docs, f"{model.variants}") + except Exception as err: + failed_initialized.append((m, err)) + console.print(table) + if len(failed_initialized) > 0: + console.print( + "\n[bold yellow] The following models are supported but failed to initialize:[/bold yellow]\n" + ) + for m, err in failed_initialized: + console.print(Text(f"- {m}: ") + Text(f"{err}\n", style="bold red")) + elif output == "json": + result_json: dict[str, dict[t.Literal["variants", "description"], t.Any]] = {} + for m in models: + docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)") + try: + model = openllm.AutoLLM.for_model(m) + result_json[m] = {"variants": model.variants, "description": docs} + except Exception as err: + logger.debug("Exception caught while parsing model %s", m, exc_info=err) + result_json[m] = {"variants": None, "description": docs} + + click.secho(orjson.dumps(result_json, option=orjson.OPT_INDENT_2).decode()) + else: + click.echo("\n".join(models)) + sys.exit(0) + + @cli.command(name="download_models") + @click.argument( + "model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]) + ) + @click.option( + "--pretrained", type=click.STRING, default=None, help="Optional pretrained name or path to fine-tune weight." + ) + @output_option + def _(model_name: str, pretrained: str | None, output: t.Literal["json", "pretty", "porcelain"]): + """Setup LLM interactively. + + Note: This is useful for development and setup for fine-tune. + """ + config = openllm.AutoConfig.for_model(model_name) + env = config.__openllm_env__.get_framework_env() + if env == "flax": + model = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained, llm_config=config) + elif env == "tf": + model = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained, llm_config=config) + else: + model = openllm.AutoLLM.for_model(model_name, pretrained=pretrained, llm_config=config) + + tag = model.make_tag() + + if len(bentoml.models.list(tag)) == 0: + if output == "pretty": + click.secho(f"{tag} does not exists yet!. Downloading...", nl=True) + m = model.ensure_pretrained_exists() + click.secho(f"Saved model: {m.tag}") + elif output == "json": + m = model.ensure_pretrained_exists() + click.secho( + orjson.dumps( + {"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2 + ).decode() + ) + else: + m = model.ensure_pretrained_exists() + click.secho(m.tag) + else: + m = model.ensure_pretrained_exists() + if output == "pretty": + click.secho(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True) + elif output == "json": + click.secho( + orjson.dumps( + {"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2 + ).decode() + ) + else: + click.echo(m.tag) + return m - \b - - Powered by BentoML 🍱 + HuggingFace 🤗 - """ if psutil.WINDOWS: sys.stdout.reconfigure(encoding="utf-8") # type: ignore - -@cli.command() -@output_option -@click.pass_context -def version(ctx: click.Context, output: t.Literal["json", "pretty", "porcelain"]): - """🚀 OpenLLM version.""" - from gettext import gettext - - from .__about__ import __version__ - - message = gettext("%(prog)s, version %(version)s") - version = __version__ - prog_name = ctx.find_root().info_name - - if output == "pretty": - click.echo(message % {"prog": prog_name, "version": version}, color=ctx.color) - elif output == "json": - click.echo(orjson.dumps({"version": version}, option=orjson.OPT_INDENT_2).decode()) - else: - click.echo(version) - - ctx.exit() + return cli -@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start") -def start_cli(): - """ - Start any LLM as a REST server. - - $ openllm start -- ... - """ - - -@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc") -def start_grpc_cli(): - """ - Start any LLM as a gRPC server. - - $ openllm start-grpc -- ... - """ - - -@cli.command(aliases=["build"]) -@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])) -@click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].") -@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.") -@output_option -def bundle(model_name: str, pretrained: str | None, overwrite: bool, output: t.Literal["json", "pretty", "porcelain"]): - """Package a given models into a Bento. - - $ openllm bundle flan-t5 - """ - from bentoml._internal.configuration import get_quiet_mode - - bento, _previously_built = openllm.build( - model_name, __cli__=True, pretrained=pretrained, _overwrite_existing_bento=overwrite - ) - - if output == "pretty": - if not get_quiet_mode(): - click.echo("\n" + OPENLLM_FIGLET) - if not _previously_built: - click.secho(f"Successfully built {bento}.", fg="green") - else: - click.secho( - f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", - fg="yellow", - ) - - click.secho( - "\nPossible next steps:\n\n * Push to BentoCloud with `bentoml push`:\n " - + f"$ bentoml push {bento.tag}", - fg="blue", - ) - click.secho( - "\n * Containerize your Bento with `bentoml containerize`:\n " - + f"$ bentoml containerize {bento.tag}", - fg="blue", - ) - elif output == "json": - click.secho(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) - else: - click.echo(bento.tag) - return bento - - -@cli.command() -@output_option -def models(output: t.Literal["json", "pretty", "porcelain"]): - """List all supported models.""" - models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()) - failed_initialized: list[tuple[str, Exception]] = [] - if output == "pretty": - import rich - import rich.box - from rich.table import Table - from rich.text import Text - - console = rich.get_console() - table = Table(title="Supported LLMs", box=rich.box.SQUARE, show_lines=True) - table.add_column("LLM") - table.add_column("Description") - table.add_column("Variants") - for m in models: - docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)") - try: - model = openllm.AutoLLM.for_model(m) - table.add_row(m, docs, f"{model.variants}") - except Exception as err: - failed_initialized.append((m, err)) - console.print(table) - if len(failed_initialized) > 0: - console.print( - "\n[bold yellow] The following models are supported but failed to initialize:[/bold yellow]\n" - ) - for m, err in failed_initialized: - console.print(Text(f"- {m}: ") + Text(f"{err}\n", style="bold red")) - elif output == "json": - result_json: dict[str, dict[t.Literal["variants", "description"], t.Any]] = {} - for m in models: - docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)") - try: - model = openllm.AutoLLM.for_model(m) - result_json[m] = {"variants": model.variants, "description": docs} - except Exception as err: - logger.debug("Exception caught while parsing model %s", m, exc_info=err) - result_json[m] = {"variants": None, "description": docs} - - click.secho(orjson.dumps(result_json, option=orjson.OPT_INDENT_2).decode()) - else: - click.echo("\n".join(models)) - sys.exit(0) - - -@cli.command() -@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])) -@click.option( - "--pretrained", type=click.STRING, default=None, help="Optional pretrained name or path to fine-tune weight." -) -@output_option -def download_models(model_name: str, pretrained: str | None, output: t.Literal["json", "pretty", "porcelain"]): - """Setup LLM interactively. - - Note: This is useful for development and setup for fine-tune. - """ - config = openllm.AutoConfig.for_model(model_name) - env = config.__openllm_env__.get_framework_env() - if env == "flax": - model = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained, llm_config=config) - elif env == "tf": - model = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained, llm_config=config) - else: - model = openllm.AutoLLM.for_model(model_name, pretrained=pretrained, llm_config=config) - - tag = model.make_tag() - - if len(bentoml.models.list(tag)) == 0: - if output == "pretty": - click.secho(f"{tag} does not exists yet!. Downloading...", nl=True) - m = model.ensure_pretrained_exists() - click.secho(f"Saved model: {m.tag}") - elif output == "json": - m = model.ensure_pretrained_exists() - click.secho( - orjson.dumps( - {"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2 - ).decode() - ) - else: - m = model.ensure_pretrained_exists() - click.secho(m.tag) - else: - m = model.ensure_pretrained_exists() - if output == "pretty": - click.secho(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True) - elif output == "json": - click.secho( - orjson.dumps( - {"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2 - ).decode() - ) - else: - click.echo(m.tag) - return m - +cli = cli_builder() if __name__ == "__main__": cli() diff --git a/src/openllm/exceptions.py b/src/openllm/exceptions.py index 370babb1..05371869 100644 --- a/src/openllm/exceptions.py +++ b/src/openllm/exceptions.py @@ -35,5 +35,9 @@ class ForbiddenAttributeError(OpenLLMException): """Raised when using an _internal field.""" +class MissingAnnotationAttributeError(OpenLLMException): + """Raised when a field under openllm.LLMConfig is missing annotations.""" + + class MissingDependencyError(BaseException): """Raised when a dependency is missing.""" diff --git a/src/openllm/models/dolly_v2/configuration_dolly_v2.py b/src/openllm/models/dolly_v2/configuration_dolly_v2.py index edbebfdb..b6a163ab 100644 --- a/src/openllm/models/dolly_v2/configuration_dolly_v2.py +++ b/src/openllm/models/dolly_v2/configuration_dolly_v2.py @@ -34,7 +34,9 @@ class DollyV2Config(openllm.LLMConfig, default_timeout=3600000): Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information. """ - return_full_text: bool = False + return_full_text: bool = openllm.LLMConfig.Field( + False, description="Whether to return the full prompt to the users." + ) class GenerationConfig: temperature: float = 0.9 diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index d34597b9..f83f9198 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -15,6 +15,8 @@ from __future__ import annotations +import functools +import os import typing as t import attr @@ -30,8 +32,16 @@ if t.TYPE_CHECKING: _T = t.TypeVar("_T") +def _default_converter(value: t.Any, env: str | None) -> t.Any: + if env is not None: + value = os.environ.get(env, value) + if value is not None and isinstance(value, str): + return eval(value, {"__builtins__": {}}, {}) + return value + + def Field( - value: t.Any, + default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, @@ -52,10 +62,6 @@ def Field( **kwargs: The rest of the arguments are passed to attr.field """ metadata = attrs.pop("metadata", {}) - default = attrs.pop("default", value) - if default is not value: - raise ValueError("Either specify 'default=value' or provide 'value' as the only argument") - if description is None: description = "(No description is available)" metadata["description"] = description @@ -63,6 +69,8 @@ def Field( metadata["env"] = env piped: list[_ValidatorType[t.Any]] = [] + converter = attrs.pop("converter", functools.partial(_default_converter, env=env)) + if ge is not None: piped.append(attr.validators.ge(ge)) if le is not None: @@ -77,7 +85,7 @@ def Field( else: _validator = attr.validators.and_(*piped) - return attr.field(default=default, metadata=metadata, validator=_validator, **attrs) + return attr.field(default=default, metadata=metadata, validator=_validator, converter=converter, **attrs) def allows_multiple(field_type: t.Any) -> bool: diff --git a/typings/astor/__init__.pyi b/typings/astor/__init__.pyi deleted file mode 100644 index de85819d..00000000 --- a/typings/astor/__init__.pyi +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import ast -import typing as t - -class SourceGenerator(ast.NodeVisitor): - def __init__( - self, - indent_width: str, - add_line_information: bool = ..., - pretty_string: t.Callable[..., t.Any] = ..., - len: t.Callable[[t.Any], int] = ..., - isinstance: t.Callable[[t.Any, t.Any], bool] = ..., - callable: t.Callable[[t.Any], bool] = ..., - ) -> None: ... - def newline(self, node: ast.AST | None = ..., extra: t.Any = ...) -> None: ... - def write(*params: t.Any) -> None: ... diff --git a/typings/attr/__init__.pyi b/typings/attr/__init__.pyi index a6582a83..99aa31b7 100644 --- a/typings/attr/__init__.pyi +++ b/typings/attr/__init__.pyi @@ -9,6 +9,7 @@ from typing import ( Literal, Mapping, Optional, + ParamSpec, Protocol, Sequence, Tuple, @@ -42,6 +43,7 @@ __license__: str __copyright__: str _T = TypeVar("_T") _C = TypeVar("_C", bound=type) +_P = ParamSpec("_P") _EqOrderType = Union[bool, Callable[[Any], Any]] _ValidatorType = Callable[[Any, "Attribute[_T]", _T], Any] _ConverterType = Callable[[Any], Any] @@ -484,3 +486,17 @@ def get_run_validators() -> bool: ... attributes = ... attr = ... dataclass = ... + +def _make_init( + cls: type[AttrsInstance], + attrs: tuple[Attribute[_T]], + pre_init: bool, + post_init: bool, + frozen: bool, + slots: bool, + cache_hash: bool, + base_attr_map: dict[str, Any], + is_exc: bool, + cls_on_setattr: Any, + attrs_init: bool, +) -> Callable[_P, Any]: ...