refactor(configuration): __config__ and perf

move model_ids and default_id to config class declaration,
cleanup dependencies between config and LLM implementation

lazy load module during LLM creation to llm_post_init

fix post_init hooks to run load_in_mha.

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-06-11 12:53:15 +00:00
parent 17241292da
commit 2e453fb005
22 changed files with 565 additions and 376 deletions

View File

@@ -12,13 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Configuration utilities for OpenLLM. All model configuration will inherit from openllm.configuration_utils.LLMConfig.
Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``.
Note that ``openllm.LLMConfig`` is a subclass of ``pydantic.BaseModel``. It also
has a ``to_click_options`` that returns a list of Click-compatible options for the model.
Such options will then be parsed to ``openllm.__main__.cli``.
Each fields in ``openllm.LLMConfig`` will also automatically generate a environment
Highlight feature: Each fields in ``openllm.LLMConfig`` will also automatically generate a environment
variable based on its name field.
For example, the following config class:
@@ -34,10 +30,14 @@ class FlanT5Config(openllm.LLMConfig):
repetition_penalty = 1.0
```
which generates the environment OPENLLM_FLAN_T5_GENERATION_TEMPERATURE for users to configure temperature
dynamically during serve, ahead-of-serve or per requests.
Refer to ``openllm.LLMConfig`` docstring for more information.
"""
from __future__ import annotations
import inspect
import logging
import os
import typing as t
@@ -53,13 +53,13 @@ from deepmerge.merger import Merger
import openllm
from .exceptions import GpuNotAvailableError, OpenLLMException
from .utils import LazyType, ModelEnv, bentoml_cattr, dantic, lenient_issubclass
from .utils import LazyType, ModelEnv, bentoml_cattr, dantic, first_not_none, lenient_issubclass
if t.TYPE_CHECKING:
import tensorflow as tf
import torch
import transformers
from attr import _CountingAttr, _make_init
from attr import _CountingAttr, _make_init, _make_method, _make_repr
from transformers.generation.beam_constraints import Constraint
from ._types import ClickFunctionWrapper, F, O_co, P
@@ -67,14 +67,16 @@ if t.TYPE_CHECKING:
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
DictStrAny = dict[str, t.Any]
ListStr = list[str]
ItemgetterAny = itemgetter[t.Any]
else:
Constraint = t.Any
ListStr = list
DictStrAny = dict
ItemgetterAny = itemgetter
# NOTE: Using internal API from attr here, since we are actually
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
from attr._make import _CountingAttr, _make_init
from attr._make import _CountingAttr, _make_init, _make_method, _make_repr
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
@@ -93,6 +95,8 @@ config_merger = Merger(
type_conflict_strategies=["override"],
)
_T = t.TypeVar("_T")
@t.overload
def attrs_to_options(
@@ -158,7 +162,9 @@ class GenerationConfig:
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
with some additional validation and environment constructor.
Note that we always set `do_sample=True`
Note that we always set `do_sample=True`. This class is not designed to be used directly, rather
to be used conjunction with LLMConfig. The instance of the generation config can then be accessed
via ``LLMConfig.generation_config``.
"""
# NOTE: parameters for controlling the length of the output
@@ -417,20 +423,6 @@ def _populate_value_from_env_var(
return os.environ.get(key, fallback)
def env_transformers(cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]:
transformed: list[attr.Attribute[t.Any]] = []
for f in fields:
if "env" not in f.metadata:
raise ValueError(
"Make sure to setup the field with 'cls.Field' or 'attr.field(..., metadata={\"env\": \"...\"})'"
)
_from_env = _populate_value_from_env_var(f.metadata["env"])
if _from_env is not None:
f = f.evolve(default=_from_env)
transformed.append(f)
return transformed
# sentinel object for unequivocal object() getattr
_sentinel = object()
@@ -590,7 +582,7 @@ def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConf
)
return transformed
generated_cls = attr.make_class(
_cl = attr.make_class(
cls.__name__.replace("Config", "GenerationConfig"),
[],
bases=(GenerationConfig,),
@@ -599,28 +591,205 @@ def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConf
repr=True,
field_transformer=_evolve_with_base_default,
)
_cl.__doc__ = GenerationConfig.__doc__
return generated_cls
if _has_gen_class:
delattr(cls, "GenerationConfig")
return _cl
# NOTE: This DEFAULT_LLMCONFIG_ATTRS is a way to dynamically generate attr.field
# and will be saved for future use in LLMConfig if we have some shared config.
DEFAULT_LLMCONFIG_ATTRS: tuple[tuple[str, t.Any, str, type[t.Any]], ...] = ()
# NOTE: This is the ModelConfig where we can control the behaviour of the LLM.
# refers to the __openllm_*__ docstring inside LLMConfig for more information.
class ModelConfig(t.TypedDict, total=False):
# NOTE: meta
url: str
requires_gpu: bool
trust_remote_code: bool
requirements: t.Optional[t.List[str]]
# NOTE: naming convention, only name_type is needed
# as the three below it can be determined automatically
name_type: t.Literal["dasherize", "lowercase"]
model_name: str
start_name: str
env: openllm.utils.ModelEnv
# NOTE: serving configuration
timeout: int
workers_per_resource: t.Union[int, float]
# NOTE: use t.Required once we drop 3.8 support
default_id: str
model_ids: list[str]
def _gen_default_model_config(cls: type[LLMConfig]) -> ModelConfig:
"""Generate the default ModelConfig and delete __config__ in LLMConfig
if defined inplace."""
_internal_config = t.cast(ModelConfig, getattr(cls, "__config__", {}))
default_id = _internal_config.get("default_id", None)
if default_id is None:
raise RuntimeError("'default_id' is required under '__config__'.")
model_ids = _internal_config.get("model_ids", None)
if model_ids is None:
raise RuntimeError("'model_ids' is required under '__config__'.")
def _first_not_null(key: str, default: _T) -> _T:
return first_not_none(_internal_config.get(key), default=default)
llm_config_striped = cls.__name__.replace("Config", "")
name_type: t.Literal["dasherize", "lowercase"] = _first_not_null("name_type", "dasherize")
if name_type == "dasherize":
default_model_name = inflection.underscore(llm_config_striped)
default_start_name = inflection.dasherize(default_model_name)
else:
default_model_name = llm_config_striped.lower()
default_start_name = default_model_name
model_name = _first_not_null("model_name", default_model_name)
_config = ModelConfig(
name_type=name_type,
model_name=model_name,
default_id=default_id,
model_ids=model_ids,
start_name=_first_not_null("start_name", default_start_name),
url=_first_not_null("url", "(not provided)"),
requires_gpu=_first_not_null("requires_gpu", False),
trust_remote_code=_first_not_null("trust_remote_code", False),
requirements=_first_not_null("requirements", ListStr()),
env=_first_not_null("env", openllm.utils.ModelEnv(model_name)),
timeout=_first_not_null("timeout", 3600),
workers_per_resource=_first_not_null("workers_per_resource", 1),
)
if hasattr(cls, "__config__"):
delattr(cls, "__config__")
return _config
def _generate_unique_filename(cls: type[t.Any], func_name: str):
return f"<LLMConfig generated {func_name} {cls.__module__}." f"{getattr(cls, '__qualname__', cls.__name__)}>"
def _setattr_class(attr_name: str, value_var: t.Any):
"""
Use the builtin setattr to set *attr_name* to *value_var*.
We can't use the cached object.__setattr__ since we are setting
attributes to a class.
"""
return f"setattr(cls, '{attr_name}', {value_var})"
@t.overload
def _make_assignment_with_prefix_script(cls: type[LLMConfig], attributes: ModelConfig) -> t.Callable[..., None]:
...
@t.overload
def _make_assignment_with_prefix_script(cls: type[LLMConfig], attributes: dict[str, t.Any]) -> t.Callable[..., None]:
...
def _make_assignment_with_prefix_script(cls: type[LLMConfig], attributes: t.Any) -> t.Callable[..., None]:
"""Generate the assignment script with prefix attributes __openllm_<value>__"""
args: list[str] = []
globs: dict[str, t.Any] = {"cls": cls, "attr_dict": attributes}
annotations: dict[str, t.Any] = {"return": None}
# Circumvent __setattr__ descriptor to save one lookup per assigment
lines: list[str] = []
for attr_name in attributes:
arg_name = f"__openllm_{inflection.underscore(attr_name)}__"
args.append(f"{attr_name}=attr_dict['{attr_name}']")
lines.append(_setattr_class(arg_name, attr_name))
annotations[attr_name] = type(attributes[attr_name])
script = "def __assign_attr(cls, %s):\n %s\n" % (", ".join(args), "\n ".join(lines) if lines else "pass")
assign_method = _make_method(
"__assign_attr",
script,
_generate_unique_filename(cls, "__assign_attr"),
globs,
)
assign_method.__annotations__ = annotations
return assign_method
@attr.define
class LLMConfig:
Field = dantic.Field
"""Field is a alias to the internal dantic utilities to easily create
attrs.fields with pydantic-compatible interface.
"""
``openllm.LLMConfig`` is somewhat a hybrid combination between the performance of `attrs` with the
easy-to-use interface that pydantic offer. It lives in between where it allows users to quickly formulate
a LLMConfig for any LLM without worrying too much about performance. It does a few things:
- Automatic environment conversion: Each fields will automatically be provisioned with an environment
variable, make it easy to work with ahead-of-time or during serving time
- Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API,
i.e: ``model_construct_env``, ``to_generation_config``, ``to_click_options``
- Automatic CLI generation: It can identify each fields and convert it to compatible Click options.
This means developers can use any of the LLMConfig to create CLI with compatible-Python
CLI library (click, typer, ...)
> Internally, LLMConfig is an attrs class. All subclass of LLMConfig contains "attrs-like" features,
> which means LLMConfig will actually generate subclass to have attrs-compatible API, so that the subclass
> can be written as any normal Python class.
To directly configure GenerationConfig for any given LLM, create a GenerationConfig under the subclass:
```python
class FlanT5Config(openllm.LLMConfig):
class GenerationConfig:
temperature: float = 0.75
max_new_tokens: int = 3000
top_k: int = 50
top_p: float = 0.4
repetition_penalty = 1.0
```
By doing so, openllm.LLMConfig will create a compatible GenerationConfig attrs class that can be converted
to ``transformers.GenerationConfig``. These attribute can be accessed via ``LLMConfig.generation_config``.
By default, all LLMConfig has a __config__ that contains a default value. If any LLM requires customization,
provide a ``__config__`` under the class declaration:
```python
class FalconConfig(openllm.LLMConfig):
__config__ = {"trust_remote_code": True, "default_timeout": 3600000}
```
Note that ``model_name``, ``start_name``, and ``env`` is optional under ``__config__``. If set, then OpenLLM
will respect that option for start and other components within the library.
"""
if t.TYPE_CHECKING:
# The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
Field = dantic.Field
"""Field is a alias to the internal dantic utilities to easily create
attrs.fields with pydantic-compatible interface. For example:
```python
class MyModelConfig(openllm.LLMConfig):
field1 = openllm.LLMConfig.Field(...)
```
"""
# NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
if t.TYPE_CHECKING:
# NOTE: Internal attributes that should only be used by OpenLLM. Users usually shouldn't
# concern any of these.
def __attrs_init__(self, **attrs: t.Any):
"""Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract."""
__config__: ModelConfig | None = None
"""Internal configuration for this LLM model. Each of the field in here will be populated
and prefixed with __openllm_<value>__"""
__attrs_attrs__: tuple[attr.Attribute[t.Any], ...] = tuple()
"""Since we are writing our own __init_subclass__, which is an alternative way for __prepare__,
we want openllm.LLMConfig to be attrs-like dataclass that has pydantic-like interface.
@@ -630,8 +799,15 @@ class LLMConfig:
__openllm_attrs__: tuple[str, ...] = tuple()
"""Internal attribute tracking to store converted LLMConfig attributes to correct attrs"""
__openllm_timeout__: int = 3600
"""The default timeout to be set for this given LLM."""
__openllm_hints__: dict[str, t.Any] = Field(None, init=False)
"""An internal cache of resolved types for this LLMConfig."""
__openllm_accepted_keys__: set[str] = Field(None, init=False)
"""The accepted keys for this LLMConfig."""
# NOTE: The following will be populated from __config__
__openllm_url__: str = Field(None, init=False)
"""The resolved url for this LLMConfig."""
__openllm_requires_gpu__: bool = False
"""Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU."""
@@ -639,6 +815,13 @@ class LLMConfig:
__openllm_trust_remote_code__: bool = False
"""Whether to always trust remote code"""
__openllm_requirements__: list[str] | None = None
"""The default PyPI requirements needed to run this given LLM. By default, we will depend on
bentoml, torch, transformers."""
__openllm_env__: openllm.utils.ModelEnv = Field(None, init=False)
"""A ModelEnv instance for this LLMConfig."""
__openllm_model_name__: str = ""
"""The normalized version of __openllm_start_name__, determined by __openllm_name_type__"""
@@ -649,21 +832,8 @@ class LLMConfig:
"""the default name typed for this model. "dasherize" will convert the name to lowercase and
replace spaces with dashes. "lowercase" will convert the name to lowercase."""
__openllm_env__: openllm.utils.ModelEnv = Field(None, init=False)
"""A ModelEnv instance for this LLMConfig."""
__openllm_hints__: dict[str, t.Any] = Field(None, init=False)
"""An internal cache of resolved types for this LLMConfig."""
__openllm_url__: str = Field(None, init=False)
"""The resolved url for this LLMConfig."""
__openllm_accepted_keys__: set[str] = Field(None, init=False)
"""The accepted keys for this LLMConfig."""
__openllm_requirements__: list[str] | None = None
"""The default PyPI requirements needed to run this given LLM. By default, we will depend on
bentoml, torch, transformers."""
__openllm_timeout__: int = 3600
"""The default timeout to be set for this given LLM."""
__openllm_workers_per_resource__: int | float = 1
"""The default number of workers per resource. By default, we will use 1 worker per resource.
@@ -671,6 +841,18 @@ class LLMConfig:
https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy for more details.
"""
__openllm_default_id__: str = Field(None)
"""Return the default model to use when using 'openllm start <model_id>'.
This could be one of the keys in 'self.model_ids' or custom users model."""
__openllm_model_ids__: list[str] = Field(None)
"""A list of supported pretrained models tag for this given runnable.
For example:
For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base",
"google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"]
"""
GenerationConfig: type = type
"""Users can override this subclass of any given LLMConfig to provide GenerationConfig
default value. For example:
@@ -689,35 +871,9 @@ class LLMConfig:
"""The result generated GenerationConfig class for this LLMConfig. This will be used
to create the generation_config argument that can be used throughout the lifecycle."""
def __init_subclass__(
cls,
*,
name_type: t.Literal["dasherize", "lowercase"] = "dasherize",
default_timeout: int | None = None,
trust_remote_code: bool = False,
requires_gpu: bool = False,
url: str | None = None,
requirements: list[str] | None = None,
workers_per_resource: int | float = 1,
):
if name_type == "dasherize":
model_name = inflection.underscore(cls.__name__.replace("Config", ""))
start_name = inflection.dasherize(model_name)
else:
model_name = cls.__name__.replace("Config", "").lower()
start_name = model_name
cls.__openllm_name_type__ = name_type
cls.__openllm_requires_gpu__ = requires_gpu
cls.__openllm_timeout__ = default_timeout or 3600
cls.__openllm_trust_remote_code__ = trust_remote_code
cls.__openllm_model_name__ = model_name
cls.__openllm_start_name__ = start_name
cls.__openllm_env__ = openllm.utils.ModelEnv(model_name)
cls.__openllm_url__ = url or "(not set)"
cls.__openllm_requirements__ = requirements
cls.__openllm_workers_per_resource__ = workers_per_resource
def __init_subclass__(cls):
# NOTE: auto assignment attributes generated from __config__
_make_assignment_with_prefix_script(cls, _gen_default_model_config(cls))(cls)
# NOTE: Since we want to enable a pydantic-like experience
# this means we will have to hide the attr abstraction, and generate
@@ -727,7 +883,7 @@ class LLMConfig:
cd = cls.__dict__
def field_env_key(key: str) -> str:
return f"OPENLLM_{model_name.upper()}_{key.upper()}"
return f"OPENLLM_{cls.__openllm_model_name__.upper()}_{key.upper()}"
ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)}
ca_list: list[tuple[str, _CountingAttr[t.Any]]] = []
@@ -763,22 +919,14 @@ class LLMConfig:
gen_attribute = gen_attribute.evolve(metadata=metadata)
own_attrs.append(gen_attribute)
# This is to handle subclass of subclass of all provided LLMConfig.
# refer to attrs for the original implementation.
base_attrs, base_attr_map = _collect_base_attrs(cls, {a.name for a in own_attrs})
# __openllm_attrs__ is a tracking tuple[attr.Attribute[t.Any]]
# that we construct ourself.
cls.__openllm_attrs__ = tuple(a.name for a in own_attrs)
# NOTE: Enable some default attributes that can be shared across all LLMConfig
if len(DEFAULT_LLMCONFIG_ATTRS) > 0:
# NOTE: update the hints for default variables we dynamically added.
hints.update({k: hints for k, _, _, hints in DEFAULT_LLMCONFIG_ATTRS})
base_attrs = [
attr.Attribute.from_counting_attr(k, cls.Field(default, env=field_env_key(k), description=docs), hints)
for k, default, docs, hints in DEFAULT_LLMCONFIG_ATTRS
if k not in cls.__openllm_attrs__
] + base_attrs
attrs: list[attr.Attribute[t.Any]] = own_attrs + base_attrs
# Mandatory vs non-mandatory attr order only matters when they are part of
@@ -810,38 +958,33 @@ class LLMConfig:
_has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
AttrsTuple = _make_attr_tuple_class(cls.__name__, cls.__openllm_attrs__)
# NOTE: the protocol for attrs-decorated class
cls.__attrs_attrs__ = AttrsTuple(attrs)
# NOTE: generate a __attrs_init__ for the subclass
cls.__attrs_init__ = _add_method_dunders(
cls,
_make_init(
cls,
AttrsTuple(attrs),
_has_pre_init,
_has_post_init,
False,
True,
True,
base_attr_map,
False,
None,
attrs_init=True,
cls, # cls (the attrs-decorated class)
cls.__attrs_attrs__, # tuple of attr.Attribute of cls
_has_pre_init, # pre_init
_has_post_init, # post_init
False, # frozen
True, # slots
True, # cache_hash
base_attr_map, # base_attr_map
False, # is_exc (check if it is exception)
None, # cls_on_setattr (essentially attr.setters)
attrs_init=True, # whether to create __attrs_init__ instead of __init__
),
)
cls.__attrs_attrs__ = AttrsTuple(attrs)
# NOTE: Finally, set the generation_class for this given config.
cls.generation_class = _make_internal_generation_class(cls)
hints.update(t.get_type_hints(cls.generation_class))
cls.__openllm_hints__ = hints
cls.__openllm_accepted_keys__ = set(cls.__openllm_attrs__) | set(attr.fields_dict(cls.generation_class))
@property
def name_type(self) -> t.Literal["dasherize", "lowercase"]:
return self.__openllm_name_type__
def __init__(
self,
*,
@@ -849,33 +992,56 @@ class LLMConfig:
__openllm_extras__: dict[str, t.Any] | None = None,
**attrs: t.Any,
):
self.__openllm_extras__ = openllm.utils.first_not_none(__openllm_extras__, default={})
# create a copy of the list of keys as cache
_cached_keys = tuple(attrs.keys())
self.__openllm_extras__ = first_not_none(__openllm_extras__, default={})
config_merger.merge(
self.__openllm_extras__, {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__}
)
attrs = {k: v for k, v in attrs.items() if k not in self.__openllm_extras__ and v is not None}
for k in _cached_keys:
if k in self.__openllm_extras__ or attrs.get(k) is None:
del attrs[k]
_cached_keys = tuple(k for k in _cached_keys if k in attrs)
_generation_cl_dict = attr.fields_dict(self.generation_class)
if generation_config is None:
generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(self.generation_class)}
generation_config = {k: v for k, v in attrs.items() if k in _generation_cl_dict}
else:
generation_keys = {k for k in attrs if k in _generation_cl_dict}
if len(generation_keys) > 0:
logger.warning(
"When 'generation_config' is passed, \
the following keys are ignored and won't be used: %s. If you wish to use those values, \
pass it into 'generation_config'.",
", ".join(generation_keys),
)
for k in _cached_keys:
if k in generation_keys:
del attrs[k]
_cached_keys = tuple(k for k in _cached_keys if k in attrs)
self.generation_config = self.generation_class(**generation_config)
base_attrs: tuple[attr.Attribute[t.Any], ...] = attr.fields(self.__class__)
base_attrs += (
attr.Attribute.from_counting_attr(
name="generation_config",
ca=dantic.Field(
self.generation_config, description=inspect.cleandoc(self.generation_class.__doc__ or "")
),
type=self.generation_class,
),
)
# mk the class __repr__ function with the updated fields.
self.__class__.__repr__ = _add_method_dunders(self.__class__, _make_repr(base_attrs, None, self.__class__))
attrs = {k: v for k, v in attrs.items() if k not in generation_config}
for k in _cached_keys:
if k in generation_config:
del attrs[k]
self.__attrs_init__(**{k: v for k, v in attrs.items() if k in self.__openllm_attrs__})
# The rest update to extras
attrs = {k: v for k, v in attrs.items() if k not in self.__openllm_attrs__}
config_merger.merge(self.__openllm_extras__, attrs)
def __repr__(self) -> str:
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={repr(self.generation_class())}"
if len(self.__openllm_attrs__) > 0:
bases += ", " + ", ".join([f"{k}={getattr(self, k)}" for k in self.__openllm_attrs__]) + ")"
else:
bases += ")"
return bases
# The rest of attrs should only be the attributes to be passed to __attrs_init__
self.__attrs_init__(**attrs)
def __getattr__(self, item: str) -> t.Any:
if hasattr(self.generation_config, item):
@@ -975,43 +1141,46 @@ class LLMConfig:
config = transformers.GenerationConfig(**bentoml_cattr.unstructure(self.generation_config))
return config.to_dict() if return_as_dict else config
@classmethod
@t.overload
def to_click_options(
self, f: t.Callable[..., openllm.LLMConfig]
cls, f: t.Callable[..., openllm.LLMConfig]
) -> F[P, ClickFunctionWrapper[..., openllm.LLMConfig]]:
...
@classmethod
@t.overload
def to_click_options(self, f: t.Callable[P, O_co]) -> F[P, ClickFunctionWrapper[P, O_co]]:
def to_click_options(cls, f: t.Callable[P, O_co]) -> F[P, ClickFunctionWrapper[P, O_co]]:
...
def to_click_options(self, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
@classmethod
def to_click_options(cls, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
"""
Convert current model to click options. This can be used as a decorator for click commands.
Note that the identifier for all LLMConfig will be prefixed with '<model_name>_*', and the generation config
will be prefixed with '<model_name>_generation_*'.
"""
for name, field in attr.fields_dict(self.generation_class).items():
ty = self.__openllm_hints__.get(name)
for name, field in attr.fields_dict(cls.generation_class).items():
ty = cls.__openllm_hints__.get(name)
if t.get_origin(ty) is t.Union:
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
continue
f = attrs_to_options(name, field, self.__openllm_model_name__, typ=ty, suffix_generation=True)(f)
f = optgroup.group(f"{self.generation_class.__name__} generation options")(f)
f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f)
f = optgroup.group(f"{cls.generation_class.__name__} generation options")(f)
if len(self.__class__.__openllm_attrs__) == 0:
if len(cls.__openllm_attrs__) == 0:
# NOTE: in this case, the function is already a ClickFunctionWrapper
# hence the casting
return f
for name, field in attr.fields_dict(self.__class__).items():
ty = self.__openllm_hints__.get(name)
for name, field in attr.fields_dict(cls).items():
ty = cls.__openllm_hints__.get(name)
if t.get_origin(ty) is t.Union:
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
continue
f = attrs_to_options(name, field, self.__openllm_model_name__, typ=ty)(f)
f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty)(f)
return optgroup.group(f"{self.__class__.__name__} options")(f)
return optgroup.group(f"{cls.__name__} options")(f)
bentoml_cattr.register_unstructure_hook_factory(

View File

@@ -32,6 +32,7 @@ from bentoml.types import ModelSignature, ModelSignatureDict
import openllm
from .exceptions import ForbiddenAttributeError, OpenLLMException
from .models.auto import AutoConfig
from .utils import ENV_VARS_TRUE_VALUES, LazyLoader, bentoml_cattr
if t.TYPE_CHECKING:
@@ -180,37 +181,19 @@ def import_model(
torch.cuda.empty_cache()
_required_namespace = {"default_id", "model_ids"}
_reserved_namespace = _required_namespace | {
"config_class",
"model",
"tokenizer",
"import_kwargs",
}
_reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"}
class LLMInterface(ABC):
"""This defines the loose contract for all openllm.LLM implementations."""
default_id: str
"""Return the default model to use when using 'openllm start <model_id>'.
This could be one of the keys in 'self.model_ids' or custom users model."""
model_ids: list[str]
"""A list of supported pretrained models tag for this given runnable.
For example:
For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base",
"google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"]
"""
config_class: type[openllm.LLMConfig]
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
import_kwargs: dict[str, t.Any] | None = None
"""The default import kwargs to used when importing the model.
This will be passed into 'openllm.LLM.import_model'."""
@property
def import_kwargs(self) -> dict[str, t.Any] | None:
"""The default import kwargs to used when importing the model.
This will be passed into 'openllm.LLM.import_model'."""
@abstractmethod
def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any:
@@ -265,34 +248,29 @@ class LLMInterface(ABC):
raise NotImplementedError
def _default_post_init(self: LLM):
# load_in_mha: Whether to apply BetterTransformer (or Torch MultiHeadAttention) during inference load.
# See https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/
# for more information.
# NOTE: set a default variable to transform to BetterTransformer by default for inference
self.load_in_mha = (
os.environ.get(self.config_class.__openllm_env__.bettertransformer, str(False)).upper() in ENV_VARS_TRUE_VALUES
)
if self.config_class.__openllm_requires_gpu__:
# For all models that requires GPU, no need to offload it to BetterTransformer
# use bitsandbytes instead
self.load_in_mha = False
class LLMMetaclass(ABCMeta):
def __new__(
mcls, cls_name: str, bases: tuple[type[t.Any], ...], namespace: dict[str, t.Any], **attrs: t.Any
) -> type:
"""Metaclass for creating a LLM."""
if LLMInterface not in bases: # only actual openllm.LLM should hit this branch.
if "__annotations__" not in namespace:
annotations_dict: dict[str, t.Any] = {}
namespace["__annotations__"] = annotations_dict
# NOTE: check for required attributes
if "__openllm_internal__" not in namespace:
_required_namespace.add("config_class")
for k in _required_namespace:
if k not in namespace:
raise RuntimeError(f"Missing required key '{k}'. Make sure to define it within the LLM subclass.")
# NOTE: set implementation branch
prefix_class_name_config = cls_name
if "__llm_implementation__" in namespace:
raise RuntimeError(
f"""\
__llm_implementation__ should not be set directly. Instead make sure that your class
name follows the convention prefix:
- For Tensorflow implementation: 'TF{cls_name}'
- For Flax implementation: 'Flax{cls_name}'
- For PyTorch implementation: '{cls_name}'"""
)
if cls_name.startswith("Flax"):
implementation = "flax"
prefix_class_name_config = cls_name[4:]
@@ -302,39 +280,37 @@ class LLMMetaclass(ABCMeta):
else:
implementation = "pt"
namespace["__llm_implementation__"] = implementation
config_class = AutoConfig.infer_class_from_name(prefix_class_name_config)
# NOTE: setup config class branch
if "__openllm_internal__" in namespace:
# NOTE: we will automatically find the subclass for this given config class
if "config_class" not in namespace:
# this branch we will automatically get the class
namespace["config_class"] = getattr(openllm, f"{prefix_class_name_config}Config")
namespace["config_class"] = config_class
else:
logger.debug(f"Using config class {namespace['config_class']} for {cls_name}.")
# NOTE: check for required attributes
else:
if "config_class" not in namespace:
raise RuntimeError(
"Missing required key 'config_class'. Make sure to define it within the LLM subclass."
)
config_class: type[openllm.LLMConfig] = namespace["config_class"]
# NOTE: the llm_post_init branch
if "llm_post_init" in namespace:
original_llm_post_init = namespace["llm_post_init"]
# NOTE: update the annotations for self.config
namespace["__annotations__"]["config"] = t.get_type_hints(config_class)
def wrapped_llm_post_init(self: LLM) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
_default_post_init(self)
original_llm_post_init(self)
for key in ("__openllm_start_name__", "__openllm_requires_gpu__"):
namespace[key] = getattr(config_class, key)
# load_in_mha: Whether to apply BetterTransformer (or Torch MultiHeadAttention) during inference load.
# See https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/
# for more information.
# NOTE: set a default variable to transform to BetterTransformer by default for inference
if "load_in_mha" not in namespace:
load_in_mha = (
os.environ.get(config_class().__openllm_env__.bettertransformer, str(False)).upper()
in ENV_VARS_TRUE_VALUES
)
namespace["load_in_mha"] = load_in_mha
if namespace["__openllm_requires_gpu__"]:
# For all models that requires GPU, no need to offload it to BetterTransformer
# use bitsandbytes instead
namespace["load_in_mha"] = False
namespace["llm_post_init"] = wrapped_llm_post_init
else:
namespace["llm_post_init"] = _default_post_init
# NOTE: import_model branch
if "import_model" not in namespace:
@@ -343,16 +319,10 @@ class LLMMetaclass(ABCMeta):
else:
logger.debug("Using custom 'import_model' for %s", cls_name)
# NOTE: populate with default cache.
namespace.update({k: None for k in ("__llm_bentomodel__", "__llm_model__", "__llm_tokenizer__")})
cls: type[LLM] = super().__new__(t.cast("type[type[LLM]]", mcls), cls_name, bases, namespace, **attrs)
cls.__openllm_post_init__ = None if cls.llm_post_init is LLMInterface.llm_post_init else cls.llm_post_init
cls.__openllm_custom_load__ = None if cls.load_model is LLMInterface.load_model else cls.load_model
if getattr(cls, "config_class") is None:
raise RuntimeError(f"'config_class' must be defined for '{cls.__name__}'")
return cls
else:
# the LLM class itself being created, no need to setup
@@ -362,23 +332,19 @@ class LLMMetaclass(ABCMeta):
class LLM(LLMInterface, metaclass=LLMMetaclass):
if t.TYPE_CHECKING:
# NOTE: the following will be populated by metaclass
__llm_bentomodel__: bentoml.Model | None = None
__llm_model__: LLMModel | None = None
__llm_tokenizer__: LLMTokenizer | None = None
__llm_implementation__: t.Literal["pt", "tf", "flax"]
__openllm_start_name__: str
__openllm_requires_gpu__: bool
__openllm_post_init__: t.Callable[[t.Self], None] | None
__openllm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None
load_in_mha: bool
_llm_attrs: dict[str, t.Any]
_llm_args: tuple[t.Any, ...]
# NOTE: the following is the similar interface to HuggingFace pretrained protocol.
__llm_bentomodel__: bentoml.Model | None = None
__llm_model__: LLMModel | None = None
__llm_tokenizer__: LLMTokenizer | None = None
# NOTE: the following is the similar interface to HuggingFace pretrained protocol.
@classmethod
def from_pretrained(
cls, model_id: str | None = None, llm_config: openllm.LLMConfig | None = None, *args: t.Any, **attrs: t.Any
@@ -446,15 +412,33 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
Note that this tag will be generated based on `self.default_id` or the given `pretrained` kwds.
passed from the __init__ constructor.
``llm_post_init`` can also be implemented if you need to do any
additional initialization after everything is setup.
``llm_post_init`` can also be implemented if you need to do any additional
initialization after everything is setup.
Note: If you need to implement a custom `load_model`, the following is an example from Falcon implementation:
```python
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:
torch_dtype = attrs.pop("torch_dtype", torch.bfloat16)
device_map = attrs.pop("device_map", "auto")
_ref = bentoml.transformers.get(tag)
model = bentoml.transformers.load_model(_ref, device_map=device_map, torch_dtype=torch_dtype, **attrs)
return transformers.pipeline("text-generation", model=model, tokenizer=_ref.custom_objects["tokenizer"])
```
Args:
model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used.
llm_config: The config to use for this LLM. Defaults to None. If not passed, we will use 'self.config_class'
to construct default configuration.
llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM
will use `config_class` to construct default configuration.
*args: The args to be passed to the model.
**attrs: The kwargs to be passed to the model.
The following are optional:
openllm_model_version: version for this `model_id`. By default, users can ignore this if using pretrained
weights as OpenLLM will use the commit_hash of given model_id.
However, if `model_id` is a path, this argument is recomended to include.
"""
if llm_config is not None:
@@ -466,10 +450,8 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
attrs = self.config.__openllm_extras__
if model_id is None:
model_id = os.environ.get(self.config.__openllm_env__.model_id, None)
if not model_id:
assert self.default_id, "A default model is required for any LLM."
model_id = self.default_id
model_id = os.environ.get(self.config.__openllm_env__.model_id, self.config.__openllm_default_id__)
assert model_id is not None
# NOTE: This is the actual given path or pretrained weight for this LLM.
self._model_id = model_id
@@ -479,7 +461,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
self._llm_attrs = attrs
if self.__openllm_post_init__:
self.__openllm_post_init__(self)
self.llm_post_init()
def __setattr__(self, attr: str, value: t.Any):
if attr in _reserved_namespace:
@@ -500,7 +482,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
def identifying_params(self) -> dict[str, t.Any]:
return {
"configuration": self.config.model_dump_json().decode(),
"model_ids": orjson.dumps(self.model_ids).decode(),
"model_ids": orjson.dumps(self.config.__openllm_model_ids__).decode(),
}
@t.overload
@@ -643,7 +625,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
return self._bentomodel.tag
@property
def model(self) -> LLMModel | torch.nn.Module:
def model(self) -> LLMModel:
"""The model to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it."""
# Run check for GPU
trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
@@ -818,7 +800,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
(_Runnable,),
{
"SUPPORTED_RESOURCES": ("nvidia.com/gpu", "cpu")
if self.__openllm_requires_gpu__
if self.config.__openllm_requires_gpu__
else ("nvidia.com/gpu",),
"llm_type": self.llm_type,
"identifying_params": self.identifying_params,

View File

@@ -76,16 +76,17 @@ def construct_python_options(llm: openllm.LLM, llm_fs: FS) -> PythonOptions:
# first, then proceed to install everything inside the wheels/ folder.
packages: list[str] = ["openllm"]
ModelEnv = openllm.utils.ModelEnv(llm.__openllm_start_name__)
if llm.config.__openllm_requirements__ is not None:
packages.extend(llm.config.__openllm_requirements__)
if not (str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false"):
packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
to_use_framework = ModelEnv.get_framework_env()
to_use_framework = llm.config.__openllm_env__.get_framework_env()
if to_use_framework == "flax":
assert utils.is_flax_available(), f"Flax is not available, while {ModelEnv.framework} is set to 'flax'"
assert (
utils.is_flax_available()
), f"Flax is not available, while {llm.config.__openllm_env__.framework} is set to 'flax'"
packages.extend(
[
f"flax>={importlib.metadata.version('flax')}",
@@ -94,7 +95,9 @@ def construct_python_options(llm: openllm.LLM, llm_fs: FS) -> PythonOptions:
]
)
elif to_use_framework == "tf":
assert utils.is_tf_available(), f"TensorFlow is not available, while {ModelEnv.framework} is set to 'tf'"
assert (
utils.is_tf_available()
), f"TensorFlow is not available, while {llm.config.__openllm_env__.framework} is set to 'tf'"
candidates = (
"tensorflow",
"tensorflow-cpu",
@@ -128,7 +131,6 @@ def construct_python_options(llm: openllm.LLM, llm_fs: FS) -> PythonOptions:
def construct_docker_options(llm: openllm.LLM, _: FS) -> DockerOptions:
ModelEnv = openllm.utils.ModelEnv(llm.__openllm_start_name__)
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options += (
" "
@@ -141,7 +143,7 @@ def construct_docker_options(llm: openllm.LLM, _: FS) -> DockerOptions:
return DockerOptions(
cuda_version="11.6", # NOTE: Torch 2.0 currently only support 11.6 as the latest CUDA version
env={
ModelEnv.framework: ModelEnv.get_framework_env(),
llm.config.__openllm_env__.framework: llm.config.__openllm_env__.get_framework_env(),
"OPENLLM_MODEL": llm.config.__openllm_model_name__,
"BENTOML_DEBUG": str(get_debug_mode()),
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options,
@@ -200,12 +202,12 @@ def build(model_name: str, *, __cli__: bool = False, **attrs: t.Any) -> tuple[be
raise bentoml.exceptions.NotFound("Overwriting previously saved Bento.")
_previously_built = True
except bentoml.exceptions.NotFound:
logger.info("Building Bento for LLM '%s'", llm.__openllm_start_name__)
logger.info("Building Bento for LLM '%s'", llm.config.__openllm_start_name__)
bento = bentoml.bentos.build(
f"{service_name}:svc",
name=bento_tag.name,
labels=labels,
description=f"OpenLLM service for {llm.__openllm_start_name__}",
description=f"OpenLLM service for {llm.config.__openllm_start_name__}",
include=[
f for f in llm_fs.walk.files(filter=["*.py"])
], # NOTE: By default, we are using _service.py as the default service, for now.

View File

@@ -384,12 +384,10 @@ def start_model_command(
ModelEnv = openllm.utils.ModelEnv(model_name)
llm_config = openllm.AutoConfig.for_model(model_name)
for_doc = openllm.AutoLLM.for_model(model_name)
docstring = f"""\
{ModelEnv.start_docstring}
\b
Available model_id(s) to use with '{model_name}' are: {for_doc.model_ids} [default: {for_doc.default_id}]
Tip: One can pass one of the aforementioned to '--model-id' to use other pretrained weights.
Available model_id(s): {llm_config.__openllm_model_ids__} [default: {llm_config.__openllm_default_id__}]
"""
command_attrs: dict[str, t.Any] = {
"name": ModelEnv.model_name,
@@ -399,7 +397,7 @@ Tip: One can pass one of the aforementioned to '--model-id' to use other pretrai
}
aliases: list[str] = []
if llm_config.name_type == "dasherize":
if llm_config.__openllm_name_type__ == "dasherize":
aliases.append(llm_config.__openllm_start_name__)
command_attrs["aliases"] = aliases if len(aliases) > 0 else None
@@ -429,8 +427,9 @@ Tip: One can pass one of the aforementioned to '--model-id' to use other pretrai
@group.command(**command_attrs)
@llm_config.to_click_options
@parse_serve_args(_serve_grpc)
@click.option("--server-timeout", type=int, default=3600, help="Server timeout in seconds")
@model_id_option
@cog.optgroup.group("General LLM Options")
@cog.optgroup.option("--server-timeout", type=int, default=3600, help="Server timeout in seconds")
@model_id_option(cog.optgroup, model_env=ModelEnv)
@click.option(
"--device",
type=tuple,
@@ -439,6 +438,7 @@ Tip: One can pass one of the aforementioned to '--model-id' to use other pretrai
envvar="CUDA_VISIBLE_DEVICES",
callback=parse_device_callback,
help=f"Assign GPU devices (if available) for {model_name}.",
show_envvar=True,
)
def model_start(
server_timeout: int,
@@ -575,12 +575,20 @@ output_option = click.option(
default="pretty",
help="Showing output type. Default to 'pretty'",
)
model_id_option = click.option(
"--model-id",
type=click.STRING,
default=None,
help="Optional model_id name or path for (fine-tune) weight.",
)
def model_id_option(factory: t.Any, model_env: openllm.utils.ModelEnv | None = None):
envvar = None
if model_env is not None:
envvar = model_env.model_id
return factory.option(
"--model-id",
type=click.STRING,
default=None,
help="Optional model_id name or path for (fine-tune) weight.",
envvar=envvar,
show_envvar=True if envvar is not None else False,
)
def cli_factory() -> click.Group:
@@ -627,7 +635,7 @@ def cli_factory() -> click.Group:
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@model_id_option
@model_id_option(click)
@output_option
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
def build(model_name: str, model_id: str | None, overwrite: bool, output: OutputLiteral):
@@ -713,8 +721,12 @@ def cli_factory() -> click.Group:
runtime_impl += ("flax",)
if model.config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES:
runtime_impl += ("tf",)
json_data[m] = {"model_id": model.model_ids, "description": docs, "runtime_impl": runtime_impl}
converted.extend([convert_transformers_model_name(i) for i in model.model_ids])
json_data[m] = {
"model_id": model.config.__openllm_model_ids__,
"description": docs,
"runtime_impl": runtime_impl,
}
converted.extend([convert_transformers_model_name(i) for i in model.config.__openllm_model_ids__])
except Exception as err:
failed_initialized.append((m, err))
@@ -783,7 +795,7 @@ def cli_factory() -> click.Group:
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@model_id_option
@model_id_option(click)
@output_option
def download_models(model_name: str, model_id: str | None, output: OutputLiteral):
"""Setup LLM interactively.

View File

@@ -87,6 +87,7 @@ class _LazyConfigMapping(ConfigOrderedDict):
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder"}
class AutoConfig:
@@ -102,3 +103,15 @@ class AutoConfig:
f"Unrecognized configuration class for {model_name}. "
f"Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
)
@classmethod
def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]:
model_name = inflection.underscore(name)
if model_name in CONFIG_NAME_ALIASES:
model_name = CONFIG_NAME_ALIASES[model_name]
if model_name in CONFIG_MAPPING:
return CONFIG_MAPPING[model_name]
raise ValueError(
f"Unrecognized configuration class for {model_name}. "
f"Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
)

View File

@@ -16,15 +16,7 @@ from __future__ import annotations
import openllm
class ChatGLMConfig(
openllm.LLMConfig,
name_type="lowercase",
trust_remote_code=True,
default_timeout=3600000,
requires_gpu=True,
url="https://github.com/THUDM/ChatGLM-6B",
requirements=["cpm_kernels", "sentencepiece"],
):
class ChatGLMConfig(openllm.LLMConfig):
"""
ChatGLM is an open bilingual language model based on
[General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
@@ -41,11 +33,24 @@ class ChatGLMConfig(
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
"""
retain_history: bool = False
"""Whether to retain history given to the model. If set to True, then the model will retain given history."""
__config__ = {
"name_type": "lowercase",
"trust_remote_code": True,
"timeout": 3600000,
"requires_gpu": True,
"url": "https://github.com/THUDM/ChatGLM-6B",
"requirements": ["cpm_kernels", "sentencepiece"],
"default_id": "thudm/chatglm-6b-int4",
"model_ids": ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4"],
}
use_half_precision: bool = True
"""Whether to use half precision for model."""
retain_history: bool = openllm.LLMConfig.Field(
False,
description="""Whether to retain history given to the model.
If set to True, then the model will retain given history.""",
)
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
class GenerationConfig:
max_new_tokens: int = 2048

View File

@@ -41,11 +41,8 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
class ChatGLM(openllm.LLM):
__openllm_internal__ = True
default_id = "thudm/chatglm-6b-int4"
model_ids = ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4"]
device = torch.device("cuda")
def llm_post_init(self):
self.device = torch.device("cuda")
def import_model(
self,

View File

@@ -20,12 +20,7 @@ from __future__ import annotations
import openllm
class DollyV2Config(
openllm.LLMConfig,
default_timeout=3600000,
trust_remote_code=True,
url="https://github.com/databrickslabs/dolly",
):
class DollyV2Config(openllm.LLMConfig):
"""Databricks Dolly is an instruction-following large language model trained on the Databricks
machine learning platform that is licensed for commercial use.
@@ -39,6 +34,14 @@ class DollyV2Config(
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
"""
__config__ = {
"timeout": 3600000,
"trust_remote_code": True,
"url": "https://github.com/databrickslabs/dolly",
"default_id": "databricks/dolly-v2-3b",
"model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"],
}
return_full_text: bool = openllm.LLMConfig.Field(
False, description="Whether to return the full prompt to the users."
)

View File

@@ -38,17 +38,16 @@ class DollyV2(openllm.LLM):
__openllm_internal__ = True
default_id = "databricks/dolly-v2-3b"
@property
def import_kwargs(self):
return {
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
"torch_dtype": torch.bfloat16,
"_tokenizer_padding_side": "left",
}
model_ids = ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"]
import_kwargs = {
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
"torch_dtype": torch.bfloat16,
"_tokenizer_padding_side": "left",
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def llm_post_init(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def import_model(
self, model_id: str, tag: bentoml.Tag, *model_args: t.Any, tokenizer_kwds: dict[str, t.Any], **attrs: t.Any

View File

@@ -16,15 +16,7 @@ from __future__ import annotations
import openllm
class FalconConfig(
openllm.LLMConfig,
name_type="lowercase",
trust_remote_code=True,
requires_gpu=True,
default_timeout=3600000,
url="https://falconllm.tii.ae/",
requirements=["einops", "xformers", "safetensors"],
):
class FalconConfig(openllm.LLMConfig):
"""Falcon-7B is a 7B parameters causal decoder-only model built by
TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
enhanced with curated corpora. It is made available under the TII Falcon LLM License.
@@ -32,6 +24,22 @@ class FalconConfig(
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
"""
__config__ = {
"name_type": "lowercase",
"trust_remote_code": True,
"requires_gpu": True,
"timeout": 3600000,
"url": "https://falconllm.tii.ae/",
"requirements": ["einops", "xformers", "safetensors"],
"default_id": "tiiuae/falcon-7b",
"model_ids": [
"tiiuae/falcon-7b",
"tiiuae/falcon-40b",
"tiiuae/falcon-7b-instruct",
"tiiuae/falcon-40b-instruct",
],
}
class GenerationConfig:
max_new_tokens: int = 200
top_k: int = 10

View File

@@ -34,14 +34,12 @@ else:
class Falcon(openllm.LLM):
__openllm_internal__ = True
default_id = "tiiuae/falcon-7b"
model_ids = ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct"]
import_kwargs = {
"torch_dtype": torch.bfloat16,
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
}
@property
def import_kwargs(self):
return {
"torch_dtype": torch.bfloat16,
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
}
def import_model(
self, model_id: str, tag: bentoml.Tag, *model_args: t.Any, tokenizer_kwds: dict[str, t.Any], **attrs: t.Any

View File

@@ -44,13 +44,25 @@ $ openllm start flan-t5 --model-id google/flan-t5-xxl
DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruction}\nAnswer:"""
class FlanT5Config(openllm.LLMConfig, url="https://huggingface.co/docs/transformers/model_doc/flan-t5"):
class FlanT5Config(openllm.LLMConfig):
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf)
- it is an enhanced version of T5 that has been finetuned in a mixture of tasks.
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
"""
__config__ = {
"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5",
"default_id": "google/flan-t5-large",
"model_ids": [
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
],
}
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 2048

View File

@@ -29,17 +29,8 @@ else:
class FlanT5(openllm.LLM):
__openllm_internal__ = True
default_id = "google/flan-t5-large"
model_ids = [
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def llm_post_init(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def sanitize_parameters(
self,

View File

@@ -24,16 +24,6 @@ from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
class FlaxFlanT5(openllm.LLM):
__openllm_internal__ = True
default_id: str = "google/flan-t5-large"
model_ids = [
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
]
def sanitize_parameters(
self,
prompt: str,

View File

@@ -24,16 +24,6 @@ from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
class TFFlanT5(openllm.LLM):
__openllm_internal__ = True
default_id: str = "google/flan-t5-large"
model_ids = [
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
]
def sanitize_parameters(
self,
prompt: str,

View File

@@ -16,7 +16,7 @@ from __future__ import annotations
import openllm
class StableLMConfig(openllm.LLMConfig, name_type="lowercase", url="https://github.com/Stability-AI/StableLM"):
class StableLMConfig(openllm.LLMConfig):
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models
pre-trained on a diverse collection of English datasets with a sequence
length of 4096 to push beyond the context window limitations of existing open-source language models.
@@ -30,6 +30,18 @@ class StableLMConfig(openllm.LLMConfig, name_type="lowercase", url="https://gith
for more information.
"""
__config__ = {
"name_type": "lowercase",
"url": "https://github.com/Stability-AI/StableLM",
"default_id": "stabilityai/stablelm-tuned-alpha-3b",
"model_ids": [
"stabilityai/stablelm-tuned-alpha-3b",
"stabilityai/stablelm-tuned-alpha-7b",
"stabilityai/stablelm-base-alpha-3b",
"stabilityai/stablelm-base-alpha-7b",
],
}
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 128

View File

@@ -42,23 +42,17 @@ logger = logging.getLogger(__name__)
class StableLM(openllm.LLM):
__openllm_internal__ = True
load_in_mha = True if not torch.cuda.is_available() else False
default_id = "stabilityai/stablelm-tuned-alpha-3b"
def llm_post_init(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.load_in_mha = True if not torch.cuda.is_available() else False
model_ids = [
"stabilityai/stablelm-tuned-alpha-3b",
"stabilityai/stablelm-tuned-alpha-7b",
"stabilityai/stablelm-base-alpha-3b",
"stabilityai/stablelm-base-alpha-7b",
]
import_kwargs = {
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
"load_in_8bit": False,
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@property
def import_kwargs(self):
return {
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
"load_in_8bit": False,
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
}
def sanitize_parameters(
self,
@@ -98,7 +92,6 @@ class StableLM(openllm.LLM):
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
return generation_result[0]
@torch.inference_mode()
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
generation_kwargs = {
"do_sample": True,

View File

@@ -16,14 +16,7 @@ from __future__ import annotations
import openllm
class StarCoderConfig(
openllm.LLMConfig,
name_type="lowercase",
requires_gpu=True,
url="https://github.com/bigcode-project/starcoder",
requirements=["bitsandbytes"],
workers_per_resource=0.5,
):
class StarCoderConfig(openllm.LLMConfig):
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from
[The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
@@ -34,6 +27,16 @@ class StarCoderConfig(
Refer to [StarCoder's model card](https://huggingface.co/bigcode/starcoder) for more information.
"""
__config__ = {
"name_type": "lowercase",
"requires_gpu": True,
"url": "https://github.com/bigcode-project/starcoder",
"requirements": ["bitsandbytes"],
"workers_per_resource": 0.5,
"default_id": "bigcode/starcoder",
"model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"],
}
class GenerationConfig:
temperature: float = 0.2
max_new_tokens: int = 256

View File

@@ -40,18 +40,17 @@ FIM_INDICATOR = "<FILL_HERE>"
class StarCoder(openllm.LLM):
__openllm_internal__ = True
default_id = "bigcode/starcoder"
def llm_post_init(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ids = ["bigcode/starcoder", "bigcode/starcoderbase"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import_kwargs = {
"_tokenizer_padding_side": "left",
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
"load_in_8bit": True if torch.cuda.device_count() > 1 else False,
"torch_dtype": torch.float16,
}
@property
def import_kwargs(self):
return {
"_tokenizer_padding_side": "left",
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
"load_in_8bit": True if torch.cuda.device_count() > 1 else False,
"torch_dtype": torch.float16,
}
def import_model(
self,

View File

@@ -38,7 +38,10 @@ def _default_converter(value: t.Any, env: str | None) -> t.Any:
if env is not None:
value = os.environ.get(env, value)
if value is not None and isinstance(value, str):
return eval(value, {"__builtins__": {}}, {})
try:
return orjson.loads(value)
except orjson.JSONDecodeError as err:
raise RuntimeError(f"Failed to parse '{value}' from '{env}': {err}")
return value
@@ -65,7 +68,7 @@ def Field(
"""
metadata = attrs.pop("metadata", {})
if description is None:
description = "(No description is available)"
description = "(No description provided)"
metadata["description"] = description
if env is not None:
metadata["env"] = env

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import enum
import sys
from typing import (
@@ -500,3 +502,5 @@ def _make_init(
cls_on_setattr: Any,
attrs_init: bool,
) -> Callable[_P, Any]: ...
def _make_method(name: str, script: str, filename: str, globs: dict[str, Any]) -> Callable[..., Any]: ...
def _make_repr(attrs: tuple[Attribute[Any]], ns: str | None, cls: AttrsInstance) -> Callable[[AttrsInstance], str]: ...

4
typings/attr/_compat.pyi Normal file
View File

@@ -0,0 +1,4 @@
import threading
set_closure_cell = ...
repr_context: threading.local = ...