mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-24 16:44:39 -04:00
fix: generate attrs class internally to conform with interface
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -38,10 +38,10 @@ class FlanT5Config(openllm.LLMConfig):
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
from operator import itemgetter
|
||||
|
||||
import attr
|
||||
import click
|
||||
@@ -59,6 +59,7 @@ if t.TYPE_CHECKING:
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import transformers
|
||||
from attr import _CountingAttr, _make_init
|
||||
from transformers.generation.beam_constraints import Constraint
|
||||
|
||||
P = t.ParamSpec("P")
|
||||
@@ -68,9 +69,15 @@ if t.TYPE_CHECKING:
|
||||
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
|
||||
|
||||
DictStrAny = dict[str, t.Any]
|
||||
ItemgetterAny = itemgetter[t.Any]
|
||||
else:
|
||||
Constraint = t.Any
|
||||
DictStrAny = dict
|
||||
ItemgetterAny = itemgetter
|
||||
# NOTE: Using internal API from attr here, since we are actually
|
||||
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
|
||||
from attr._make import _CountingAttr, _make_init
|
||||
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
tf = openllm.utils.LazyLoader("tf", globals(), "tensorflow")
|
||||
@@ -115,9 +122,6 @@ def attrs_to_options(
|
||||
)
|
||||
|
||||
|
||||
_IGNORE_FIELDS = ("__openllm_name_type__", "generation_config")
|
||||
|
||||
|
||||
@attr.define
|
||||
class GenerationConfig:
|
||||
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
|
||||
@@ -128,7 +132,9 @@ class GenerationConfig:
|
||||
|
||||
# NOTE: parameters for controlling the length of the output
|
||||
max_new_tokens: int = dantic.Field(
|
||||
20, ge=0, description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."
|
||||
20,
|
||||
ge=0,
|
||||
description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.",
|
||||
)
|
||||
min_length: int = dantic.Field(
|
||||
0,
|
||||
@@ -137,7 +143,7 @@ class GenerationConfig:
|
||||
input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.""",
|
||||
)
|
||||
min_new_tokens: int = dantic.Field(
|
||||
None, description="The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt."
|
||||
description="The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.",
|
||||
)
|
||||
early_stopping: bool = dantic.Field(
|
||||
False,
|
||||
@@ -149,7 +155,6 @@ class GenerationConfig:
|
||||
""",
|
||||
)
|
||||
max_time: float = dantic.Field(
|
||||
None,
|
||||
description="""The maximum amount of time you allow the computation to run for in seconds. generation will
|
||||
still finish the current pass after allocated time has been passed.""",
|
||||
)
|
||||
@@ -162,7 +167,6 @@ class GenerationConfig:
|
||||
groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.""",
|
||||
)
|
||||
penalty_alpha: float = dantic.Field(
|
||||
None,
|
||||
description="""The values balance the model confidence and the degeneration penalty in
|
||||
contrastive search decoding.""",
|
||||
)
|
||||
@@ -241,7 +245,6 @@ class GenerationConfig:
|
||||
0, description="If set to int > 0, all ngrams of that size can only occur once."
|
||||
)
|
||||
bad_words_ids: t.List[t.List[int]] = dantic.Field(
|
||||
None,
|
||||
description="""List of token ids that are not allowed to be generated. In order to get the token ids
|
||||
of the words that should not appear in the generated text, use
|
||||
`tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`.
|
||||
@@ -250,7 +253,6 @@ class GenerationConfig:
|
||||
|
||||
# NOTE: t.Union is not yet supported on CLI, but the environment variable should already be available.
|
||||
force_words_ids: t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]] = dantic.Field(
|
||||
None,
|
||||
description="""List of token ids that must be generated. If given a `List[List[int]]`, this is treated
|
||||
as a simple list of words that must be included, the opposite to `bad_words_ids`.
|
||||
If given `List[List[List[int]]]`, this triggers a
|
||||
@@ -266,13 +268,11 @@ class GenerationConfig:
|
||||
""",
|
||||
)
|
||||
constraints: t.List[Constraint] = dantic.Field(
|
||||
None,
|
||||
description="""Custom constraints that can be added to the generation to ensure that the output
|
||||
description="""Custom constraints that can be added to the generation to ensure that the output
|
||||
will contain the use of certain tokens as defined by ``Constraint`` objects, in the most sensible way possible.
|
||||
""",
|
||||
)
|
||||
forced_bos_token_id: int = dantic.Field(
|
||||
None,
|
||||
description="""The id of the token to force as the first generated token after the
|
||||
``decoder_start_token_id``. Useful for multilingual models like
|
||||
[mBART](https://huggingface.co/docs/transformers/model_doc/mbart) where the first generated token needs
|
||||
@@ -280,7 +280,6 @@ class GenerationConfig:
|
||||
""",
|
||||
)
|
||||
forced_eos_token_id: t.Union[int, t.List[int]] = dantic.Field(
|
||||
None,
|
||||
description="""The id of the token to force as the last generated token when `max_length` is reached.
|
||||
Optionally, use a list to set multiple *end-of-sequence* tokens.""",
|
||||
)
|
||||
@@ -290,26 +289,22 @@ class GenerationConfig:
|
||||
generation method to crash. Note that using `remove_invalid_values` can slow down generation.""",
|
||||
)
|
||||
exponential_decay_length_penalty: t.Tuple[int, float] = dantic.Field(
|
||||
None,
|
||||
description="""This tuple adds an exponentially increasing length penalty, after a certain amount of tokens
|
||||
have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index`
|
||||
indicates where penalty starts and `decay_factor` represents the factor of exponential decay
|
||||
""",
|
||||
)
|
||||
suppress_tokens: t.List[int] = dantic.Field(
|
||||
None,
|
||||
description="""A list of tokens that will be suppressed at generation. The `SupressTokens` logit
|
||||
processor will set their log probs to `-inf` so that they are not sampled.
|
||||
""",
|
||||
)
|
||||
begin_suppress_tokens: t.List[int] = dantic.Field(
|
||||
None,
|
||||
description="""A list of tokens that will be suppressed at the beginning of the generation. The
|
||||
`SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
|
||||
""",
|
||||
)
|
||||
forced_decoder_ids: t.List[t.List[int]] = dantic.Field(
|
||||
None,
|
||||
description="""A list of pairs of integers which indicates a mapping from generation indices to token indices
|
||||
that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
|
||||
be a token of index 123.
|
||||
@@ -338,10 +333,9 @@ class GenerationConfig:
|
||||
)
|
||||
|
||||
# NOTE: Special tokens that can be used at generation time
|
||||
pad_token_id: int = dantic.Field(None, description="The id of the *padding* token.")
|
||||
bos_token_id: int = dantic.Field(None, description="The id of the *beginning-of-sequence* token.")
|
||||
pad_token_id: int = dantic.Field(description="The id of the *padding* token.")
|
||||
bos_token_id: int = dantic.Field(description="The id of the *beginning-of-sequence* token.")
|
||||
eos_token_id: t.Union[int, t.List[int]] = dantic.Field(
|
||||
None,
|
||||
description="""The id of the *end-of-sequence* token. Optionally, use a list to set
|
||||
multiple *end-of-sequence* tokens.""",
|
||||
)
|
||||
@@ -354,7 +348,6 @@ class GenerationConfig:
|
||||
""",
|
||||
)
|
||||
decoder_start_token_id: int = dantic.Field(
|
||||
None,
|
||||
description="""If an encoder-decoder model starts decoding with a
|
||||
different token than *bos*, the id of that token.
|
||||
""",
|
||||
@@ -373,43 +366,223 @@ class GenerationConfig:
|
||||
)
|
||||
self.__attrs_init__(**attrs)
|
||||
|
||||
def model_dump(self, exclude_none: bool = False, **_: t.Any) -> dict[str, t.Any]:
|
||||
target: dict[str, t.Any] = {}
|
||||
for k in attr.fields_dict(self.__class__):
|
||||
v = getattr(self, k, None)
|
||||
if exclude_none and v is None:
|
||||
continue
|
||||
if not k.startswith("_"):
|
||||
target[k] = v
|
||||
return target
|
||||
|
||||
def generation_config_dump(genconf: GenerationConfig) -> dict[str, t.Any]:
|
||||
target: dict[str, t.Any] = {}
|
||||
for k in attr.fields_dict(genconf.__class__):
|
||||
v = getattr(genconf, k, None)
|
||||
if v is None:
|
||||
continue
|
||||
if not k.startswith("_"):
|
||||
target[k] = v
|
||||
return target
|
||||
|
||||
|
||||
bentoml_cattr.register_unstructure_hook(
|
||||
GenerationConfig, functools.partial(GenerationConfig.model_dump, exclude_none=True)
|
||||
bentoml_cattr.register_unstructure_hook_func(
|
||||
lambda cls: lenient_issubclass(cls, GenerationConfig),
|
||||
generation_config_dump,
|
||||
)
|
||||
|
||||
|
||||
def _populate_value_from_env_var(
|
||||
key: str, transform: t.Callable[[str], str] | None = None, fallback: t.Any = None
|
||||
) -> t.Any:
|
||||
if transform is not None and callable(transform):
|
||||
key = transform(key)
|
||||
|
||||
return os.environ.get(key, fallback)
|
||||
|
||||
|
||||
def env_transformers(cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]:
|
||||
transformed: list[attr.Attribute[t.Any]] = []
|
||||
for f in fields:
|
||||
default = os.environ.get(f.metadata["env"], None)
|
||||
if default is not None:
|
||||
try:
|
||||
default = orjson.loads(default)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise ValueError(f"Failed to load from environment variables: {err}")
|
||||
else:
|
||||
default = f.default
|
||||
transformed.append(f.evolve(default=default))
|
||||
if "env" not in f.metadata:
|
||||
raise ValueError(
|
||||
"Make sure to setup the field with 'cls.Field' or 'attr.field(..., metadata={\"env\": \"...\"})'"
|
||||
)
|
||||
_from_env = _populate_value_from_env_var(f.metadata["env"])
|
||||
if _from_env is not None:
|
||||
f = f.evolve(default=_from_env)
|
||||
transformed.append(f)
|
||||
return transformed
|
||||
|
||||
|
||||
# sentinel object for unequivocal object() getattr
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def _has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
|
||||
"""
|
||||
Check whether *cls* defines *attrib_name* (and doesn't just inherit it).
|
||||
"""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel:
|
||||
return False
|
||||
|
||||
for base_cls in cls.__mro__[1:]:
|
||||
a = getattr(base_cls, attrib_name, None)
|
||||
if attr is a:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
"""
|
||||
Get annotations for *cls*.
|
||||
"""
|
||||
if _has_own_attribute(cls, "__annotations__"):
|
||||
return cls.__annotations__
|
||||
|
||||
return DictStrAny()
|
||||
|
||||
|
||||
# The below is vendorred from attrs
|
||||
def _collect_base_attrs(
|
||||
cls: type[LLMConfig], taken_attr_names: set[str]
|
||||
) -> tuple[list[attr.Attribute[t.Any]], dict[str, type[t.Any]]]:
|
||||
"""
|
||||
Collect attr.ibs from base classes of *cls*, except *taken_attr_names*.
|
||||
"""
|
||||
base_attrs: list[attr.Attribute[t.Any]] = []
|
||||
base_attr_map: dict[str, type[t.Any]] = {} # A dictionary of base attrs to their classes.
|
||||
|
||||
# Traverse the MRO and collect attributes.
|
||||
for base_cls in reversed(cls.__mro__[1:-1]):
|
||||
for a in getattr(base_cls, "__attrs_attrs__", []):
|
||||
if a.inherited or a.name in taken_attr_names:
|
||||
continue
|
||||
|
||||
a = a.evolve(inherited=True)
|
||||
base_attrs.append(a)
|
||||
base_attr_map[a.name] = base_cls
|
||||
|
||||
# For each name, only keep the freshest definition i.e. the furthest at the back.
|
||||
filtered: list[attr.Attribute[t.Any]] = []
|
||||
seen: set[str] = set()
|
||||
for a in reversed(base_attrs):
|
||||
if a.name in seen:
|
||||
continue
|
||||
filtered.insert(0, a)
|
||||
seen.add(a.name)
|
||||
|
||||
return filtered, base_attr_map
|
||||
|
||||
|
||||
_classvar_prefixes = (
|
||||
"typing.ClassVar",
|
||||
"t.ClassVar",
|
||||
"ClassVar",
|
||||
"typing_extensions.ClassVar",
|
||||
)
|
||||
|
||||
|
||||
def _is_class_var(annot: str | t.Any) -> bool:
|
||||
"""
|
||||
Check whether *annot* is a typing.ClassVar.
|
||||
|
||||
The string comparison hack is used to avoid evaluating all string
|
||||
annotations which would put attrs-based classes at a performance
|
||||
disadvantage compared to plain old classes.
|
||||
"""
|
||||
annot = str(annot)
|
||||
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
|
||||
annot = annot[1:-1]
|
||||
|
||||
return annot.startswith(_classvar_prefixes)
|
||||
|
||||
|
||||
def _add_method_dunders(cls: type[t.Any], method: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
|
||||
"""
|
||||
Add __module__ and __qualname__ to a *method* if possible.
|
||||
"""
|
||||
try:
|
||||
method.__module__ = cls.__module__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
method.__qualname__ = ".".join((cls.__qualname__, method.__name__))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
method.__doc__ = "Method generated by attrs for class " f"{cls.__qualname__}."
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return method
|
||||
|
||||
|
||||
# NOTE: vendorred from attrs
|
||||
def _compile_and_eval(script: str, globs: dict[str, t.Any], locs: dict[str, t.Any] | None = None, filename: str = ""):
|
||||
"""
|
||||
"Exec" the script with the given global (globs) and local (locs) variables.
|
||||
"""
|
||||
bytecode = compile(script, filename, "exec")
|
||||
eval(bytecode, globs, locs)
|
||||
|
||||
|
||||
def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[tuple[attr.Attribute[t.Any], ...]]:
|
||||
"""
|
||||
Create a tuple subclass to hold `Attribute`s for an `attrs` class.
|
||||
|
||||
The subclass is a bare tuple with properties for names.
|
||||
|
||||
class MyClassAttributes(tuple):
|
||||
__slots__ = ()
|
||||
x = property(itemgetter(0))
|
||||
"""
|
||||
attr_class_name = f"{cls_name}Attributes"
|
||||
attr_class_template = [
|
||||
f"class {attr_class_name}(tuple):",
|
||||
" __slots__ = ()",
|
||||
]
|
||||
if attr_names:
|
||||
for i, attr_name in enumerate(attr_names):
|
||||
attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
|
||||
else:
|
||||
attr_class_template.append(" pass")
|
||||
globs: dict[str, t.Any] = {"_attrs_itemgetter": ItemgetterAny, "_attrs_property": property}
|
||||
_compile_and_eval("\n".join(attr_class_template), globs)
|
||||
return globs[attr_class_name]
|
||||
|
||||
|
||||
def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]:
|
||||
attribs: DictStrAny = {}
|
||||
_has_gen_class = _has_own_attribute(cls, "GenerationConfig")
|
||||
for key, field in attr.fields_dict(GenerationConfig).items():
|
||||
class_gen_value = cls.GenerationConfig.__dict__.get(key, attr.NOTHING) if _has_gen_class else attr.NOTHING
|
||||
attribs[key] = cls.Field(
|
||||
class_gen_value if class_gen_value is not attr.NOTHING else field.default,
|
||||
description=field.metadata.get("description"),
|
||||
env=f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{key.upper()}",
|
||||
validator=field.validator,
|
||||
)
|
||||
if _has_gen_class:
|
||||
delattr(cls, "GenerationConfig")
|
||||
|
||||
return attr.make_class(
|
||||
cls.__name__.replace("Config", "GenerationConfig"), attribs, field_transformer=env_transformers
|
||||
)
|
||||
|
||||
|
||||
@attr.define
|
||||
class LLMConfig:
|
||||
Field = dantic.Field
|
||||
"""Field is a alias to the internal dantic utilities to easily create
|
||||
attrs.fields with pydantic-compatible interface.
|
||||
"""
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def __attrs_init__(self, **attrs: t.Any):
|
||||
...
|
||||
|
||||
# The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
|
||||
__attrs_attrs__: tuple[attr.Attribute[t.Any], ...] = tuple()
|
||||
|
||||
__openllm_attrs__: tuple[str, ...] = tuple()
|
||||
|
||||
__openllm_timeout__: int = 3600
|
||||
@@ -420,9 +593,9 @@ class LLMConfig:
|
||||
__openllm_start_name__: str = ""
|
||||
__openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize"
|
||||
|
||||
__openllm_env__: openllm.utils.ModelEnv = dantic.Field(None, init=False)
|
||||
__openllm_env__: openllm.utils.ModelEnv = Field(None, init=False)
|
||||
|
||||
generation_class: type[GenerationConfig] = dantic.Field(None, init=False)
|
||||
generation_class: type[GenerationConfig] = Field(None, init=False)
|
||||
|
||||
GenerationConfig: type = type
|
||||
|
||||
@@ -442,36 +615,6 @@ class LLMConfig:
|
||||
start_name = model_name
|
||||
|
||||
cls.__openllm_name_type__ = name_type
|
||||
|
||||
attributes = {
|
||||
key: dantic.Field(
|
||||
field.default,
|
||||
description=field.metadata.get("description"),
|
||||
env=f"OPENLLM_{model_name.upper()}_GENERATION_{key.upper()}",
|
||||
validator=field.validator,
|
||||
)
|
||||
for key, field in attr.fields_dict(GenerationConfig).items()
|
||||
}
|
||||
|
||||
generation_class: type[GenerationConfig]
|
||||
if hasattr(cls, "GenerationConfig"):
|
||||
generation_class = attr.make_class(
|
||||
cls.__name__.replace("Config", "GenerationConfig"),
|
||||
attributes,
|
||||
field_transformer=env_transformers,
|
||||
)
|
||||
delattr(cls, "GenerationConfig")
|
||||
else:
|
||||
generation_class = attr.make_class(
|
||||
"GenerationConfig",
|
||||
attributes,
|
||||
field_transformer=env_transformers,
|
||||
)
|
||||
generation_class.model_dump = GenerationConfig.model_dump # type: ignore
|
||||
|
||||
# Set the generation_config attributes here.
|
||||
cls.generation_class = generation_class
|
||||
|
||||
cls.__openllm_requires_gpu__ = requires_gpu
|
||||
cls.__openllm_timeout__ = default_timeout or 3600
|
||||
cls.__openllm_trust_remote_code__ = trust_remote_code
|
||||
@@ -480,51 +623,100 @@ class LLMConfig:
|
||||
cls.__openllm_start_name__ = start_name
|
||||
cls.__openllm_env__ = openllm.utils.ModelEnv(model_name)
|
||||
|
||||
if hasattr(cls, "__annotations__"):
|
||||
anns = cls.__annotations__
|
||||
else:
|
||||
anns = {}
|
||||
# NOTE: Since we want to enable a pydantic-like experience
|
||||
# this means we will have to hide the attr abstraction, and generate
|
||||
# all of the Field from __init_subclass__
|
||||
# Some of the logics here are from attr._make._transform_attrs
|
||||
anns = _get_annotations(cls)
|
||||
cd = cls.__dict__
|
||||
|
||||
def field_env_key(key: str) -> str:
|
||||
return f"OPENLLM_{model_name.upper()}_{key.upper()}"
|
||||
|
||||
ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)}
|
||||
ca_list: list[tuple[str, _CountingAttr[t.Any]]] = []
|
||||
annotated_names: set[str] = set()
|
||||
for attr_name, typ in anns.items():
|
||||
if _is_class_var(typ):
|
||||
continue
|
||||
annotated_names.add(attr_name)
|
||||
val = cd.get(attr_name, attr.NOTHING)
|
||||
if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val):
|
||||
if val is attr.NOTHING:
|
||||
val = cls.Field(env=field_env_key(attr_name))
|
||||
else:
|
||||
val = cls.Field(default=val, env=field_env_key(attr_name))
|
||||
ca_list.append((attr_name, val))
|
||||
unannotated = ca_names - annotated_names
|
||||
|
||||
if len(unannotated) > 0:
|
||||
missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter)
|
||||
raise openllm.exceptions.MissingAnnotationAttributeError(
|
||||
f"The following field doesn't have a type annotation: {missing_annotated}"
|
||||
)
|
||||
|
||||
# NOTE: we know need to determine the list of the attrs
|
||||
# by mro to at the very least support inheritance. Tho it is not recommended.
|
||||
own_attrs: list[attr.Attribute[t.Any]] = [
|
||||
attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=anns.get(attr_name))
|
||||
for attr_name, ca in ca_list
|
||||
]
|
||||
base_attrs, base_attr_map = _collect_base_attrs(cls, {a.name for a in own_attrs})
|
||||
attrs: list[attr.Attribute[t.Any]] = own_attrs + base_attrs
|
||||
|
||||
# Mandatory vs non-mandatory attr order only matters when they are part of
|
||||
# the __init__ signature and when they aren't kw_only (which are moved to
|
||||
# the end and can be mandatory or non-mandatory in any order, as they will
|
||||
# be specified as keyword args anyway). Check the order of those attrs:
|
||||
had_default = False
|
||||
for a in (a for a in attrs if a.init is not False and a.kw_only is False):
|
||||
if had_default is True and a.default is attr.NOTHING:
|
||||
raise ValueError(
|
||||
"No mandatory attributes allowed after an attribute with a "
|
||||
f"default value or factory. Attribute in question: {a!r}"
|
||||
)
|
||||
|
||||
if had_default is False and a.default is not attr.NOTHING:
|
||||
had_default = True
|
||||
|
||||
# NOTE: Resolve the alias and default value from environment variable
|
||||
attrs = [
|
||||
a.evolve(
|
||||
alias=a.name.lstrip("_") if not a.alias else None,
|
||||
# NOTE: This is where we actually populate with the environment variable set for this attrs.
|
||||
default=_populate_value_from_env_var(a.name, transform=field_env_key, fallback=a.default),
|
||||
)
|
||||
for a in attrs
|
||||
]
|
||||
|
||||
_has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
|
||||
_has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
|
||||
|
||||
# __openllm_attrs__ is a tracking tuple[attr.Attribute[t.Any]]
|
||||
# that we construct ourself.
|
||||
openllm_attrs: tuple[str, ...] = tuple()
|
||||
cur_attrs: tuple[attr.Attribute[t.Any]] = tuple()
|
||||
cls.__openllm_attrs__ = tuple(a.name for a in attrs)
|
||||
AttrsTuple = _make_attr_tuple_class(cls.__name__, cls.__openllm_attrs__)
|
||||
# NOTE: generate a __attrs_init__ for the subclass
|
||||
cls.__attrs_init__ = _add_method_dunders(
|
||||
cls,
|
||||
_make_init(
|
||||
cls,
|
||||
AttrsTuple(attrs),
|
||||
_has_pre_init,
|
||||
_has_post_init,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
base_attr_map,
|
||||
False,
|
||||
None,
|
||||
attrs_init=True,
|
||||
),
|
||||
)
|
||||
cls.__attrs_attrs__ = AttrsTuple(attrs)
|
||||
|
||||
for key, value in vars(cls).items():
|
||||
if key in _IGNORE_FIELDS:
|
||||
continue
|
||||
|
||||
# NOTE: we probably want to decorate all of the function in LLMConfig
|
||||
# with an internal flag, which then allow user to define their own
|
||||
# function and do whatever they want with the LLMConfig. Currently, this is
|
||||
# a limitation.
|
||||
if not key.startswith("_") and not callable(value):
|
||||
env_key = f"OPENLLM_{model_name.upper()}_{key.upper()}"
|
||||
default = os.environ.get(env_key, None)
|
||||
if default is not None:
|
||||
try:
|
||||
default = orjson.loads(default)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise ValueError(f"Failed to load from environment variables: {err}")
|
||||
else:
|
||||
default = value
|
||||
|
||||
annotation = anns.get(key, None)
|
||||
if annotation is not None:
|
||||
# NOTE: eval is dangerous, but we don't provide any specific imports here.
|
||||
annotation = eval(annotation, {}, {})
|
||||
|
||||
attribute: attr.Attribute[t.Any] = attr.Attribute.from_counting_attr(
|
||||
key, dantic.Field(default, env=env_key, alias=key), type=annotation
|
||||
)
|
||||
|
||||
openllm_attrs += (key,)
|
||||
cur_attrs += (attribute,)
|
||||
|
||||
setattr(cls, key, attribute.default)
|
||||
|
||||
cls.__openllm_attrs__ = openllm_attrs
|
||||
cls.__attrs_attrs__ = cur_attrs
|
||||
# NOTE: Finally, set the generation_class for this given config.
|
||||
cls.generation_class = _make_internal_generation_class(cls)
|
||||
|
||||
@property
|
||||
def name_type(self) -> t.Literal["dasherize", "lowercase"]:
|
||||
@@ -537,8 +729,6 @@ class LLMConfig:
|
||||
__openllm_extras__: dict[str, t.Any] | None = None,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
self.__openllm_env__ = openllm.utils.ModelEnv(self.__openllm_model_name__)
|
||||
|
||||
to_exclude = list(attr.fields_dict(self.generation_class)) + list(self.__openllm_attrs__)
|
||||
self.__openllm_extras__ = __openllm_extras__ or {k: v for k, v in attrs.items() if k not in to_exclude}
|
||||
|
||||
@@ -546,22 +736,18 @@ class LLMConfig:
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(self.generation_class)}
|
||||
else:
|
||||
logger.debug("Overriding default 'generation_config' with '%s'", generation_config)
|
||||
|
||||
self.generation_config = self.generation_class(**generation_config)
|
||||
|
||||
# NOTE: since our subclass is not a dataclass-like, we need to set attr like this.
|
||||
for k, fields in attr.fields_dict(self.__class__).items():
|
||||
if k in attrs and attrs[k] != fields.default:
|
||||
setattr(self, k, attrs[k])
|
||||
else:
|
||||
setattr(self, k, fields.default)
|
||||
attrs = {k: v for k, v in attrs.items() if k not in generation_config}
|
||||
|
||||
# set the remaning attrs to class
|
||||
for k, v in attrs.items():
|
||||
if k not in self.__openllm_attrs__:
|
||||
setattr(self, k, v)
|
||||
extras = set(attrs).difference(set(attr.fields_dict(self.__class__)))
|
||||
if len(extras) > 0:
|
||||
# Set all of the keys in extras to the class
|
||||
for k in extras:
|
||||
setattr(self, k, attrs[k])
|
||||
|
||||
self.__attrs_init__(**{k: v for k, v in attrs.items() if k not in extras})
|
||||
|
||||
def __repr__(self) -> str:
|
||||
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={self.generation_config}"
|
||||
|
||||
@@ -272,14 +272,11 @@ class LLMMetaclass(ABCMeta):
|
||||
if cls_name.startswith("Flax"):
|
||||
implementation = "flax"
|
||||
prefix_class_name_config = cls_name[4:]
|
||||
namespace["__annotations__"].update({"model": "transformers.FlaxPreTrainedModel"})
|
||||
elif cls_name.startswith("TF"):
|
||||
implementation = "tf"
|
||||
prefix_class_name_config = cls_name[2:]
|
||||
namespace["__annotations__"].update({"model": "transformers.TFPreTrainedModel"})
|
||||
else:
|
||||
implementation = "pt"
|
||||
namespace["__annotations__"].update({"model": "transformers.PreTrainedModel"})
|
||||
namespace["__llm_implementation__"] = implementation
|
||||
|
||||
# NOTE: setup config class branch
|
||||
|
||||
@@ -166,6 +166,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
self._cached_grpc: dict[str, t.Any] = {}
|
||||
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||
breakpoint()
|
||||
cmd_name = self.resolve_alias(cmd_name)
|
||||
if ctx.command.name == "start":
|
||||
if cmd_name not in self._cached_http:
|
||||
@@ -465,213 +466,217 @@ output_option = click.option(
|
||||
)
|
||||
|
||||
|
||||
@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm")
|
||||
def cli():
|
||||
"""
|
||||
\b
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
|
||||
def cli_builder():
|
||||
@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm")
|
||||
def cli():
|
||||
"""
|
||||
\b
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
|
||||
|
||||
\b
|
||||
OpenLLM: Your one stop-and-go-solution for serving any Open Large-Language Model
|
||||
\b
|
||||
OpenLLM: Your one stop-and-go-solution for serving any Open Large-Language Model
|
||||
|
||||
- StableLM, Falcon, ChatGLM, Dolly, Flan-T5, and more
|
||||
- StableLM, Falcon, ChatGLM, Dolly, Flan-T5, and more
|
||||
|
||||
\b
|
||||
- Powered by BentoML 🍱 + HuggingFace 🤗
|
||||
"""
|
||||
|
||||
@cli.command(name="version")
|
||||
@output_option
|
||||
@click.pass_context
|
||||
def _(ctx: click.Context, output: t.Literal["json", "pretty", "porcelain"]):
|
||||
"""🚀 OpenLLM version."""
|
||||
from gettext import gettext
|
||||
|
||||
from .__about__ import __version__
|
||||
|
||||
message = gettext("%(prog)s, version %(version)s")
|
||||
version = __version__
|
||||
prog_name = ctx.find_root().info_name
|
||||
|
||||
if output == "pretty":
|
||||
click.echo(message % {"prog": prog_name, "version": version}, color=ctx.color)
|
||||
elif output == "json":
|
||||
click.echo(orjson.dumps({"version": version}, option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
click.echo(version)
|
||||
|
||||
ctx.exit()
|
||||
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start")
|
||||
def _():
|
||||
"""
|
||||
Start any LLM as a REST server.
|
||||
|
||||
$ openllm start <model_name> --<options> ...
|
||||
"""
|
||||
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc")
|
||||
def _():
|
||||
"""
|
||||
Start any LLM as a gRPC server.
|
||||
|
||||
$ openllm start-grpc <model_name> --<options> ...
|
||||
"""
|
||||
|
||||
@cli.command(name="bundle", aliases=["build"])
|
||||
@click.argument(
|
||||
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
|
||||
)
|
||||
@click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].")
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
|
||||
@output_option
|
||||
def _(model_name: str, pretrained: str | None, overwrite: bool, output: t.Literal["json", "pretty", "porcelain"]):
|
||||
"""Package a given models into a Bento.
|
||||
|
||||
$ openllm bundle flan-t5
|
||||
"""
|
||||
from bentoml._internal.configuration import get_quiet_mode
|
||||
|
||||
bento, _previously_built = openllm.build(
|
||||
model_name, __cli__=True, pretrained=pretrained, _overwrite_existing_bento=overwrite
|
||||
)
|
||||
|
||||
if output == "pretty":
|
||||
if not get_quiet_mode():
|
||||
click.echo("\n" + OPENLLM_FIGLET)
|
||||
if not _previously_built:
|
||||
click.secho(f"Successfully built {bento}.", fg="green")
|
||||
else:
|
||||
click.secho(
|
||||
f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
click.secho(
|
||||
"\nPossible next steps:\n\n * Push to BentoCloud with `bentoml push`:\n "
|
||||
+ f"$ bentoml push {bento.tag}",
|
||||
fg="blue",
|
||||
)
|
||||
click.secho(
|
||||
"\n * Containerize your Bento with `bentoml containerize`:\n "
|
||||
+ f"$ bentoml containerize {bento.tag}",
|
||||
fg="blue",
|
||||
)
|
||||
elif output == "json":
|
||||
click.secho(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
click.echo(bento.tag)
|
||||
return bento
|
||||
|
||||
@cli.command(name="models")
|
||||
@output_option
|
||||
def _(output: t.Literal["json", "pretty", "porcelain"]):
|
||||
"""List all supported models."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
failed_initialized: list[tuple[str, Exception]] = []
|
||||
if output == "pretty":
|
||||
import rich
|
||||
import rich.box
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
console = rich.get_console()
|
||||
table = Table(title="Supported LLMs", box=rich.box.SQUARE, show_lines=True)
|
||||
table.add_column("LLM")
|
||||
table.add_column("Description")
|
||||
table.add_column("Variants")
|
||||
for m in models:
|
||||
docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)")
|
||||
try:
|
||||
model = openllm.AutoLLM.for_model(m)
|
||||
table.add_row(m, docs, f"{model.variants}")
|
||||
except Exception as err:
|
||||
failed_initialized.append((m, err))
|
||||
console.print(table)
|
||||
if len(failed_initialized) > 0:
|
||||
console.print(
|
||||
"\n[bold yellow] The following models are supported but failed to initialize:[/bold yellow]\n"
|
||||
)
|
||||
for m, err in failed_initialized:
|
||||
console.print(Text(f"- {m}: ") + Text(f"{err}\n", style="bold red"))
|
||||
elif output == "json":
|
||||
result_json: dict[str, dict[t.Literal["variants", "description"], t.Any]] = {}
|
||||
for m in models:
|
||||
docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)")
|
||||
try:
|
||||
model = openllm.AutoLLM.for_model(m)
|
||||
result_json[m] = {"variants": model.variants, "description": docs}
|
||||
except Exception as err:
|
||||
logger.debug("Exception caught while parsing model %s", m, exc_info=err)
|
||||
result_json[m] = {"variants": None, "description": docs}
|
||||
|
||||
click.secho(orjson.dumps(result_json, option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
click.echo("\n".join(models))
|
||||
sys.exit(0)
|
||||
|
||||
@cli.command(name="download_models")
|
||||
@click.argument(
|
||||
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
|
||||
)
|
||||
@click.option(
|
||||
"--pretrained", type=click.STRING, default=None, help="Optional pretrained name or path to fine-tune weight."
|
||||
)
|
||||
@output_option
|
||||
def _(model_name: str, pretrained: str | None, output: t.Literal["json", "pretty", "porcelain"]):
|
||||
"""Setup LLM interactively.
|
||||
|
||||
Note: This is useful for development and setup for fine-tune.
|
||||
"""
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
env = config.__openllm_env__.get_framework_env()
|
||||
if env == "flax":
|
||||
model = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
|
||||
elif env == "tf":
|
||||
model = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
|
||||
else:
|
||||
model = openllm.AutoLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
|
||||
|
||||
tag = model.make_tag()
|
||||
|
||||
if len(bentoml.models.list(tag)) == 0:
|
||||
if output == "pretty":
|
||||
click.secho(f"{tag} does not exists yet!. Downloading...", nl=True)
|
||||
m = model.ensure_pretrained_exists()
|
||||
click.secho(f"Saved model: {m.tag}")
|
||||
elif output == "json":
|
||||
m = model.ensure_pretrained_exists()
|
||||
click.secho(
|
||||
orjson.dumps(
|
||||
{"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2
|
||||
).decode()
|
||||
)
|
||||
else:
|
||||
m = model.ensure_pretrained_exists()
|
||||
click.secho(m.tag)
|
||||
else:
|
||||
m = model.ensure_pretrained_exists()
|
||||
if output == "pretty":
|
||||
click.secho(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True)
|
||||
elif output == "json":
|
||||
click.secho(
|
||||
orjson.dumps(
|
||||
{"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2
|
||||
).decode()
|
||||
)
|
||||
else:
|
||||
click.echo(m.tag)
|
||||
return m
|
||||
|
||||
\b
|
||||
- Powered by BentoML 🍱 + HuggingFace 🤗
|
||||
"""
|
||||
if psutil.WINDOWS:
|
||||
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
|
||||
|
||||
|
||||
@cli.command()
|
||||
@output_option
|
||||
@click.pass_context
|
||||
def version(ctx: click.Context, output: t.Literal["json", "pretty", "porcelain"]):
|
||||
"""🚀 OpenLLM version."""
|
||||
from gettext import gettext
|
||||
|
||||
from .__about__ import __version__
|
||||
|
||||
message = gettext("%(prog)s, version %(version)s")
|
||||
version = __version__
|
||||
prog_name = ctx.find_root().info_name
|
||||
|
||||
if output == "pretty":
|
||||
click.echo(message % {"prog": prog_name, "version": version}, color=ctx.color)
|
||||
elif output == "json":
|
||||
click.echo(orjson.dumps({"version": version}, option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
click.echo(version)
|
||||
|
||||
ctx.exit()
|
||||
return cli
|
||||
|
||||
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start")
|
||||
def start_cli():
|
||||
"""
|
||||
Start any LLM as a REST server.
|
||||
|
||||
$ openllm start <model_name> --<options> ...
|
||||
"""
|
||||
|
||||
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc")
|
||||
def start_grpc_cli():
|
||||
"""
|
||||
Start any LLM as a gRPC server.
|
||||
|
||||
$ openllm start-grpc <model_name> --<options> ...
|
||||
"""
|
||||
|
||||
|
||||
@cli.command(aliases=["build"])
|
||||
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
|
||||
@click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].")
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
|
||||
@output_option
|
||||
def bundle(model_name: str, pretrained: str | None, overwrite: bool, output: t.Literal["json", "pretty", "porcelain"]):
|
||||
"""Package a given models into a Bento.
|
||||
|
||||
$ openllm bundle flan-t5
|
||||
"""
|
||||
from bentoml._internal.configuration import get_quiet_mode
|
||||
|
||||
bento, _previously_built = openllm.build(
|
||||
model_name, __cli__=True, pretrained=pretrained, _overwrite_existing_bento=overwrite
|
||||
)
|
||||
|
||||
if output == "pretty":
|
||||
if not get_quiet_mode():
|
||||
click.echo("\n" + OPENLLM_FIGLET)
|
||||
if not _previously_built:
|
||||
click.secho(f"Successfully built {bento}.", fg="green")
|
||||
else:
|
||||
click.secho(
|
||||
f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
click.secho(
|
||||
"\nPossible next steps:\n\n * Push to BentoCloud with `bentoml push`:\n "
|
||||
+ f"$ bentoml push {bento.tag}",
|
||||
fg="blue",
|
||||
)
|
||||
click.secho(
|
||||
"\n * Containerize your Bento with `bentoml containerize`:\n "
|
||||
+ f"$ bentoml containerize {bento.tag}",
|
||||
fg="blue",
|
||||
)
|
||||
elif output == "json":
|
||||
click.secho(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
click.echo(bento.tag)
|
||||
return bento
|
||||
|
||||
|
||||
@cli.command()
|
||||
@output_option
|
||||
def models(output: t.Literal["json", "pretty", "porcelain"]):
|
||||
"""List all supported models."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
failed_initialized: list[tuple[str, Exception]] = []
|
||||
if output == "pretty":
|
||||
import rich
|
||||
import rich.box
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
console = rich.get_console()
|
||||
table = Table(title="Supported LLMs", box=rich.box.SQUARE, show_lines=True)
|
||||
table.add_column("LLM")
|
||||
table.add_column("Description")
|
||||
table.add_column("Variants")
|
||||
for m in models:
|
||||
docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)")
|
||||
try:
|
||||
model = openllm.AutoLLM.for_model(m)
|
||||
table.add_row(m, docs, f"{model.variants}")
|
||||
except Exception as err:
|
||||
failed_initialized.append((m, err))
|
||||
console.print(table)
|
||||
if len(failed_initialized) > 0:
|
||||
console.print(
|
||||
"\n[bold yellow] The following models are supported but failed to initialize:[/bold yellow]\n"
|
||||
)
|
||||
for m, err in failed_initialized:
|
||||
console.print(Text(f"- {m}: ") + Text(f"{err}\n", style="bold red"))
|
||||
elif output == "json":
|
||||
result_json: dict[str, dict[t.Literal["variants", "description"], t.Any]] = {}
|
||||
for m in models:
|
||||
docs = inspect.cleandoc(openllm.AutoConfig.for_model(m).__doc__ or "(No description)")
|
||||
try:
|
||||
model = openllm.AutoLLM.for_model(m)
|
||||
result_json[m] = {"variants": model.variants, "description": docs}
|
||||
except Exception as err:
|
||||
logger.debug("Exception caught while parsing model %s", m, exc_info=err)
|
||||
result_json[m] = {"variants": None, "description": docs}
|
||||
|
||||
click.secho(orjson.dumps(result_json, option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
click.echo("\n".join(models))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
|
||||
@click.option(
|
||||
"--pretrained", type=click.STRING, default=None, help="Optional pretrained name or path to fine-tune weight."
|
||||
)
|
||||
@output_option
|
||||
def download_models(model_name: str, pretrained: str | None, output: t.Literal["json", "pretty", "porcelain"]):
|
||||
"""Setup LLM interactively.
|
||||
|
||||
Note: This is useful for development and setup for fine-tune.
|
||||
"""
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
env = config.__openllm_env__.get_framework_env()
|
||||
if env == "flax":
|
||||
model = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
|
||||
elif env == "tf":
|
||||
model = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
|
||||
else:
|
||||
model = openllm.AutoLLM.for_model(model_name, pretrained=pretrained, llm_config=config)
|
||||
|
||||
tag = model.make_tag()
|
||||
|
||||
if len(bentoml.models.list(tag)) == 0:
|
||||
if output == "pretty":
|
||||
click.secho(f"{tag} does not exists yet!. Downloading...", nl=True)
|
||||
m = model.ensure_pretrained_exists()
|
||||
click.secho(f"Saved model: {m.tag}")
|
||||
elif output == "json":
|
||||
m = model.ensure_pretrained_exists()
|
||||
click.secho(
|
||||
orjson.dumps(
|
||||
{"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2
|
||||
).decode()
|
||||
)
|
||||
else:
|
||||
m = model.ensure_pretrained_exists()
|
||||
click.secho(m.tag)
|
||||
else:
|
||||
m = model.ensure_pretrained_exists()
|
||||
if output == "pretty":
|
||||
click.secho(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True)
|
||||
elif output == "json":
|
||||
click.secho(
|
||||
orjson.dumps(
|
||||
{"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2
|
||||
).decode()
|
||||
)
|
||||
else:
|
||||
click.echo(m.tag)
|
||||
return m
|
||||
|
||||
cli = cli_builder()
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -35,5 +35,9 @@ class ForbiddenAttributeError(OpenLLMException):
|
||||
"""Raised when using an _internal field."""
|
||||
|
||||
|
||||
class MissingAnnotationAttributeError(OpenLLMException):
|
||||
"""Raised when a field under openllm.LLMConfig is missing annotations."""
|
||||
|
||||
|
||||
class MissingDependencyError(BaseException):
|
||||
"""Raised when a dependency is missing."""
|
||||
|
||||
@@ -34,7 +34,9 @@ class DollyV2Config(openllm.LLMConfig, default_timeout=3600000):
|
||||
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
|
||||
"""
|
||||
|
||||
return_full_text: bool = False
|
||||
return_full_text: bool = openllm.LLMConfig.Field(
|
||||
False, description="Whether to return the full prompt to the users."
|
||||
)
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
@@ -30,8 +32,16 @@ if t.TYPE_CHECKING:
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
def _default_converter(value: t.Any, env: str | None) -> t.Any:
|
||||
if env is not None:
|
||||
value = os.environ.get(env, value)
|
||||
if value is not None and isinstance(value, str):
|
||||
return eval(value, {"__builtins__": {}}, {})
|
||||
return value
|
||||
|
||||
|
||||
def Field(
|
||||
value: t.Any,
|
||||
default: t.Any = None,
|
||||
*,
|
||||
ge: int | float | None = None,
|
||||
le: int | float | None = None,
|
||||
@@ -52,10 +62,6 @@ def Field(
|
||||
**kwargs: The rest of the arguments are passed to attr.field
|
||||
"""
|
||||
metadata = attrs.pop("metadata", {})
|
||||
default = attrs.pop("default", value)
|
||||
if default is not value:
|
||||
raise ValueError("Either specify 'default=value' or provide 'value' as the only argument")
|
||||
|
||||
if description is None:
|
||||
description = "(No description is available)"
|
||||
metadata["description"] = description
|
||||
@@ -63,6 +69,8 @@ def Field(
|
||||
metadata["env"] = env
|
||||
piped: list[_ValidatorType[t.Any]] = []
|
||||
|
||||
converter = attrs.pop("converter", functools.partial(_default_converter, env=env))
|
||||
|
||||
if ge is not None:
|
||||
piped.append(attr.validators.ge(ge))
|
||||
if le is not None:
|
||||
@@ -77,7 +85,7 @@ def Field(
|
||||
else:
|
||||
_validator = attr.validators.and_(*piped)
|
||||
|
||||
return attr.field(default=default, metadata=metadata, validator=_validator, **attrs)
|
||||
return attr.field(default=default, metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
|
||||
|
||||
def allows_multiple(field_type: t.Any) -> bool:
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import typing as t
|
||||
|
||||
class SourceGenerator(ast.NodeVisitor):
|
||||
def __init__(
|
||||
self,
|
||||
indent_width: str,
|
||||
add_line_information: bool = ...,
|
||||
pretty_string: t.Callable[..., t.Any] = ...,
|
||||
len: t.Callable[[t.Any], int] = ...,
|
||||
isinstance: t.Callable[[t.Any, t.Any], bool] = ...,
|
||||
callable: t.Callable[[t.Any], bool] = ...,
|
||||
) -> None: ...
|
||||
def newline(self, node: ast.AST | None = ..., extra: t.Any = ...) -> None: ...
|
||||
def write(*params: t.Any) -> None: ...
|
||||
@@ -9,6 +9,7 @@ from typing import (
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@@ -42,6 +43,7 @@ __license__: str
|
||||
__copyright__: str
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", bound=type)
|
||||
_P = ParamSpec("_P")
|
||||
_EqOrderType = Union[bool, Callable[[Any], Any]]
|
||||
_ValidatorType = Callable[[Any, "Attribute[_T]", _T], Any]
|
||||
_ConverterType = Callable[[Any], Any]
|
||||
@@ -484,3 +486,17 @@ def get_run_validators() -> bool: ...
|
||||
attributes = ...
|
||||
attr = ...
|
||||
dataclass = ...
|
||||
|
||||
def _make_init(
|
||||
cls: type[AttrsInstance],
|
||||
attrs: tuple[Attribute[_T]],
|
||||
pre_init: bool,
|
||||
post_init: bool,
|
||||
frozen: bool,
|
||||
slots: bool,
|
||||
cache_hash: bool,
|
||||
base_attr_map: dict[str, Any],
|
||||
is_exc: bool,
|
||||
cls_on_setattr: Any,
|
||||
attrs_init: bool,
|
||||
) -> Callable[_P, Any]: ...
|
||||
|
||||
Reference in New Issue
Block a user