mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-22 14:31:26 -05:00
refactor(configuration): __config__ and perf
move model_ids and default_id to config class declaration, cleanup dependencies between config and LLM implementation lazy load module during LLM creation to llm_post_init fix post_init hooks to run load_in_mha. Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -12,13 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Configuration utilities for OpenLLM. All model configuration will inherit from openllm.configuration_utils.LLMConfig.
|
||||
Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``.
|
||||
|
||||
Note that ``openllm.LLMConfig`` is a subclass of ``pydantic.BaseModel``. It also
|
||||
has a ``to_click_options`` that returns a list of Click-compatible options for the model.
|
||||
Such options will then be parsed to ``openllm.__main__.cli``.
|
||||
|
||||
Each fields in ``openllm.LLMConfig`` will also automatically generate a environment
|
||||
Highlight feature: Each fields in ``openllm.LLMConfig`` will also automatically generate a environment
|
||||
variable based on its name field.
|
||||
|
||||
For example, the following config class:
|
||||
@@ -34,10 +30,14 @@ class FlanT5Config(openllm.LLMConfig):
|
||||
repetition_penalty = 1.0
|
||||
```
|
||||
|
||||
which generates the environment OPENLLM_FLAN_T5_GENERATION_TEMPERATURE for users to configure temperature
|
||||
dynamically during serve, ahead-of-serve or per requests.
|
||||
|
||||
Refer to ``openllm.LLMConfig`` docstring for more information.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
@@ -53,13 +53,13 @@ from deepmerge.merger import Merger
|
||||
import openllm
|
||||
|
||||
from .exceptions import GpuNotAvailableError, OpenLLMException
|
||||
from .utils import LazyType, ModelEnv, bentoml_cattr, dantic, lenient_issubclass
|
||||
from .utils import LazyType, ModelEnv, bentoml_cattr, dantic, first_not_none, lenient_issubclass
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import transformers
|
||||
from attr import _CountingAttr, _make_init
|
||||
from attr import _CountingAttr, _make_init, _make_method, _make_repr
|
||||
from transformers.generation.beam_constraints import Constraint
|
||||
|
||||
from ._types import ClickFunctionWrapper, F, O_co, P
|
||||
@@ -67,14 +67,16 @@ if t.TYPE_CHECKING:
|
||||
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
|
||||
|
||||
DictStrAny = dict[str, t.Any]
|
||||
ListStr = list[str]
|
||||
ItemgetterAny = itemgetter[t.Any]
|
||||
else:
|
||||
Constraint = t.Any
|
||||
ListStr = list
|
||||
DictStrAny = dict
|
||||
ItemgetterAny = itemgetter
|
||||
# NOTE: Using internal API from attr here, since we are actually
|
||||
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
|
||||
from attr._make import _CountingAttr, _make_init
|
||||
from attr._make import _CountingAttr, _make_init, _make_method, _make_repr
|
||||
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
@@ -93,6 +95,8 @@ config_merger = Merger(
|
||||
type_conflict_strategies=["override"],
|
||||
)
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
@t.overload
|
||||
def attrs_to_options(
|
||||
@@ -158,7 +162,9 @@ class GenerationConfig:
|
||||
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
|
||||
with some additional validation and environment constructor.
|
||||
|
||||
Note that we always set `do_sample=True`
|
||||
Note that we always set `do_sample=True`. This class is not designed to be used directly, rather
|
||||
to be used conjunction with LLMConfig. The instance of the generation config can then be accessed
|
||||
via ``LLMConfig.generation_config``.
|
||||
"""
|
||||
|
||||
# NOTE: parameters for controlling the length of the output
|
||||
@@ -417,20 +423,6 @@ def _populate_value_from_env_var(
|
||||
return os.environ.get(key, fallback)
|
||||
|
||||
|
||||
def env_transformers(cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]:
|
||||
transformed: list[attr.Attribute[t.Any]] = []
|
||||
for f in fields:
|
||||
if "env" not in f.metadata:
|
||||
raise ValueError(
|
||||
"Make sure to setup the field with 'cls.Field' or 'attr.field(..., metadata={\"env\": \"...\"})'"
|
||||
)
|
||||
_from_env = _populate_value_from_env_var(f.metadata["env"])
|
||||
if _from_env is not None:
|
||||
f = f.evolve(default=_from_env)
|
||||
transformed.append(f)
|
||||
return transformed
|
||||
|
||||
|
||||
# sentinel object for unequivocal object() getattr
|
||||
_sentinel = object()
|
||||
|
||||
@@ -590,7 +582,7 @@ def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConf
|
||||
)
|
||||
return transformed
|
||||
|
||||
generated_cls = attr.make_class(
|
||||
_cl = attr.make_class(
|
||||
cls.__name__.replace("Config", "GenerationConfig"),
|
||||
[],
|
||||
bases=(GenerationConfig,),
|
||||
@@ -599,28 +591,205 @@ def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConf
|
||||
repr=True,
|
||||
field_transformer=_evolve_with_base_default,
|
||||
)
|
||||
_cl.__doc__ = GenerationConfig.__doc__
|
||||
|
||||
return generated_cls
|
||||
if _has_gen_class:
|
||||
delattr(cls, "GenerationConfig")
|
||||
|
||||
return _cl
|
||||
|
||||
|
||||
# NOTE: This DEFAULT_LLMCONFIG_ATTRS is a way to dynamically generate attr.field
|
||||
# and will be saved for future use in LLMConfig if we have some shared config.
|
||||
DEFAULT_LLMCONFIG_ATTRS: tuple[tuple[str, t.Any, str, type[t.Any]], ...] = ()
|
||||
# NOTE: This is the ModelConfig where we can control the behaviour of the LLM.
|
||||
# refers to the __openllm_*__ docstring inside LLMConfig for more information.
|
||||
class ModelConfig(t.TypedDict, total=False):
|
||||
# NOTE: meta
|
||||
url: str
|
||||
requires_gpu: bool
|
||||
trust_remote_code: bool
|
||||
requirements: t.Optional[t.List[str]]
|
||||
|
||||
# NOTE: naming convention, only name_type is needed
|
||||
# as the three below it can be determined automatically
|
||||
name_type: t.Literal["dasherize", "lowercase"]
|
||||
model_name: str
|
||||
start_name: str
|
||||
env: openllm.utils.ModelEnv
|
||||
|
||||
# NOTE: serving configuration
|
||||
timeout: int
|
||||
workers_per_resource: t.Union[int, float]
|
||||
|
||||
# NOTE: use t.Required once we drop 3.8 support
|
||||
default_id: str
|
||||
model_ids: list[str]
|
||||
|
||||
|
||||
def _gen_default_model_config(cls: type[LLMConfig]) -> ModelConfig:
|
||||
"""Generate the default ModelConfig and delete __config__ in LLMConfig
|
||||
if defined inplace."""
|
||||
|
||||
_internal_config = t.cast(ModelConfig, getattr(cls, "__config__", {}))
|
||||
default_id = _internal_config.get("default_id", None)
|
||||
if default_id is None:
|
||||
raise RuntimeError("'default_id' is required under '__config__'.")
|
||||
model_ids = _internal_config.get("model_ids", None)
|
||||
if model_ids is None:
|
||||
raise RuntimeError("'model_ids' is required under '__config__'.")
|
||||
|
||||
def _first_not_null(key: str, default: _T) -> _T:
|
||||
return first_not_none(_internal_config.get(key), default=default)
|
||||
|
||||
llm_config_striped = cls.__name__.replace("Config", "")
|
||||
|
||||
name_type: t.Literal["dasherize", "lowercase"] = _first_not_null("name_type", "dasherize")
|
||||
|
||||
if name_type == "dasherize":
|
||||
default_model_name = inflection.underscore(llm_config_striped)
|
||||
default_start_name = inflection.dasherize(default_model_name)
|
||||
else:
|
||||
default_model_name = llm_config_striped.lower()
|
||||
default_start_name = default_model_name
|
||||
|
||||
model_name = _first_not_null("model_name", default_model_name)
|
||||
|
||||
_config = ModelConfig(
|
||||
name_type=name_type,
|
||||
model_name=model_name,
|
||||
default_id=default_id,
|
||||
model_ids=model_ids,
|
||||
start_name=_first_not_null("start_name", default_start_name),
|
||||
url=_first_not_null("url", "(not provided)"),
|
||||
requires_gpu=_first_not_null("requires_gpu", False),
|
||||
trust_remote_code=_first_not_null("trust_remote_code", False),
|
||||
requirements=_first_not_null("requirements", ListStr()),
|
||||
env=_first_not_null("env", openllm.utils.ModelEnv(model_name)),
|
||||
timeout=_first_not_null("timeout", 3600),
|
||||
workers_per_resource=_first_not_null("workers_per_resource", 1),
|
||||
)
|
||||
|
||||
if hasattr(cls, "__config__"):
|
||||
delattr(cls, "__config__")
|
||||
|
||||
return _config
|
||||
|
||||
|
||||
def _generate_unique_filename(cls: type[t.Any], func_name: str):
|
||||
return f"<LLMConfig generated {func_name} {cls.__module__}." f"{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
|
||||
|
||||
def _setattr_class(attr_name: str, value_var: t.Any):
|
||||
"""
|
||||
Use the builtin setattr to set *attr_name* to *value_var*.
|
||||
We can't use the cached object.__setattr__ since we are setting
|
||||
attributes to a class.
|
||||
"""
|
||||
return f"setattr(cls, '{attr_name}', {value_var})"
|
||||
|
||||
|
||||
@t.overload
|
||||
def _make_assignment_with_prefix_script(cls: type[LLMConfig], attributes: ModelConfig) -> t.Callable[..., None]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def _make_assignment_with_prefix_script(cls: type[LLMConfig], attributes: dict[str, t.Any]) -> t.Callable[..., None]:
|
||||
...
|
||||
|
||||
|
||||
def _make_assignment_with_prefix_script(cls: type[LLMConfig], attributes: t.Any) -> t.Callable[..., None]:
|
||||
"""Generate the assignment script with prefix attributes __openllm_<value>__"""
|
||||
args: list[str] = []
|
||||
globs: dict[str, t.Any] = {"cls": cls, "attr_dict": attributes}
|
||||
annotations: dict[str, t.Any] = {"return": None}
|
||||
|
||||
# Circumvent __setattr__ descriptor to save one lookup per assigment
|
||||
lines: list[str] = []
|
||||
for attr_name in attributes:
|
||||
arg_name = f"__openllm_{inflection.underscore(attr_name)}__"
|
||||
args.append(f"{attr_name}=attr_dict['{attr_name}']")
|
||||
lines.append(_setattr_class(arg_name, attr_name))
|
||||
annotations[attr_name] = type(attributes[attr_name])
|
||||
|
||||
script = "def __assign_attr(cls, %s):\n %s\n" % (", ".join(args), "\n ".join(lines) if lines else "pass")
|
||||
assign_method = _make_method(
|
||||
"__assign_attr",
|
||||
script,
|
||||
_generate_unique_filename(cls, "__assign_attr"),
|
||||
globs,
|
||||
)
|
||||
assign_method.__annotations__ = annotations
|
||||
|
||||
return assign_method
|
||||
|
||||
|
||||
@attr.define
|
||||
class LLMConfig:
|
||||
Field = dantic.Field
|
||||
"""Field is a alias to the internal dantic utilities to easily create
|
||||
attrs.fields with pydantic-compatible interface.
|
||||
"""
|
||||
``openllm.LLMConfig`` is somewhat a hybrid combination between the performance of `attrs` with the
|
||||
easy-to-use interface that pydantic offer. It lives in between where it allows users to quickly formulate
|
||||
a LLMConfig for any LLM without worrying too much about performance. It does a few things:
|
||||
|
||||
- Automatic environment conversion: Each fields will automatically be provisioned with an environment
|
||||
variable, make it easy to work with ahead-of-time or during serving time
|
||||
- Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API,
|
||||
i.e: ``model_construct_env``, ``to_generation_config``, ``to_click_options``
|
||||
- Automatic CLI generation: It can identify each fields and convert it to compatible Click options.
|
||||
This means developers can use any of the LLMConfig to create CLI with compatible-Python
|
||||
CLI library (click, typer, ...)
|
||||
|
||||
> Internally, LLMConfig is an attrs class. All subclass of LLMConfig contains "attrs-like" features,
|
||||
> which means LLMConfig will actually generate subclass to have attrs-compatible API, so that the subclass
|
||||
> can be written as any normal Python class.
|
||||
|
||||
To directly configure GenerationConfig for any given LLM, create a GenerationConfig under the subclass:
|
||||
|
||||
```python
|
||||
class FlanT5Config(openllm.LLMConfig):
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 3000
|
||||
top_k: int = 50
|
||||
top_p: float = 0.4
|
||||
repetition_penalty = 1.0
|
||||
```
|
||||
By doing so, openllm.LLMConfig will create a compatible GenerationConfig attrs class that can be converted
|
||||
to ``transformers.GenerationConfig``. These attribute can be accessed via ``LLMConfig.generation_config``.
|
||||
|
||||
By default, all LLMConfig has a __config__ that contains a default value. If any LLM requires customization,
|
||||
provide a ``__config__`` under the class declaration:
|
||||
|
||||
```python
|
||||
class FalconConfig(openllm.LLMConfig):
|
||||
__config__ = {"trust_remote_code": True, "default_timeout": 3600000}
|
||||
```
|
||||
|
||||
Note that ``model_name``, ``start_name``, and ``env`` is optional under ``__config__``. If set, then OpenLLM
|
||||
will respect that option for start and other components within the library.
|
||||
"""
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
|
||||
Field = dantic.Field
|
||||
"""Field is a alias to the internal dantic utilities to easily create
|
||||
attrs.fields with pydantic-compatible interface. For example:
|
||||
|
||||
```python
|
||||
class MyModelConfig(openllm.LLMConfig):
|
||||
|
||||
field1 = openllm.LLMConfig.Field(...)
|
||||
```
|
||||
"""
|
||||
|
||||
# NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
|
||||
if t.TYPE_CHECKING:
|
||||
# NOTE: Internal attributes that should only be used by OpenLLM. Users usually shouldn't
|
||||
# concern any of these.
|
||||
def __attrs_init__(self, **attrs: t.Any):
|
||||
"""Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract."""
|
||||
|
||||
__config__: ModelConfig | None = None
|
||||
"""Internal configuration for this LLM model. Each of the field in here will be populated
|
||||
and prefixed with __openllm_<value>__"""
|
||||
|
||||
__attrs_attrs__: tuple[attr.Attribute[t.Any], ...] = tuple()
|
||||
"""Since we are writing our own __init_subclass__, which is an alternative way for __prepare__,
|
||||
we want openllm.LLMConfig to be attrs-like dataclass that has pydantic-like interface.
|
||||
@@ -630,8 +799,15 @@ class LLMConfig:
|
||||
__openllm_attrs__: tuple[str, ...] = tuple()
|
||||
"""Internal attribute tracking to store converted LLMConfig attributes to correct attrs"""
|
||||
|
||||
__openllm_timeout__: int = 3600
|
||||
"""The default timeout to be set for this given LLM."""
|
||||
__openllm_hints__: dict[str, t.Any] = Field(None, init=False)
|
||||
"""An internal cache of resolved types for this LLMConfig."""
|
||||
|
||||
__openllm_accepted_keys__: set[str] = Field(None, init=False)
|
||||
"""The accepted keys for this LLMConfig."""
|
||||
|
||||
# NOTE: The following will be populated from __config__
|
||||
__openllm_url__: str = Field(None, init=False)
|
||||
"""The resolved url for this LLMConfig."""
|
||||
|
||||
__openllm_requires_gpu__: bool = False
|
||||
"""Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU."""
|
||||
@@ -639,6 +815,13 @@ class LLMConfig:
|
||||
__openllm_trust_remote_code__: bool = False
|
||||
"""Whether to always trust remote code"""
|
||||
|
||||
__openllm_requirements__: list[str] | None = None
|
||||
"""The default PyPI requirements needed to run this given LLM. By default, we will depend on
|
||||
bentoml, torch, transformers."""
|
||||
|
||||
__openllm_env__: openllm.utils.ModelEnv = Field(None, init=False)
|
||||
"""A ModelEnv instance for this LLMConfig."""
|
||||
|
||||
__openllm_model_name__: str = ""
|
||||
"""The normalized version of __openllm_start_name__, determined by __openllm_name_type__"""
|
||||
|
||||
@@ -649,21 +832,8 @@ class LLMConfig:
|
||||
"""the default name typed for this model. "dasherize" will convert the name to lowercase and
|
||||
replace spaces with dashes. "lowercase" will convert the name to lowercase."""
|
||||
|
||||
__openllm_env__: openllm.utils.ModelEnv = Field(None, init=False)
|
||||
"""A ModelEnv instance for this LLMConfig."""
|
||||
|
||||
__openllm_hints__: dict[str, t.Any] = Field(None, init=False)
|
||||
"""An internal cache of resolved types for this LLMConfig."""
|
||||
|
||||
__openllm_url__: str = Field(None, init=False)
|
||||
"""The resolved url for this LLMConfig."""
|
||||
|
||||
__openllm_accepted_keys__: set[str] = Field(None, init=False)
|
||||
"""The accepted keys for this LLMConfig."""
|
||||
|
||||
__openllm_requirements__: list[str] | None = None
|
||||
"""The default PyPI requirements needed to run this given LLM. By default, we will depend on
|
||||
bentoml, torch, transformers."""
|
||||
__openllm_timeout__: int = 3600
|
||||
"""The default timeout to be set for this given LLM."""
|
||||
|
||||
__openllm_workers_per_resource__: int | float = 1
|
||||
"""The default number of workers per resource. By default, we will use 1 worker per resource.
|
||||
@@ -671,6 +841,18 @@ class LLMConfig:
|
||||
https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy for more details.
|
||||
"""
|
||||
|
||||
__openllm_default_id__: str = Field(None)
|
||||
"""Return the default model to use when using 'openllm start <model_id>'.
|
||||
This could be one of the keys in 'self.model_ids' or custom users model."""
|
||||
|
||||
__openllm_model_ids__: list[str] = Field(None)
|
||||
"""A list of supported pretrained models tag for this given runnable.
|
||||
|
||||
For example:
|
||||
For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base",
|
||||
"google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"]
|
||||
"""
|
||||
|
||||
GenerationConfig: type = type
|
||||
"""Users can override this subclass of any given LLMConfig to provide GenerationConfig
|
||||
default value. For example:
|
||||
@@ -689,35 +871,9 @@ class LLMConfig:
|
||||
"""The result generated GenerationConfig class for this LLMConfig. This will be used
|
||||
to create the generation_config argument that can be used throughout the lifecycle."""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
*,
|
||||
name_type: t.Literal["dasherize", "lowercase"] = "dasherize",
|
||||
default_timeout: int | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
requires_gpu: bool = False,
|
||||
url: str | None = None,
|
||||
requirements: list[str] | None = None,
|
||||
workers_per_resource: int | float = 1,
|
||||
):
|
||||
if name_type == "dasherize":
|
||||
model_name = inflection.underscore(cls.__name__.replace("Config", ""))
|
||||
start_name = inflection.dasherize(model_name)
|
||||
else:
|
||||
model_name = cls.__name__.replace("Config", "").lower()
|
||||
start_name = model_name
|
||||
|
||||
cls.__openllm_name_type__ = name_type
|
||||
cls.__openllm_requires_gpu__ = requires_gpu
|
||||
cls.__openllm_timeout__ = default_timeout or 3600
|
||||
cls.__openllm_trust_remote_code__ = trust_remote_code
|
||||
|
||||
cls.__openllm_model_name__ = model_name
|
||||
cls.__openllm_start_name__ = start_name
|
||||
cls.__openllm_env__ = openllm.utils.ModelEnv(model_name)
|
||||
cls.__openllm_url__ = url or "(not set)"
|
||||
cls.__openllm_requirements__ = requirements
|
||||
cls.__openllm_workers_per_resource__ = workers_per_resource
|
||||
def __init_subclass__(cls):
|
||||
# NOTE: auto assignment attributes generated from __config__
|
||||
_make_assignment_with_prefix_script(cls, _gen_default_model_config(cls))(cls)
|
||||
|
||||
# NOTE: Since we want to enable a pydantic-like experience
|
||||
# this means we will have to hide the attr abstraction, and generate
|
||||
@@ -727,7 +883,7 @@ class LLMConfig:
|
||||
cd = cls.__dict__
|
||||
|
||||
def field_env_key(key: str) -> str:
|
||||
return f"OPENLLM_{model_name.upper()}_{key.upper()}"
|
||||
return f"OPENLLM_{cls.__openllm_model_name__.upper()}_{key.upper()}"
|
||||
|
||||
ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)}
|
||||
ca_list: list[tuple[str, _CountingAttr[t.Any]]] = []
|
||||
@@ -763,22 +919,14 @@ class LLMConfig:
|
||||
gen_attribute = gen_attribute.evolve(metadata=metadata)
|
||||
own_attrs.append(gen_attribute)
|
||||
|
||||
# This is to handle subclass of subclass of all provided LLMConfig.
|
||||
# refer to attrs for the original implementation.
|
||||
base_attrs, base_attr_map = _collect_base_attrs(cls, {a.name for a in own_attrs})
|
||||
|
||||
# __openllm_attrs__ is a tracking tuple[attr.Attribute[t.Any]]
|
||||
# that we construct ourself.
|
||||
cls.__openllm_attrs__ = tuple(a.name for a in own_attrs)
|
||||
|
||||
# NOTE: Enable some default attributes that can be shared across all LLMConfig
|
||||
if len(DEFAULT_LLMCONFIG_ATTRS) > 0:
|
||||
# NOTE: update the hints for default variables we dynamically added.
|
||||
hints.update({k: hints for k, _, _, hints in DEFAULT_LLMCONFIG_ATTRS})
|
||||
base_attrs = [
|
||||
attr.Attribute.from_counting_attr(k, cls.Field(default, env=field_env_key(k), description=docs), hints)
|
||||
for k, default, docs, hints in DEFAULT_LLMCONFIG_ATTRS
|
||||
if k not in cls.__openllm_attrs__
|
||||
] + base_attrs
|
||||
|
||||
attrs: list[attr.Attribute[t.Any]] = own_attrs + base_attrs
|
||||
|
||||
# Mandatory vs non-mandatory attr order only matters when they are part of
|
||||
@@ -810,38 +958,33 @@ class LLMConfig:
|
||||
_has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
|
||||
|
||||
AttrsTuple = _make_attr_tuple_class(cls.__name__, cls.__openllm_attrs__)
|
||||
# NOTE: the protocol for attrs-decorated class
|
||||
cls.__attrs_attrs__ = AttrsTuple(attrs)
|
||||
# NOTE: generate a __attrs_init__ for the subclass
|
||||
cls.__attrs_init__ = _add_method_dunders(
|
||||
cls,
|
||||
_make_init(
|
||||
cls,
|
||||
AttrsTuple(attrs),
|
||||
_has_pre_init,
|
||||
_has_post_init,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
base_attr_map,
|
||||
False,
|
||||
None,
|
||||
attrs_init=True,
|
||||
cls, # cls (the attrs-decorated class)
|
||||
cls.__attrs_attrs__, # tuple of attr.Attribute of cls
|
||||
_has_pre_init, # pre_init
|
||||
_has_post_init, # post_init
|
||||
False, # frozen
|
||||
True, # slots
|
||||
True, # cache_hash
|
||||
base_attr_map, # base_attr_map
|
||||
False, # is_exc (check if it is exception)
|
||||
None, # cls_on_setattr (essentially attr.setters)
|
||||
attrs_init=True, # whether to create __attrs_init__ instead of __init__
|
||||
),
|
||||
)
|
||||
cls.__attrs_attrs__ = AttrsTuple(attrs)
|
||||
|
||||
# NOTE: Finally, set the generation_class for this given config.
|
||||
cls.generation_class = _make_internal_generation_class(cls)
|
||||
|
||||
hints.update(t.get_type_hints(cls.generation_class))
|
||||
|
||||
cls.__openllm_hints__ = hints
|
||||
|
||||
cls.__openllm_accepted_keys__ = set(cls.__openllm_attrs__) | set(attr.fields_dict(cls.generation_class))
|
||||
|
||||
@property
|
||||
def name_type(self) -> t.Literal["dasherize", "lowercase"]:
|
||||
return self.__openllm_name_type__
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -849,33 +992,56 @@ class LLMConfig:
|
||||
__openllm_extras__: dict[str, t.Any] | None = None,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
self.__openllm_extras__ = openllm.utils.first_not_none(__openllm_extras__, default={})
|
||||
# create a copy of the list of keys as cache
|
||||
_cached_keys = tuple(attrs.keys())
|
||||
|
||||
self.__openllm_extras__ = first_not_none(__openllm_extras__, default={})
|
||||
config_merger.merge(
|
||||
self.__openllm_extras__, {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__}
|
||||
)
|
||||
|
||||
attrs = {k: v for k, v in attrs.items() if k not in self.__openllm_extras__ and v is not None}
|
||||
for k in _cached_keys:
|
||||
if k in self.__openllm_extras__ or attrs.get(k) is None:
|
||||
del attrs[k]
|
||||
_cached_keys = tuple(k for k in _cached_keys if k in attrs)
|
||||
|
||||
_generation_cl_dict = attr.fields_dict(self.generation_class)
|
||||
if generation_config is None:
|
||||
generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(self.generation_class)}
|
||||
generation_config = {k: v for k, v in attrs.items() if k in _generation_cl_dict}
|
||||
else:
|
||||
generation_keys = {k for k in attrs if k in _generation_cl_dict}
|
||||
if len(generation_keys) > 0:
|
||||
logger.warning(
|
||||
"When 'generation_config' is passed, \
|
||||
the following keys are ignored and won't be used: %s. If you wish to use those values, \
|
||||
pass it into 'generation_config'.",
|
||||
", ".join(generation_keys),
|
||||
)
|
||||
for k in _cached_keys:
|
||||
if k in generation_keys:
|
||||
del attrs[k]
|
||||
_cached_keys = tuple(k for k in _cached_keys if k in attrs)
|
||||
|
||||
self.generation_config = self.generation_class(**generation_config)
|
||||
base_attrs: tuple[attr.Attribute[t.Any], ...] = attr.fields(self.__class__)
|
||||
base_attrs += (
|
||||
attr.Attribute.from_counting_attr(
|
||||
name="generation_config",
|
||||
ca=dantic.Field(
|
||||
self.generation_config, description=inspect.cleandoc(self.generation_class.__doc__ or "")
|
||||
),
|
||||
type=self.generation_class,
|
||||
),
|
||||
)
|
||||
# mk the class __repr__ function with the updated fields.
|
||||
self.__class__.__repr__ = _add_method_dunders(self.__class__, _make_repr(base_attrs, None, self.__class__))
|
||||
|
||||
attrs = {k: v for k, v in attrs.items() if k not in generation_config}
|
||||
for k in _cached_keys:
|
||||
if k in generation_config:
|
||||
del attrs[k]
|
||||
|
||||
self.__attrs_init__(**{k: v for k, v in attrs.items() if k in self.__openllm_attrs__})
|
||||
|
||||
# The rest update to extras
|
||||
attrs = {k: v for k, v in attrs.items() if k not in self.__openllm_attrs__}
|
||||
config_merger.merge(self.__openllm_extras__, attrs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={repr(self.generation_class())}"
|
||||
if len(self.__openllm_attrs__) > 0:
|
||||
bases += ", " + ", ".join([f"{k}={getattr(self, k)}" for k in self.__openllm_attrs__]) + ")"
|
||||
else:
|
||||
bases += ")"
|
||||
return bases
|
||||
# The rest of attrs should only be the attributes to be passed to __attrs_init__
|
||||
self.__attrs_init__(**attrs)
|
||||
|
||||
def __getattr__(self, item: str) -> t.Any:
|
||||
if hasattr(self.generation_config, item):
|
||||
@@ -975,43 +1141,46 @@ class LLMConfig:
|
||||
config = transformers.GenerationConfig(**bentoml_cattr.unstructure(self.generation_config))
|
||||
return config.to_dict() if return_as_dict else config
|
||||
|
||||
@classmethod
|
||||
@t.overload
|
||||
def to_click_options(
|
||||
self, f: t.Callable[..., openllm.LLMConfig]
|
||||
cls, f: t.Callable[..., openllm.LLMConfig]
|
||||
) -> F[P, ClickFunctionWrapper[..., openllm.LLMConfig]]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@t.overload
|
||||
def to_click_options(self, f: t.Callable[P, O_co]) -> F[P, ClickFunctionWrapper[P, O_co]]:
|
||||
def to_click_options(cls, f: t.Callable[P, O_co]) -> F[P, ClickFunctionWrapper[P, O_co]]:
|
||||
...
|
||||
|
||||
def to_click_options(self, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
|
||||
@classmethod
|
||||
def to_click_options(cls, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
|
||||
"""
|
||||
Convert current model to click options. This can be used as a decorator for click commands.
|
||||
Note that the identifier for all LLMConfig will be prefixed with '<model_name>_*', and the generation config
|
||||
will be prefixed with '<model_name>_generation_*'.
|
||||
"""
|
||||
for name, field in attr.fields_dict(self.generation_class).items():
|
||||
ty = self.__openllm_hints__.get(name)
|
||||
for name, field in attr.fields_dict(cls.generation_class).items():
|
||||
ty = cls.__openllm_hints__.get(name)
|
||||
if t.get_origin(ty) is t.Union:
|
||||
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
|
||||
continue
|
||||
f = attrs_to_options(name, field, self.__openllm_model_name__, typ=ty, suffix_generation=True)(f)
|
||||
f = optgroup.group(f"{self.generation_class.__name__} generation options")(f)
|
||||
f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f)
|
||||
f = optgroup.group(f"{cls.generation_class.__name__} generation options")(f)
|
||||
|
||||
if len(self.__class__.__openllm_attrs__) == 0:
|
||||
if len(cls.__openllm_attrs__) == 0:
|
||||
# NOTE: in this case, the function is already a ClickFunctionWrapper
|
||||
# hence the casting
|
||||
return f
|
||||
|
||||
for name, field in attr.fields_dict(self.__class__).items():
|
||||
ty = self.__openllm_hints__.get(name)
|
||||
for name, field in attr.fields_dict(cls).items():
|
||||
ty = cls.__openllm_hints__.get(name)
|
||||
if t.get_origin(ty) is t.Union:
|
||||
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
|
||||
continue
|
||||
f = attrs_to_options(name, field, self.__openllm_model_name__, typ=ty)(f)
|
||||
f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty)(f)
|
||||
|
||||
return optgroup.group(f"{self.__class__.__name__} options")(f)
|
||||
return optgroup.group(f"{cls.__name__} options")(f)
|
||||
|
||||
|
||||
bentoml_cattr.register_unstructure_hook_factory(
|
||||
|
||||
@@ -32,6 +32,7 @@ from bentoml.types import ModelSignature, ModelSignatureDict
|
||||
import openllm
|
||||
|
||||
from .exceptions import ForbiddenAttributeError, OpenLLMException
|
||||
from .models.auto import AutoConfig
|
||||
from .utils import ENV_VARS_TRUE_VALUES, LazyLoader, bentoml_cattr
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -180,37 +181,19 @@ def import_model(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
_required_namespace = {"default_id", "model_ids"}
|
||||
|
||||
_reserved_namespace = _required_namespace | {
|
||||
"config_class",
|
||||
"model",
|
||||
"tokenizer",
|
||||
"import_kwargs",
|
||||
}
|
||||
_reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"}
|
||||
|
||||
|
||||
class LLMInterface(ABC):
|
||||
"""This defines the loose contract for all openllm.LLM implementations."""
|
||||
|
||||
default_id: str
|
||||
"""Return the default model to use when using 'openllm start <model_id>'.
|
||||
This could be one of the keys in 'self.model_ids' or custom users model."""
|
||||
|
||||
model_ids: list[str]
|
||||
"""A list of supported pretrained models tag for this given runnable.
|
||||
|
||||
For example:
|
||||
For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base",
|
||||
"google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"]
|
||||
"""
|
||||
|
||||
config_class: type[openllm.LLMConfig]
|
||||
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
|
||||
|
||||
import_kwargs: dict[str, t.Any] | None = None
|
||||
"""The default import kwargs to used when importing the model.
|
||||
This will be passed into 'openllm.LLM.import_model'."""
|
||||
@property
|
||||
def import_kwargs(self) -> dict[str, t.Any] | None:
|
||||
"""The default import kwargs to used when importing the model.
|
||||
This will be passed into 'openllm.LLM.import_model'."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any:
|
||||
@@ -265,34 +248,29 @@ class LLMInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _default_post_init(self: LLM):
|
||||
# load_in_mha: Whether to apply BetterTransformer (or Torch MultiHeadAttention) during inference load.
|
||||
# See https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/
|
||||
# for more information.
|
||||
# NOTE: set a default variable to transform to BetterTransformer by default for inference
|
||||
self.load_in_mha = (
|
||||
os.environ.get(self.config_class.__openllm_env__.bettertransformer, str(False)).upper() in ENV_VARS_TRUE_VALUES
|
||||
)
|
||||
if self.config_class.__openllm_requires_gpu__:
|
||||
# For all models that requires GPU, no need to offload it to BetterTransformer
|
||||
# use bitsandbytes instead
|
||||
|
||||
self.load_in_mha = False
|
||||
|
||||
|
||||
class LLMMetaclass(ABCMeta):
|
||||
def __new__(
|
||||
mcls, cls_name: str, bases: tuple[type[t.Any], ...], namespace: dict[str, t.Any], **attrs: t.Any
|
||||
) -> type:
|
||||
"""Metaclass for creating a LLM."""
|
||||
if LLMInterface not in bases: # only actual openllm.LLM should hit this branch.
|
||||
if "__annotations__" not in namespace:
|
||||
annotations_dict: dict[str, t.Any] = {}
|
||||
namespace["__annotations__"] = annotations_dict
|
||||
|
||||
# NOTE: check for required attributes
|
||||
if "__openllm_internal__" not in namespace:
|
||||
_required_namespace.add("config_class")
|
||||
for k in _required_namespace:
|
||||
if k not in namespace:
|
||||
raise RuntimeError(f"Missing required key '{k}'. Make sure to define it within the LLM subclass.")
|
||||
|
||||
# NOTE: set implementation branch
|
||||
prefix_class_name_config = cls_name
|
||||
if "__llm_implementation__" in namespace:
|
||||
raise RuntimeError(
|
||||
f"""\
|
||||
__llm_implementation__ should not be set directly. Instead make sure that your class
|
||||
name follows the convention prefix:
|
||||
- For Tensorflow implementation: 'TF{cls_name}'
|
||||
- For Flax implementation: 'Flax{cls_name}'
|
||||
- For PyTorch implementation: '{cls_name}'"""
|
||||
)
|
||||
if cls_name.startswith("Flax"):
|
||||
implementation = "flax"
|
||||
prefix_class_name_config = cls_name[4:]
|
||||
@@ -302,39 +280,37 @@ class LLMMetaclass(ABCMeta):
|
||||
else:
|
||||
implementation = "pt"
|
||||
namespace["__llm_implementation__"] = implementation
|
||||
config_class = AutoConfig.infer_class_from_name(prefix_class_name_config)
|
||||
|
||||
# NOTE: setup config class branch
|
||||
if "__openllm_internal__" in namespace:
|
||||
# NOTE: we will automatically find the subclass for this given config class
|
||||
if "config_class" not in namespace:
|
||||
# this branch we will automatically get the class
|
||||
namespace["config_class"] = getattr(openllm, f"{prefix_class_name_config}Config")
|
||||
namespace["config_class"] = config_class
|
||||
else:
|
||||
logger.debug(f"Using config class {namespace['config_class']} for {cls_name}.")
|
||||
# NOTE: check for required attributes
|
||||
else:
|
||||
if "config_class" not in namespace:
|
||||
raise RuntimeError(
|
||||
"Missing required key 'config_class'. Make sure to define it within the LLM subclass."
|
||||
)
|
||||
|
||||
config_class: type[openllm.LLMConfig] = namespace["config_class"]
|
||||
# NOTE: the llm_post_init branch
|
||||
if "llm_post_init" in namespace:
|
||||
original_llm_post_init = namespace["llm_post_init"]
|
||||
|
||||
# NOTE: update the annotations for self.config
|
||||
namespace["__annotations__"]["config"] = t.get_type_hints(config_class)
|
||||
def wrapped_llm_post_init(self: LLM) -> None:
|
||||
"""We need to both initialize private attributes and call the user-defined model_post_init
|
||||
method.
|
||||
"""
|
||||
_default_post_init(self)
|
||||
original_llm_post_init(self)
|
||||
|
||||
for key in ("__openllm_start_name__", "__openllm_requires_gpu__"):
|
||||
namespace[key] = getattr(config_class, key)
|
||||
|
||||
# load_in_mha: Whether to apply BetterTransformer (or Torch MultiHeadAttention) during inference load.
|
||||
# See https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/
|
||||
# for more information.
|
||||
# NOTE: set a default variable to transform to BetterTransformer by default for inference
|
||||
if "load_in_mha" not in namespace:
|
||||
load_in_mha = (
|
||||
os.environ.get(config_class().__openllm_env__.bettertransformer, str(False)).upper()
|
||||
in ENV_VARS_TRUE_VALUES
|
||||
)
|
||||
namespace["load_in_mha"] = load_in_mha
|
||||
|
||||
if namespace["__openllm_requires_gpu__"]:
|
||||
# For all models that requires GPU, no need to offload it to BetterTransformer
|
||||
# use bitsandbytes instead
|
||||
namespace["load_in_mha"] = False
|
||||
namespace["llm_post_init"] = wrapped_llm_post_init
|
||||
else:
|
||||
namespace["llm_post_init"] = _default_post_init
|
||||
|
||||
# NOTE: import_model branch
|
||||
if "import_model" not in namespace:
|
||||
@@ -343,16 +319,10 @@ class LLMMetaclass(ABCMeta):
|
||||
else:
|
||||
logger.debug("Using custom 'import_model' for %s", cls_name)
|
||||
|
||||
# NOTE: populate with default cache.
|
||||
namespace.update({k: None for k in ("__llm_bentomodel__", "__llm_model__", "__llm_tokenizer__")})
|
||||
|
||||
cls: type[LLM] = super().__new__(t.cast("type[type[LLM]]", mcls), cls_name, bases, namespace, **attrs)
|
||||
|
||||
cls.__openllm_post_init__ = None if cls.llm_post_init is LLMInterface.llm_post_init else cls.llm_post_init
|
||||
|
||||
cls.__openllm_custom_load__ = None if cls.load_model is LLMInterface.load_model else cls.load_model
|
||||
|
||||
if getattr(cls, "config_class") is None:
|
||||
raise RuntimeError(f"'config_class' must be defined for '{cls.__name__}'")
|
||||
return cls
|
||||
else:
|
||||
# the LLM class itself being created, no need to setup
|
||||
@@ -362,23 +332,19 @@ class LLMMetaclass(ABCMeta):
|
||||
class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
if t.TYPE_CHECKING:
|
||||
# NOTE: the following will be populated by metaclass
|
||||
__llm_bentomodel__: bentoml.Model | None = None
|
||||
__llm_model__: LLMModel | None = None
|
||||
__llm_tokenizer__: LLMTokenizer | None = None
|
||||
__llm_implementation__: t.Literal["pt", "tf", "flax"]
|
||||
|
||||
__openllm_start_name__: str
|
||||
__openllm_requires_gpu__: bool
|
||||
__openllm_post_init__: t.Callable[[t.Self], None] | None
|
||||
|
||||
__openllm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None
|
||||
|
||||
load_in_mha: bool
|
||||
_llm_attrs: dict[str, t.Any]
|
||||
_llm_args: tuple[t.Any, ...]
|
||||
|
||||
# NOTE: the following is the similar interface to HuggingFace pretrained protocol.
|
||||
__llm_bentomodel__: bentoml.Model | None = None
|
||||
__llm_model__: LLMModel | None = None
|
||||
__llm_tokenizer__: LLMTokenizer | None = None
|
||||
|
||||
# NOTE: the following is the similar interface to HuggingFace pretrained protocol.
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_id: str | None = None, llm_config: openllm.LLMConfig | None = None, *args: t.Any, **attrs: t.Any
|
||||
@@ -446,15 +412,33 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
Note that this tag will be generated based on `self.default_id` or the given `pretrained` kwds.
|
||||
passed from the __init__ constructor.
|
||||
|
||||
``llm_post_init`` can also be implemented if you need to do any
|
||||
additional initialization after everything is setup.
|
||||
``llm_post_init`` can also be implemented if you need to do any additional
|
||||
initialization after everything is setup.
|
||||
|
||||
Note: If you need to implement a custom `load_model`, the following is an example from Falcon implementation:
|
||||
|
||||
```python
|
||||
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
torch_dtype = attrs.pop("torch_dtype", torch.bfloat16)
|
||||
device_map = attrs.pop("device_map", "auto")
|
||||
|
||||
_ref = bentoml.transformers.get(tag)
|
||||
|
||||
model = bentoml.transformers.load_model(_ref, device_map=device_map, torch_dtype=torch_dtype, **attrs)
|
||||
return transformers.pipeline("text-generation", model=model, tokenizer=_ref.custom_objects["tokenizer"])
|
||||
```
|
||||
|
||||
Args:
|
||||
model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used.
|
||||
llm_config: The config to use for this LLM. Defaults to None. If not passed, we will use 'self.config_class'
|
||||
to construct default configuration.
|
||||
llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM
|
||||
will use `config_class` to construct default configuration.
|
||||
*args: The args to be passed to the model.
|
||||
**attrs: The kwargs to be passed to the model.
|
||||
|
||||
The following are optional:
|
||||
openllm_model_version: version for this `model_id`. By default, users can ignore this if using pretrained
|
||||
weights as OpenLLM will use the commit_hash of given model_id.
|
||||
However, if `model_id` is a path, this argument is recomended to include.
|
||||
"""
|
||||
|
||||
if llm_config is not None:
|
||||
@@ -466,10 +450,8 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
attrs = self.config.__openllm_extras__
|
||||
|
||||
if model_id is None:
|
||||
model_id = os.environ.get(self.config.__openllm_env__.model_id, None)
|
||||
if not model_id:
|
||||
assert self.default_id, "A default model is required for any LLM."
|
||||
model_id = self.default_id
|
||||
model_id = os.environ.get(self.config.__openllm_env__.model_id, self.config.__openllm_default_id__)
|
||||
assert model_id is not None
|
||||
|
||||
# NOTE: This is the actual given path or pretrained weight for this LLM.
|
||||
self._model_id = model_id
|
||||
@@ -479,7 +461,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
self._llm_attrs = attrs
|
||||
|
||||
if self.__openllm_post_init__:
|
||||
self.__openllm_post_init__(self)
|
||||
self.llm_post_init()
|
||||
|
||||
def __setattr__(self, attr: str, value: t.Any):
|
||||
if attr in _reserved_namespace:
|
||||
@@ -500,7 +482,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
def identifying_params(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
"configuration": self.config.model_dump_json().decode(),
|
||||
"model_ids": orjson.dumps(self.model_ids).decode(),
|
||||
"model_ids": orjson.dumps(self.config.__openllm_model_ids__).decode(),
|
||||
}
|
||||
|
||||
@t.overload
|
||||
@@ -643,7 +625,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
return self._bentomodel.tag
|
||||
|
||||
@property
|
||||
def model(self) -> LLMModel | torch.nn.Module:
|
||||
def model(self) -> LLMModel:
|
||||
"""The model to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it."""
|
||||
# Run check for GPU
|
||||
trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
@@ -818,7 +800,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
|
||||
(_Runnable,),
|
||||
{
|
||||
"SUPPORTED_RESOURCES": ("nvidia.com/gpu", "cpu")
|
||||
if self.__openllm_requires_gpu__
|
||||
if self.config.__openllm_requires_gpu__
|
||||
else ("nvidia.com/gpu",),
|
||||
"llm_type": self.llm_type,
|
||||
"identifying_params": self.identifying_params,
|
||||
|
||||
@@ -76,16 +76,17 @@ def construct_python_options(llm: openllm.LLM, llm_fs: FS) -> PythonOptions:
|
||||
# first, then proceed to install everything inside the wheels/ folder.
|
||||
packages: list[str] = ["openllm"]
|
||||
|
||||
ModelEnv = openllm.utils.ModelEnv(llm.__openllm_start_name__)
|
||||
if llm.config.__openllm_requirements__ is not None:
|
||||
packages.extend(llm.config.__openllm_requirements__)
|
||||
|
||||
if not (str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false"):
|
||||
packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
to_use_framework = ModelEnv.get_framework_env()
|
||||
to_use_framework = llm.config.__openllm_env__.get_framework_env()
|
||||
if to_use_framework == "flax":
|
||||
assert utils.is_flax_available(), f"Flax is not available, while {ModelEnv.framework} is set to 'flax'"
|
||||
assert (
|
||||
utils.is_flax_available()
|
||||
), f"Flax is not available, while {llm.config.__openllm_env__.framework} is set to 'flax'"
|
||||
packages.extend(
|
||||
[
|
||||
f"flax>={importlib.metadata.version('flax')}",
|
||||
@@ -94,7 +95,9 @@ def construct_python_options(llm: openllm.LLM, llm_fs: FS) -> PythonOptions:
|
||||
]
|
||||
)
|
||||
elif to_use_framework == "tf":
|
||||
assert utils.is_tf_available(), f"TensorFlow is not available, while {ModelEnv.framework} is set to 'tf'"
|
||||
assert (
|
||||
utils.is_tf_available()
|
||||
), f"TensorFlow is not available, while {llm.config.__openllm_env__.framework} is set to 'tf'"
|
||||
candidates = (
|
||||
"tensorflow",
|
||||
"tensorflow-cpu",
|
||||
@@ -128,7 +131,6 @@ def construct_python_options(llm: openllm.LLM, llm_fs: FS) -> PythonOptions:
|
||||
|
||||
|
||||
def construct_docker_options(llm: openllm.LLM, _: FS) -> DockerOptions:
|
||||
ModelEnv = openllm.utils.ModelEnv(llm.__openllm_start_name__)
|
||||
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options += (
|
||||
" "
|
||||
@@ -141,7 +143,7 @@ def construct_docker_options(llm: openllm.LLM, _: FS) -> DockerOptions:
|
||||
return DockerOptions(
|
||||
cuda_version="11.6", # NOTE: Torch 2.0 currently only support 11.6 as the latest CUDA version
|
||||
env={
|
||||
ModelEnv.framework: ModelEnv.get_framework_env(),
|
||||
llm.config.__openllm_env__.framework: llm.config.__openllm_env__.get_framework_env(),
|
||||
"OPENLLM_MODEL": llm.config.__openllm_model_name__,
|
||||
"BENTOML_DEBUG": str(get_debug_mode()),
|
||||
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options,
|
||||
@@ -200,12 +202,12 @@ def build(model_name: str, *, __cli__: bool = False, **attrs: t.Any) -> tuple[be
|
||||
raise bentoml.exceptions.NotFound("Overwriting previously saved Bento.")
|
||||
_previously_built = True
|
||||
except bentoml.exceptions.NotFound:
|
||||
logger.info("Building Bento for LLM '%s'", llm.__openllm_start_name__)
|
||||
logger.info("Building Bento for LLM '%s'", llm.config.__openllm_start_name__)
|
||||
bento = bentoml.bentos.build(
|
||||
f"{service_name}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
description=f"OpenLLM service for {llm.__openllm_start_name__}",
|
||||
description=f"OpenLLM service for {llm.config.__openllm_start_name__}",
|
||||
include=[
|
||||
f for f in llm_fs.walk.files(filter=["*.py"])
|
||||
], # NOTE: By default, we are using _service.py as the default service, for now.
|
||||
|
||||
@@ -384,12 +384,10 @@ def start_model_command(
|
||||
ModelEnv = openllm.utils.ModelEnv(model_name)
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
|
||||
for_doc = openllm.AutoLLM.for_model(model_name)
|
||||
docstring = f"""\
|
||||
{ModelEnv.start_docstring}
|
||||
\b
|
||||
Available model_id(s) to use with '{model_name}' are: {for_doc.model_ids} [default: {for_doc.default_id}]
|
||||
Tip: One can pass one of the aforementioned to '--model-id' to use other pretrained weights.
|
||||
Available model_id(s): {llm_config.__openllm_model_ids__} [default: {llm_config.__openllm_default_id__}]
|
||||
"""
|
||||
command_attrs: dict[str, t.Any] = {
|
||||
"name": ModelEnv.model_name,
|
||||
@@ -399,7 +397,7 @@ Tip: One can pass one of the aforementioned to '--model-id' to use other pretrai
|
||||
}
|
||||
|
||||
aliases: list[str] = []
|
||||
if llm_config.name_type == "dasherize":
|
||||
if llm_config.__openllm_name_type__ == "dasherize":
|
||||
aliases.append(llm_config.__openllm_start_name__)
|
||||
|
||||
command_attrs["aliases"] = aliases if len(aliases) > 0 else None
|
||||
@@ -429,8 +427,9 @@ Tip: One can pass one of the aforementioned to '--model-id' to use other pretrai
|
||||
@group.command(**command_attrs)
|
||||
@llm_config.to_click_options
|
||||
@parse_serve_args(_serve_grpc)
|
||||
@click.option("--server-timeout", type=int, default=3600, help="Server timeout in seconds")
|
||||
@model_id_option
|
||||
@cog.optgroup.group("General LLM Options")
|
||||
@cog.optgroup.option("--server-timeout", type=int, default=3600, help="Server timeout in seconds")
|
||||
@model_id_option(cog.optgroup, model_env=ModelEnv)
|
||||
@click.option(
|
||||
"--device",
|
||||
type=tuple,
|
||||
@@ -439,6 +438,7 @@ Tip: One can pass one of the aforementioned to '--model-id' to use other pretrai
|
||||
envvar="CUDA_VISIBLE_DEVICES",
|
||||
callback=parse_device_callback,
|
||||
help=f"Assign GPU devices (if available) for {model_name}.",
|
||||
show_envvar=True,
|
||||
)
|
||||
def model_start(
|
||||
server_timeout: int,
|
||||
@@ -575,12 +575,20 @@ output_option = click.option(
|
||||
default="pretty",
|
||||
help="Showing output type. Default to 'pretty'",
|
||||
)
|
||||
model_id_option = click.option(
|
||||
"--model-id",
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help="Optional model_id name or path for (fine-tune) weight.",
|
||||
)
|
||||
|
||||
|
||||
def model_id_option(factory: t.Any, model_env: openllm.utils.ModelEnv | None = None):
|
||||
envvar = None
|
||||
if model_env is not None:
|
||||
envvar = model_env.model_id
|
||||
return factory.option(
|
||||
"--model-id",
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help="Optional model_id name or path for (fine-tune) weight.",
|
||||
envvar=envvar,
|
||||
show_envvar=True if envvar is not None else False,
|
||||
)
|
||||
|
||||
|
||||
def cli_factory() -> click.Group:
|
||||
@@ -627,7 +635,7 @@ def cli_factory() -> click.Group:
|
||||
@click.argument(
|
||||
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
|
||||
)
|
||||
@model_id_option
|
||||
@model_id_option(click)
|
||||
@output_option
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
|
||||
def build(model_name: str, model_id: str | None, overwrite: bool, output: OutputLiteral):
|
||||
@@ -713,8 +721,12 @@ def cli_factory() -> click.Group:
|
||||
runtime_impl += ("flax",)
|
||||
if model.config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES:
|
||||
runtime_impl += ("tf",)
|
||||
json_data[m] = {"model_id": model.model_ids, "description": docs, "runtime_impl": runtime_impl}
|
||||
converted.extend([convert_transformers_model_name(i) for i in model.model_ids])
|
||||
json_data[m] = {
|
||||
"model_id": model.config.__openllm_model_ids__,
|
||||
"description": docs,
|
||||
"runtime_impl": runtime_impl,
|
||||
}
|
||||
converted.extend([convert_transformers_model_name(i) for i in model.config.__openllm_model_ids__])
|
||||
except Exception as err:
|
||||
failed_initialized.append((m, err))
|
||||
|
||||
@@ -783,7 +795,7 @@ def cli_factory() -> click.Group:
|
||||
@click.argument(
|
||||
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
|
||||
)
|
||||
@model_id_option
|
||||
@model_id_option(click)
|
||||
@output_option
|
||||
def download_models(model_name: str, model_id: str | None, output: OutputLiteral):
|
||||
"""Setup LLM interactively.
|
||||
|
||||
@@ -87,6 +87,7 @@ class _LazyConfigMapping(ConfigOrderedDict):
|
||||
|
||||
|
||||
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder"}
|
||||
|
||||
|
||||
class AutoConfig:
|
||||
@@ -102,3 +103,15 @@ class AutoConfig:
|
||||
f"Unrecognized configuration class for {model_name}. "
|
||||
f"Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]:
|
||||
model_name = inflection.underscore(name)
|
||||
if model_name in CONFIG_NAME_ALIASES:
|
||||
model_name = CONFIG_NAME_ALIASES[model_name]
|
||||
if model_name in CONFIG_MAPPING:
|
||||
return CONFIG_MAPPING[model_name]
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class for {model_name}. "
|
||||
f"Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
@@ -16,15 +16,7 @@ from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
|
||||
class ChatGLMConfig(
|
||||
openllm.LLMConfig,
|
||||
name_type="lowercase",
|
||||
trust_remote_code=True,
|
||||
default_timeout=3600000,
|
||||
requires_gpu=True,
|
||||
url="https://github.com/THUDM/ChatGLM-6B",
|
||||
requirements=["cpm_kernels", "sentencepiece"],
|
||||
):
|
||||
class ChatGLMConfig(openllm.LLMConfig):
|
||||
"""
|
||||
ChatGLM is an open bilingual language model based on
|
||||
[General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
|
||||
@@ -41,11 +33,24 @@ class ChatGLMConfig(
|
||||
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
|
||||
"""
|
||||
|
||||
retain_history: bool = False
|
||||
"""Whether to retain history given to the model. If set to True, then the model will retain given history."""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"timeout": 3600000,
|
||||
"requires_gpu": True,
|
||||
"url": "https://github.com/THUDM/ChatGLM-6B",
|
||||
"requirements": ["cpm_kernels", "sentencepiece"],
|
||||
"default_id": "thudm/chatglm-6b-int4",
|
||||
"model_ids": ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4"],
|
||||
}
|
||||
|
||||
use_half_precision: bool = True
|
||||
"""Whether to use half precision for model."""
|
||||
retain_history: bool = openllm.LLMConfig.Field(
|
||||
False,
|
||||
description="""Whether to retain history given to the model.
|
||||
If set to True, then the model will retain given history.""",
|
||||
)
|
||||
|
||||
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
|
||||
@@ -41,11 +41,8 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
class ChatGLM(openllm.LLM):
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_id = "thudm/chatglm-6b-int4"
|
||||
|
||||
model_ids = ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4"]
|
||||
|
||||
device = torch.device("cuda")
|
||||
def llm_post_init(self):
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
def import_model(
|
||||
self,
|
||||
|
||||
@@ -20,12 +20,7 @@ from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
|
||||
class DollyV2Config(
|
||||
openllm.LLMConfig,
|
||||
default_timeout=3600000,
|
||||
trust_remote_code=True,
|
||||
url="https://github.com/databrickslabs/dolly",
|
||||
):
|
||||
class DollyV2Config(openllm.LLMConfig):
|
||||
"""Databricks’ Dolly is an instruction-following large language model trained on the Databricks
|
||||
machine learning platform that is licensed for commercial use.
|
||||
|
||||
@@ -39,6 +34,14 @@ class DollyV2Config(
|
||||
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
"timeout": 3600000,
|
||||
"trust_remote_code": True,
|
||||
"url": "https://github.com/databrickslabs/dolly",
|
||||
"default_id": "databricks/dolly-v2-3b",
|
||||
"model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"],
|
||||
}
|
||||
|
||||
return_full_text: bool = openllm.LLMConfig.Field(
|
||||
False, description="Whether to return the full prompt to the users."
|
||||
)
|
||||
|
||||
@@ -38,17 +38,16 @@ class DollyV2(openllm.LLM):
|
||||
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_id = "databricks/dolly-v2-3b"
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"_tokenizer_padding_side": "left",
|
||||
}
|
||||
|
||||
model_ids = ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"]
|
||||
|
||||
import_kwargs = {
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"_tokenizer_padding_side": "left",
|
||||
}
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
def llm_post_init(self):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def import_model(
|
||||
self, model_id: str, tag: bentoml.Tag, *model_args: t.Any, tokenizer_kwds: dict[str, t.Any], **attrs: t.Any
|
||||
|
||||
@@ -16,15 +16,7 @@ from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
|
||||
class FalconConfig(
|
||||
openllm.LLMConfig,
|
||||
name_type="lowercase",
|
||||
trust_remote_code=True,
|
||||
requires_gpu=True,
|
||||
default_timeout=3600000,
|
||||
url="https://falconllm.tii.ae/",
|
||||
requirements=["einops", "xformers", "safetensors"],
|
||||
):
|
||||
class FalconConfig(openllm.LLMConfig):
|
||||
"""Falcon-7B is a 7B parameters causal decoder-only model built by
|
||||
TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
enhanced with curated corpora. It is made available under the TII Falcon LLM License.
|
||||
@@ -32,6 +24,22 @@ class FalconConfig(
|
||||
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"requires_gpu": True,
|
||||
"timeout": 3600000,
|
||||
"url": "https://falconllm.tii.ae/",
|
||||
"requirements": ["einops", "xformers", "safetensors"],
|
||||
"default_id": "tiiuae/falcon-7b",
|
||||
"model_ids": [
|
||||
"tiiuae/falcon-7b",
|
||||
"tiiuae/falcon-40b",
|
||||
"tiiuae/falcon-7b-instruct",
|
||||
"tiiuae/falcon-40b-instruct",
|
||||
],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 200
|
||||
top_k: int = 10
|
||||
|
||||
@@ -34,14 +34,12 @@ else:
|
||||
class Falcon(openllm.LLM):
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_id = "tiiuae/falcon-7b"
|
||||
|
||||
model_ids = ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct"]
|
||||
|
||||
import_kwargs = {
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
}
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
}
|
||||
|
||||
def import_model(
|
||||
self, model_id: str, tag: bentoml.Tag, *model_args: t.Any, tokenizer_kwds: dict[str, t.Any], **attrs: t.Any
|
||||
|
||||
@@ -44,13 +44,25 @@ $ openllm start flan-t5 --model-id google/flan-t5-xxl
|
||||
DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruction}\nAnswer:"""
|
||||
|
||||
|
||||
class FlanT5Config(openllm.LLMConfig, url="https://huggingface.co/docs/transformers/model_doc/flan-t5"):
|
||||
class FlanT5Config(openllm.LLMConfig):
|
||||
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf)
|
||||
- it is an enhanced version of T5 that has been finetuned in a mixture of tasks.
|
||||
|
||||
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5",
|
||||
"default_id": "google/flan-t5-large",
|
||||
"model_ids": [
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 2048
|
||||
|
||||
@@ -29,17 +29,8 @@ else:
|
||||
class FlanT5(openllm.LLM):
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_id = "google/flan-t5-large"
|
||||
|
||||
model_ids = [
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
]
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
def llm_post_init(self):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
|
||||
@@ -24,16 +24,6 @@ from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
class FlaxFlanT5(openllm.LLM):
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_id: str = "google/flan-t5-large"
|
||||
|
||||
model_ids = [
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
]
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -24,16 +24,6 @@ from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
class TFFlanT5(openllm.LLM):
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_id: str = "google/flan-t5-large"
|
||||
|
||||
model_ids = [
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
]
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
|
||||
class StableLMConfig(openllm.LLMConfig, name_type="lowercase", url="https://github.com/Stability-AI/StableLM"):
|
||||
class StableLMConfig(openllm.LLMConfig):
|
||||
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models
|
||||
pre-trained on a diverse collection of English datasets with a sequence
|
||||
length of 4096 to push beyond the context window limitations of existing open-source language models.
|
||||
@@ -30,6 +30,18 @@ class StableLMConfig(openllm.LLMConfig, name_type="lowercase", url="https://gith
|
||||
for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"url": "https://github.com/Stability-AI/StableLM",
|
||||
"default_id": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
"model_ids": [
|
||||
"stabilityai/stablelm-tuned-alpha-3b",
|
||||
"stabilityai/stablelm-tuned-alpha-7b",
|
||||
"stabilityai/stablelm-base-alpha-3b",
|
||||
"stabilityai/stablelm-base-alpha-7b",
|
||||
],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 128
|
||||
|
||||
@@ -42,23 +42,17 @@ logger = logging.getLogger(__name__)
|
||||
class StableLM(openllm.LLM):
|
||||
__openllm_internal__ = True
|
||||
|
||||
load_in_mha = True if not torch.cuda.is_available() else False
|
||||
default_id = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
def llm_post_init(self):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.load_in_mha = True if not torch.cuda.is_available() else False
|
||||
|
||||
model_ids = [
|
||||
"stabilityai/stablelm-tuned-alpha-3b",
|
||||
"stabilityai/stablelm-tuned-alpha-7b",
|
||||
"stabilityai/stablelm-base-alpha-3b",
|
||||
"stabilityai/stablelm-base-alpha-7b",
|
||||
]
|
||||
|
||||
import_kwargs = {
|
||||
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
"load_in_8bit": False,
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
}
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {
|
||||
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
"load_in_8bit": False,
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
}
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
@@ -98,7 +92,6 @@ class StableLM(openllm.LLM):
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
generation_kwargs = {
|
||||
"do_sample": True,
|
||||
|
||||
@@ -16,14 +16,7 @@ from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
|
||||
class StarCoderConfig(
|
||||
openllm.LLMConfig,
|
||||
name_type="lowercase",
|
||||
requires_gpu=True,
|
||||
url="https://github.com/bigcode-project/starcoder",
|
||||
requirements=["bitsandbytes"],
|
||||
workers_per_resource=0.5,
|
||||
):
|
||||
class StarCoderConfig(openllm.LLMConfig):
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from
|
||||
[The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
|
||||
@@ -34,6 +27,16 @@ class StarCoderConfig(
|
||||
Refer to [StarCoder's model card](https://huggingface.co/bigcode/starcoder) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"requires_gpu": True,
|
||||
"url": "https://github.com/bigcode-project/starcoder",
|
||||
"requirements": ["bitsandbytes"],
|
||||
"workers_per_resource": 0.5,
|
||||
"default_id": "bigcode/starcoder",
|
||||
"model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.2
|
||||
max_new_tokens: int = 256
|
||||
|
||||
@@ -40,18 +40,17 @@ FIM_INDICATOR = "<FILL_HERE>"
|
||||
class StarCoder(openllm.LLM):
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_id = "bigcode/starcoder"
|
||||
def llm_post_init(self):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model_ids = ["bigcode/starcoder", "bigcode/starcoderbase"]
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
import_kwargs = {
|
||||
"_tokenizer_padding_side": "left",
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
"load_in_8bit": True if torch.cuda.device_count() > 1 else False,
|
||||
"torch_dtype": torch.float16,
|
||||
}
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {
|
||||
"_tokenizer_padding_side": "left",
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
"load_in_8bit": True if torch.cuda.device_count() > 1 else False,
|
||||
"torch_dtype": torch.float16,
|
||||
}
|
||||
|
||||
def import_model(
|
||||
self,
|
||||
|
||||
@@ -38,7 +38,10 @@ def _default_converter(value: t.Any, env: str | None) -> t.Any:
|
||||
if env is not None:
|
||||
value = os.environ.get(env, value)
|
||||
if value is not None and isinstance(value, str):
|
||||
return eval(value, {"__builtins__": {}}, {})
|
||||
try:
|
||||
return orjson.loads(value)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise RuntimeError(f"Failed to parse '{value}' from '{env}': {err}")
|
||||
return value
|
||||
|
||||
|
||||
@@ -65,7 +68,7 @@ def Field(
|
||||
"""
|
||||
metadata = attrs.pop("metadata", {})
|
||||
if description is None:
|
||||
description = "(No description is available)"
|
||||
description = "(No description provided)"
|
||||
metadata["description"] = description
|
||||
if env is not None:
|
||||
metadata["env"] = env
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import sys
|
||||
from typing import (
|
||||
@@ -500,3 +502,5 @@ def _make_init(
|
||||
cls_on_setattr: Any,
|
||||
attrs_init: bool,
|
||||
) -> Callable[_P, Any]: ...
|
||||
def _make_method(name: str, script: str, filename: str, globs: dict[str, Any]) -> Callable[..., Any]: ...
|
||||
def _make_repr(attrs: tuple[Attribute[Any]], ns: str | None, cls: AttrsInstance) -> Callable[[AttrsInstance], str]: ...
|
||||
|
||||
4
typings/attr/_compat.pyi
Normal file
4
typings/attr/_compat.pyi
Normal file
@@ -0,0 +1,4 @@
|
||||
import threading
|
||||
|
||||
set_closure_cell = ...
|
||||
repr_context: threading.local = ...
|
||||
Reference in New Issue
Block a user