From dfca956fad608edaceba41c1b47eceeefd0a64d1 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Fri, 23 Jun 2023 10:07:15 -0400 Subject: [PATCH] feat: serve adapter layers (#52) --- .gitattributes | 2 + README.md | 7 +- changelog.d/52.feature.md | 45 + ....generated.txt => nightly-requirements.txt | 3 +- pyproject.toml | 4 +- src/openllm/__init__.py | 20 + src/openllm/_configuration.py | 851 +++++++++++++----- src/openllm/_llm.py | 662 ++++++++++---- src/openllm/_package.py | 139 ++- src/openllm/_service.py | 53 +- src/openllm/_trainer.py | 25 + src/openllm/cli.py | 168 +++- src/openllm/exceptions.py | 4 + src/openllm/models/auto/factory.py | 18 +- .../models/dolly_v2/modeling_dolly_v2.py | 3 - src/openllm/models/falcon/modeling_falcon.py | 4 - src/openllm/models/opt/configuration_opt.py | 10 + src/openllm/models/opt/modeling_opt.py | 15 +- .../models/starcoder/modeling_starcoder.py | 3 - src/openllm/playground/README | 2 + src/openllm/playground/ft_opt_lora.py | 11 +- src/openllm/utils/__init__.py | 12 + src/openllm/utils/codegen.py | 33 +- src/openllm/utils/dantic.py | 11 +- src/openllm/utils/import_utils.py | 96 +- src/openllm/utils/representation.py | 61 ++ src/openllm_client/runtimes/base.py | 15 +- tools/assert-model-table-latest | 14 +- tools/update-config-stubs.py | 90 ++ tools/update-optional-dependencies.py | 5 +- typings/attr/__init__.pyi | 2 +- typings/click_option_group/_decorators.pyi | 2 +- typings/deepmerge/merger.pyi | 2 +- 33 files changed, 1896 insertions(+), 496 deletions(-) create mode 100644 .gitattributes create mode 100644 changelog.d/52.feature.md rename nightly-requirements.generated.txt => nightly-requirements.txt (87%) create mode 100644 src/openllm/_trainer.py create mode 100644 src/openllm/utils/representation.py create mode 100755 tools/update-config-stubs.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..9932d692 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +nightly-requirements.txt linguist-generated=true +* text=auto eol=lf diff --git a/README.md b/README.md index 5fab5f3e..4cb30428 100644 --- a/README.md +++ b/README.md @@ -346,10 +346,9 @@ async def prompt(input_text: str) -> str: OpenLLM seamlessly integrates with Hugging Face Agents. -> **Warning** The Hugging Face Agent is still in the experimental stage. It is -> recommended to OpenLLM with -> `pip install -r nightly-requirements.generated.txt` to get the latest API -> update for Hugging Face agent. +> **Warning** The HuggingFace Agent is still at experimental stage. It is +> recommended to OpenLLM with `pip install -r nightly-requirements.txt` to get +> the latest API update for HuggingFace agent. ```python import transformers diff --git a/changelog.d/52.feature.md b/changelog.d/52.feature.md new file mode 100644 index 00000000..b0f52761 --- /dev/null +++ b/changelog.d/52.feature.md @@ -0,0 +1,45 @@ +#### Serving LLM with fine-tuned LoRA, QLoRA adapters layers + +Then the given fine tuning weights can be served with the model via +`openllm start`: + +```bash +openllm start opt --model-id facebook/opt-6.7b --adapter-id /path/to/adapters +``` + +If you just wish to try some pretrained adapter checkpoint, you can use +`--adapter-id`: + +```bash +openllm start opt --model-id facebook/opt-6.7b --adapter-id aarnphm/opt-6.7b-lora +``` + +To use multiple adapters, use the following format: + +```bash +openllm start opt --model-id facebook/opt-6.7b --adapter-id aarnphm/opt-6.7b-lora --adapter-id aarnphm/opt-6.7b-lora:french_lora +``` + +By default, the first `adapter-id` will be the default lora layer, but +optionally users can change what lora layer to use for inference via +`/v1/adapters`: + +```bash +curl -X POST http://localhost:3000/v1/adapters --json '{"adapter_name": "vn_lora"}' +``` + +> Note that for multiple `adapter-name` and `adapter-id`, it is recomended to +> update to use the default adapter before sending the inference, to avoid any +> performance degradation + +To include this into the Bento, one can also provide a `--adapter-id` into +`openllm build`: + +```bash +openllm build opt --model-id facebook/opt-6.7b --adapter-id ... +``` + +### Rework + +Separate out configuration builder, to make it more flexible for future +configuration generation. diff --git a/nightly-requirements.generated.txt b/nightly-requirements.txt similarity index 87% rename from nightly-requirements.generated.txt rename to nightly-requirements.txt index 76201d8c..684d14f8 100644 --- a/nightly-requirements.generated.txt +++ b/nightly-requirements.txt @@ -1,9 +1,10 @@ # This file is generated by `./tools/update-optional-dependencies.py` # DO NOT EDIT --e . +-e .[all] bentoml[grpc,io] @ git+https://github.com/bentoml/bentoml.git@main peft @ git+https://github.com/huggingface/peft.git@main transformers[torch,tokenizers,accelerate] @ git+https://github.com/huggingface/transformers.git@main optimum @ git+https://github.com/huggingface/optimum.git@main accelerate @ git+https://github.com/huggingface/accelerate.git@main bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git@main +deepspeed @ git+https://github.com/microsoft/deepspeed.git@master diff --git a/pyproject.toml b/pyproject.toml index 9508c386..64d11ec8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,15 +58,15 @@ requires-python = ">=3.8" # NOTE: Don't modify project.optional-dependencies # as it is managed by ./tools/update-optional-dependencies.py [project.optional-dependencies] -agents = ["transformers[agents]", "diffusers", "soundfile"] +agents = ["transformers[agents]>=4.30", "diffusers", "soundfile"] all = [ "openllm[chatglm]", "openllm[starcoder]", "openllm[falcon]", "openllm[agents]", + "openllm[flan-t5]", "openllm[fine-tune]", "openllm[openai]", - "openllm[flan-t5]", ] chatglm = ["cpm_kernels", "sentencepiece"] falcon = ["einops", "xformers", "safetensors"] diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index 0a6eb995..94a66f88 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -26,7 +26,9 @@ deploy, and monitor any LLMs with ease. from __future__ import annotations import logging +import os import typing as t +import warnings from . import utils as utils from .__about__ import __version__ as __version__ @@ -39,6 +41,24 @@ if utils.DEBUG: utils.configure_logging() logging.basicConfig(level=logging.NOTSET) +else: + # configuration for bitsandbytes before import + os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") + # The following warnings from bitsandbytes, and probably not that important + # for users to see when DEBUG is False + warnings.filterwarnings( + "ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization" + ) + warnings.filterwarnings( + "ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization" + ) + warnings.filterwarnings( + "ignore", + message=( + "The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization" + " are unavailable." + ), + ) _import_structure = { diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 9015ec3f..0ff67a0e 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -49,6 +49,7 @@ Refer to ``openllm.LLMConfig`` docstring for more information. from __future__ import annotations import copy +import enum import functools import inspect import logging @@ -69,7 +70,6 @@ from deepmerge.merger import Merger import openllm from .exceptions import ForbiddenAttributeError -from .utils import DEBUG from .utils import ENV_VARS_TRUE_VALUES from .utils import LazyType from .utils import bentoml_cattr @@ -78,6 +78,7 @@ from .utils import dantic from .utils import first_not_none from .utils import lenient_issubclass from .utils import non_intrusive_setattr +from .utils import requires_dependencies if hasattr(t, "Required"): @@ -95,14 +96,23 @@ if hasattr(t, "dataclass_transform"): else: from typing_extensions import dataclass_transform +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if hasattr(t, "get_overloads"): + from typing import overload +else: + from typing_extensions import overload + _T = t.TypeVar("_T") if t.TYPE_CHECKING: + import peft from attr import _CountingAttr # type: ignore from attr import _make_init # type: ignore from attr import _make_repr # type: ignore from attr import _transform_attrs # type: ignore + from attr._compat import set_closure_cell import transformers from transformers.generation.beam_constraints import Constraint @@ -112,8 +122,6 @@ if t.TYPE_CHECKING: from ._types import O_co from ._types import P - ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]] - DictStrAny = dict[str, t.Any] ListStr = list[str] ItemgetterAny = itemgetter[t.Any] @@ -125,12 +133,14 @@ else: ItemgetterAny = itemgetter # NOTE: Using internal API from attr here, since we are actually # allowing subclass of openllm.LLMConfig to become 'attrs'-ish + from attr._compat import set_closure_cell from attr._make import _CountingAttr from attr._make import _make_init from attr._make import _make_repr from attr._make import _transform_attrs transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") + peft = openllm.utils.LazyLoader("peft", globals(), "peft") __all__ = ["LLMConfig"] @@ -146,6 +156,209 @@ config_merger = Merger( ) +# case insensitive, but rename to conform with type +class _PeftEnumMeta(enum.EnumMeta): + def __getitem__(self, __key: str | t.Any) -> PeftType: + if isinstance(__key, str): + __key = inflection.underscore(__key).upper() + return super().__getitem__(__key) + + +# vendorred from peft.utils.config.PeftType +# since we don't hard depend on peft +class PeftType(enum.Enum, metaclass=_PeftEnumMeta): + PROMPT_TUNING = "PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" + LORA = "LORA" + ADALORA = "ADALORA" + ADAPTION_PROMPT = "ADAPTION_PROMPT" + + @classmethod + def _missing_(cls, value: object) -> PeftType | None: + if isinstance(value, str): + normalized = inflection.underscore(value).upper() + if normalized in cls._member_map_: + return cls[normalized] + + @classmethod + def supported(cls) -> set[str]: + return {inflection.underscore(v.value) for v in cls} + + def to_str(self) -> str: + return self.value + + +_PEFT_TASK_TYPE_TARGET_MAPPING = {"causal_lm": "CAUSAL_LM", "seq2seq_lm": "SEQ_2_SEQ_LM"} + +if t.TYPE_CHECKING: + AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning"] +else: + AdapterType = str + +_object_setattr = object.__setattr__ + + +def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType: + if value is None: + raise ValueError("'AdapterType' cannot be None.") + if isinstance(value, PeftType): + return value + if isinstance(value, str) and value not in PeftType.supported(): + raise ValueError(f"Given '{value}' is not a supported adapter type.") + return PeftType[value] + + +@attr.define(slots=True) +class FineTuneConfig: + """FineTuneConfig defines a default value for fine-tuning this any given LLM. For example: + + This is a lower level API that leverage `peft` as well as openllm.LLMConfig to create default + and customization + """ + + if t.TYPE_CHECKING: + # The following type stubs makes __init__ aware of attrs + # internal type converter. + @overload + def __init__(self, *args: str | PeftType | dict[str, t.Any] | bool | type[LLMConfig]) -> None: + ... + + @overload + def __init__( + self, + adapter_type: PeftType = ..., + adapter_config: dict[str, t.Any] = ..., + inference_mode: bool = ..., + llm_config_class: type[LLMConfig] = ..., + ) -> None: + ... + + @overload + def __init__( + self, + adapter_type: AdapterType = ..., + adapter_config: dict[str, t.Any] = ..., + inference_mode: bool = ..., + llm_config_class: type[LLMConfig] = ..., + ) -> None: + ... + + # The below should be generated via attrs. Only here to conform with pyright strict checking. + def __init__(self, *args: t.Any, **kwargs: t.Any): + ... + + adapter_type: PeftType = dantic.Field( + "lora", + description=f"The type of adapter to use for fine-tuning. Available supported methods: {PeftType.supported()}, default to 'lora'", + use_default_converter=False, + converter=_adapter_converter, + ) + adapter_config: t.Dict[str, t.Any] = dantic.Field( + None, + description="The configuration for the adapter. The content of the dict depends on the adapter type.", + validator=attr.validators.optional(attr.validators.instance_of(dict)), + converter=attr.converters.default_if_none(factory=dict), + use_default_converter=False, + ) + inference_mode: bool = dantic.Field( + False, + description="Whether to use this Adapter for inference", + use_default_converter=False, + ) + llm_config_class: type[LLMConfig] = dantic.Field( + None, + description="The reference class to openllm.LLMConfig", + use_default_converter=False, + ) + + @requires_dependencies("peft", extra="fine-tune") + def to_peft_config(self) -> peft.PeftConfig: + # makes a copy to correctly set all modules + adapter_config = self.adapter_config.copy() + if "peft_type" in adapter_config: + logger.debug( + "'peft_type' is not required as it is managed internally by '%s' and 'peft'.", + self.__class__.__name__, + ) + adapter_config.pop("peft_type") + + # respect user set task_type if it is passed, otherwise use one managed by OpenLLM + task_type = adapter_config.pop("task_type", peft.TaskType[self.llm_config_class.peft_task_type()]) + inference_mode = adapter_config.pop("inference_mode", self.inference_mode) + + return peft.PEFT_TYPE_TO_CONFIG_MAPPING[self.adapter_type.to_str()]( + task_type=task_type, + inference_mode=inference_mode, + **adapter_config, + ) + + def train(self) -> FineTuneConfig: + _object_setattr(self, "inference_mode", False) + return self + + def eval(self) -> FineTuneConfig: + _object_setattr(self, "inference_mode", True) + return self + + def with_config(self, **attrs: t.Any) -> FineTuneConfig: + """Create a new instance of FineTuneConfig with the given attributes.""" + adapter_type = attrs.pop("adapter_type", self.adapter_type) + inference_mode = attrs.get("inference_mode", self.inference_mode) + if "llm_config_class" in attrs: + raise ForbiddenAttributeError("'llm_config_class' should not be passed when using 'with_config'.") + return attr.evolve( + self, + adapter_type=adapter_type, + inference_mode=inference_mode, + adapter_config=config_merger.merge(self.adapter_config, attrs), + ) + + @classmethod + def make_adapter_config_class( + cls, + adapter_type: AdapterType, + llm_config_class: type[LLMConfig], + /, + *, + docs: str | None = None, + **attrs: t.Any, + ) -> type[FineTuneConfig]: + """A loose codegen to create default subclass for given adapter config type""" + + _new_default = { + "adapter_type": PeftType[adapter_type], + "adapter_config": attrs, + "llm_config_class": llm_config_class, + } + + def transformers(_: type[t.Any], fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]: + transformed: list[attr.Attribute[t.Any]] = [] + for f in fields: + if f.name in _new_default: + transformed.append(f.evolve(default=_new_default[f.name])) + else: + transformed.append(f) + return transformed + + klass = attr.make_class( + f"{inflection.camelize(adapter_type)}{llm_config_class.__name__}", + [], + bases=(cls,), + slots=True, + weakref_slot=True, + frozen=True, + repr=True, + collect_by_mro=True, + field_transformer=transformers, + ) + + if docs is not None: + klass.__doc__ = docs + + return klass + + @attr.frozen(slots=True) class GenerationConfig: """Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``, @@ -403,6 +616,9 @@ bentoml_cattr.register_unstructure_hook_factory( lambda cls: make_dict_unstructure_fn( cls, bentoml_cattr, + # The below is the default, put here for strict annotations + _cattrs_omit_if_default=False, + _cattrs_use_linecache=True, **{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)}, ), ) @@ -419,6 +635,9 @@ _object_getattribute = object.__getattribute__ class ModelSettings(t.TypedDict, total=False): """ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__. Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig. + + If the field below changes, make sure to run ./tools/update-config-stubs.py to generate correct __getitem__ + stubs for type-checking purposes. """ # NOTE: These required fields should be at the top, as it will be kw_only @@ -450,12 +669,13 @@ class ModelSettings(t.TypedDict, total=False): workers_per_resource: t.Union[int, float] # the target generation_config class to be used. - generation_class: t.Type[GenerationConfig] + fine_tune_strategies: t.Tuple[t.Dict[str, t.Any], ...] def _settings_field_transformer( _: type[attr.AttrsInstance], __: list[attr.Attribute[t.Any]] ) -> list[attr.Attribute[t.Any]]: + _transformed_type: dict[str, type[t.Any]] = {"fine_tune_strategies": t.Dict[AdapterType, FineTuneConfig]} return [ attr.Attribute.from_counting_attr( k, @@ -463,7 +683,7 @@ def _settings_field_transformer( kw_only=False if t.get_origin(ann) is not Required else True, auto_default=True, use_default_converter=False, - type=ann, + type=_transformed_type.get(k, ann), metadata={"target": f"__openllm_{k}__"}, description=f"ModelSettings field for {k}.", ), @@ -483,33 +703,30 @@ class _ModelSettingsAttr: @classmethod def default(cls) -> _ModelSettingsAttr: - _ = ModelSettings( - default_id="__default__", - model_ids=["__default__"], - name_type="dasherize", - requires_gpu=False, - url="", - use_pipeline=False, - model_type="causal_lm", - trust_remote_code=False, - requirements=None, - timeout=int(36e6), - service_name="", - workers_per_resource=1, - runtime="transformers", + return cls( + **t.cast( + DictStrAny, + ModelSettings( + default_id="__default__", + model_ids=["__default__"], + name_type="dasherize", + requires_gpu=False, + url="", + use_pipeline=False, + model_type="causal_lm", + trust_remote_code=False, + requirements=None, + timeout=int(36e6), + service_name="", + workers_per_resource=1, + runtime="transformers", + ), + ) ) - return cls(**t.cast(DictStrAny, _)) def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]): - if not lenient_issubclass(cl_, LLMConfig): - raise RuntimeError(f"Given '{cl_}' must be a subclass type of 'LLMConfig', got '{cl_}' instead.") - - if not hasattr(cl_, "__config__") or getattr(cl_, "__config__") is None: - raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.") - assert cl_.__config__ is not None - if "generation_class" in cl_.__config__: raise ValueError( "'generation_class' shouldn't be defined in '__config__', rather defining " @@ -517,7 +734,6 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]): ) _cl_name = cl_.__name__.replace("Config", "") - _settings_attr = _ModelSettingsAttr.default() try: cls(**t.cast(DictStrAny, cl_.__config__)) @@ -548,25 +764,28 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]): _final_value_dct["bettertransformer"] = False _final_value_dct["service_name"] = f"generated_{_final_value_dct['model_name']}_service.py" - _final_value_dct["generation_class"] = attr.make_class( - f"{_cl_name}GenerationConfig", - [], - bases=(GenerationConfig,), - slots=True, - weakref_slot=True, - frozen=True, - repr=True, - collect_by_mro=True, - field_transformer=_make_env_transformer( - cl_, - _final_value_dct["model_name"], - suffix="generation", - default_callback=lambda field_name, field_default: getattr(cl_.GenerationConfig, field_name, field_default) - if codegen.has_own_attribute(cl_, "GenerationConfig") - else field_default, - globs={"cl_": cl_}, - ), - ) + + # NOTE: The key for fine-tune strategies is 'fine_tune_strategies' + _fine_tune_strategies: tuple[dict[str, t.Any], ...] | None = getattr(_settings_attr, "fine_tune_strategies", None) + _converted: dict[AdapterType, FineTuneConfig] = {} + if _fine_tune_strategies is not None: + # the given value is a tuple[dict[str, t.Any] ,...] + for _possible_ft_config in _fine_tune_strategies: + _adapter_type: AdapterType | None = _possible_ft_config.pop("adapter_type", None) + if _adapter_type is None: + raise RuntimeError("'adapter_type' is required under config definition (currently missing)'.") + _llm_config_class = _possible_ft_config.pop("llm_config_class", cl_) + _doc = _possible_ft_config.pop( + "docs", + f"Default {inflection.camelize(_adapter_type)}Config for {_final_value_dct['model_name']}", + ) + _converted[_adapter_type] = codegen.add_method_dunders( + cl_, + FineTuneConfig.make_adapter_config_class( + _adapter_type, _llm_config_class, docs=_doc, **_possible_ft_config + ), + )() + _final_value_dct["fine_tune_strategies"] = _converted return attr.evolve(_settings_attr, **_final_value_dct) @@ -623,7 +842,7 @@ def _make_env_transformer( ) -def _setattr_class(attr_name: str, value_var: t.Any, add_dunder: bool = False): +def _setattr_class(attr_name: str, value_var: t.Any): """ Use the builtin setattr to set *attr_name* to *value_var*. We can't use the cached object.__setattr__ since we are setting @@ -633,11 +852,7 @@ def _setattr_class(attr_name: str, value_var: t.Any, add_dunder: bool = False): value that will be used to add the dunder methods to the class for given value_var """ - val = f"__add_dunder(cls, {value_var})" if add_dunder else value_var - return f"setattr(cls, '{attr_name}', {val})" - - -_dunder_add = {"generation_class"} + return f"setattr(cls, '{attr_name}', {value_var})" def _make_assignment_script( @@ -649,7 +864,6 @@ def _make_assignment_script( "cls": cls, "_cached_attribute": attributes, "_cached_getattribute_get": _object_getattribute.__get__, - "__add_dunder": codegen.add_method_dunders, } annotations: DictStrAny = {"return": None} @@ -657,7 +871,7 @@ def _make_assignment_script( for attr_name, field in attr.fields_dict(attributes.__class__).items(): arg_name = field.metadata.get("target", f"__{_prefix}_{inflection.underscore(attr_name)}__") args.append(f"{attr_name}=getattr(_cached_attribute, '{attr_name}')") - lines.append(_setattr_class(arg_name, attr_name, add_dunder=attr_name in _dunder_add)) + lines.append(_setattr_class(arg_name, attr_name)) annotations[attr_name] = field.type return codegen.generate_function( @@ -686,63 +900,7 @@ def __llm_config_transform__(cls: type[LLMConfig]) -> type[LLMConfig]: @attr.define(slots=True) -class LLMConfig: - """ - ``openllm.LLMConfig`` is somewhat a hybrid combination between the performance of `attrs` with the - easy-to-use interface that pydantic offer. It lives in between where it allows users to quickly formulate - a LLMConfig for any LLM without worrying too much about performance. It does a few things: - - - Automatic environment conversion: Each fields will automatically be provisioned with an environment - variable, make it easy to work with ahead-of-time or during serving time - - Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API, - i.e: ``model_construct_env``, ``to_generation_config``, ``to_click_options`` - - Automatic CLI generation: It can identify each fields and convert it to compatible Click options. - This means developers can use any of the LLMConfig to create CLI with compatible-Python - CLI library (click, typer, ...) - - > Internally, LLMConfig is an attrs class. All subclass of LLMConfig contains "attrs-like" features, - > which means LLMConfig will actually generate subclass to have attrs-compatible API, so that the subclass - > can be written as any normal Python class. - - To directly configure GenerationConfig for any given LLM, create a GenerationConfig under the subclass: - - ```python - class FlanT5Config(openllm.LLMConfig): - class GenerationConfig: - temperature: float = 0.75 - max_new_tokens: int = 3000 - top_k: int = 50 - top_p: float = 0.4 - repetition_penalty = 1.0 - ``` - By doing so, openllm.LLMConfig will create a compatible GenerationConfig attrs class that can be converted - to ``transformers.GenerationConfig``. These attribute can be accessed via ``LLMConfig.generation_config``. - - By default, all LLMConfig must provide a __config__ with 'default_id' and 'model_ids'. - - All other fields are optional, and will be use default value if not set. - - ```python - class FalconConfig(openllm.LLMConfig): - __config__ = { - "name_type": "lowercase", - "trust_remote_code": True, - "requires_gpu": True, - "timeout": 3600000, - "url": "https://falconllm.tii.ae/", - "requirements": ["einops", "xformers", "safetensors"], - # NOTE: The below are always required - "default_id": "tiiuae/falcon-7b", - "model_ids": [ - "tiiuae/falcon-7b", - "tiiuae/falcon-40b", - "tiiuae/falcon-7b-instruct", - "tiiuae/falcon-40b-instruct", - ], - } - ``` - """ - +class _ConfigAttr: Field = dantic.Field """Field is a alias to the internal dantic utilities to easily create attrs.fields with pydantic-compatible interface. For example: @@ -796,7 +954,7 @@ class LLMConfig: """Extra metadata for this LLMConfig.""" # NOTE: The following will be populated from __config__ and also - # considered to be public API. + # considered to be public API. Users can also access these via self[key] __openllm_default_id__: str = Field(None) """Return the default model to use when using 'openllm start '. This could be one of the keys in 'self.model_ids' or custom users model. @@ -884,7 +1042,264 @@ class LLMConfig: to create the generation_config argument that can be used throughout the lifecycle. This class will also be managed internally by OpenLLM.""" - def __init_subclass__(cls): + __openllm_fine_tune_strategies__: dict[AdapterType, FineTuneConfig] = Field(None) + """The fine-tune strategies for this given LLM.""" + + +@attr.define(slots=True) +class LLMConfig(_ConfigAttr): + """ + ``openllm.LLMConfig`` is somewhat a hybrid combination between the performance of `attrs` with the + easy-to-use interface that pydantic offer. It lives in between where it allows users to quickly formulate + a LLMConfig for any LLM without worrying too much about performance. It does a few things: + + - Automatic environment conversion: Each fields will automatically be provisioned with an environment + variable, make it easy to work with ahead-of-time or during serving time + - Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API, + i.e: ``model_construct_env``, ``to_generation_config``, ``to_click_options`` + - Automatic CLI generation: It can identify each fields and convert it to compatible Click options. + This means developers can use any of the LLMConfig to create CLI with compatible-Python + CLI library (click, typer, ...) + + > Internally, LLMConfig is an attrs class. All subclass of LLMConfig contains "attrs-like" features, + > which means LLMConfig will actually generate subclass to have attrs-compatible API, so that the subclass + > can be written as any normal Python class. + + To directly configure GenerationConfig for any given LLM, create a GenerationConfig under the subclass: + + ```python + class FlanT5Config(openllm.LLMConfig): + class GenerationConfig: + temperature: float = 0.75 + max_new_tokens: int = 3000 + top_k: int = 50 + top_p: float = 0.4 + repetition_penalty = 1.0 + ``` + By doing so, openllm.LLMConfig will create a compatible GenerationConfig attrs class that can be converted + to ``transformers.GenerationConfig``. These attribute can be accessed via ``LLMConfig.generation_config``. + + By default, all LLMConfig must provide a __config__ with 'default_id' and 'model_ids'. + + All other fields are optional, and will be use default value if not set. + + ```python + class FalconConfig(openllm.LLMConfig): + __config__ = { + "name_type": "lowercase", + "trust_remote_code": True, + "requires_gpu": True, + "timeout": 3600000, + "url": "https://falconllm.tii.ae/", + "requirements": ["einops", "xformers", "safetensors"], + # NOTE: The below are always required + "default_id": "tiiuae/falcon-7b", + "model_ids": [ + "tiiuae/falcon-7b", + "tiiuae/falcon-40b", + "tiiuae/falcon-7b-instruct", + "tiiuae/falcon-40b-instruct", + ], + } + ``` + + > **changelog** + > Since 0.1.7, one can also define given fine-tune strategies for given LLM via its config: + ```python + class OPTConfig(openllm.LLMConfig): + __config__ = { + "name_type": "lowercase", + "trust_remote_code": False, + "url": "https://huggingface.co/docs/transformers/model_doc/opt", + "default_id": "facebook/opt-1.3b", + "model_ids": [ + "facebook/opt-125m", + "facebook/opt-350m", + "facebook/opt-1.3b", + "facebook/opt-2.7b", + "facebook/opt-6.7b", + "facebook/opt-66b", + ], + "fine_tune_strategies": ( + { + "adapter_type": "lora", + "r": 16, + "lora_alpha": 32, + "target_modules": ["q_proj", "v_proj"], + "lora_dropout": 0.05, + "bias": "none", + }, + ), + } + ``` + """ + + class _ConfigBuilder: + """A modified version of attrs internal _ClassBuilder, should only be called + within __init_subclass__ of LLMConfig. + + Where: + - has_custom_setattr=True + - getstate_setstate=None (config class will always be a slotted class.) + - slots=True + - auto_attribs=False (We should handle it before _ConfigBuilder is invoked) + - cache_hash=False (We don't need to cache the hash code of this object for now.) + - collect_by_mro=True (The correct behaviour to resolve inheritance) + - field_transformer=_make_env_transformer (We need to transform the field to have env variable) + + It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__ + """ + + __slots__ = ( + "_cls", + "_cls_dict", + "_attr_names", + "_attrs", + "_model_name", + "_base_attr_map", + "_base_names", + "_has_pre_init", + "_has_post_init", + ) + + def __init__( + self, + cls: type[LLMConfig], + model_name: str, + these: dict[str, _CountingAttr[t.Any]], + auto_attribs: bool = False, + kw_only: bool = False, + collect_by_mro: bool = True, + ): + attrs, base_attrs, base_attr_map = _transform_attrs( + cls, + these, + auto_attribs, + kw_only, + collect_by_mro, + field_transformer=_make_env_transformer(cls, model_name), + ) + + self._cls = cls + self._model_name = model_name + self._cls_dict = dict(cls.__dict__) + self._attrs = attrs + self._base_names = {a.name for a in base_attrs} + self._base_attr_map = base_attr_map + self._attr_names = tuple(a.name for a in attrs) + self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) + self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) + + self._cls_dict["__attrs_attrs__"] = self._attrs + + def build_class(self) -> type[LLMConfig]: + """ + Finalize class based on the accumulated configuration. + + Builder cannot be used after calling this method. + + > A difference between this and attrs._ClassBuilder is that we don't + > create a new class after constructing all __dict__. This has to do + > with recursive called within __init_subclass__ + """ + cd = { + k: v + for k, v in self._cls_dict.items() + if k not in tuple(self._attr_names) + ("__dict__", "__weakref__") + } + # Traverse the MRO to collect existing slots + # and check for an existing __weakref__. + existing_slots: DictStrAny = {} + weakref_inherited = False + for base_cls in self._cls.__mro__[1:-1]: + if base_cls.__dict__.get("__weakref__", None) is not None: + weakref_inherited = True + existing_slots.update({name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])}) + + base_names = set(self._base_names) + names = self._attr_names + if ( + "__weakref__" not in getattr(self._cls, "__slots__", ()) + and "__weakref__" not in names + and not weakref_inherited + ): + names += ("__weakref__",) + + # We only add the names of attributes that aren't inherited. + # Setting __slots__ to inherited attributes wastes memory. + slot_names = [name for name in names if name not in base_names] + # There are slots for attributes from current class + # that are defined in parent classes. + # As their descriptors may be overridden by a child class, + # we collect them here and update the class dict + reused_slots = { + slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names + } + # We only add the names of attributes that aren't inherited. + # Setting __slots__ to inherited attributes wastes memory. + # __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots + slot_names = [name for name in slot_names if name not in reused_slots] + cd.update(reused_slots) + cd["__slots__"] = tuple(slot_names) + + for k, value in cd.items(): + setattr(self._cls, k, value) + + # The following is a fix for + # . + # If a method mentions `__class__` or uses the no-arg super(), the + # compiler will bake a reference to the class in the method itself + # as `method.__closure__`. Since we replace the class with a + # clone, we rewrite these references so it keeps working. + for item in self._cls.__dict__.values(): + if isinstance(item, (classmethod, staticmethod)): + # Class- and staticmethods hide their functions inside. + # These might need to be rewritten as well. + closure_cells = getattr(item.__func__, "__closure__", None) + elif isinstance(item, property): + # Workaround for property `super()` shortcut (PY3-only). + # There is no universal way for other descriptors. + closure_cells = getattr(item.fget, "__closure__", None) + else: + closure_cells = getattr(item, "__closure__", None) + + if not closure_cells: # Catch None or the empty list. + continue + for cell in closure_cells: + try: + match = cell.cell_contents is self._cls + except ValueError: # ValueError: Cell is empty + pass + else: + if match: + set_closure_cell(cell, self._cls) + + return __llm_config_transform__(self._cls) + + def add_attrs_init(self) -> t.Self: + self._cls_dict["__attrs_init__"] = codegen.add_method_dunders( + self._cls, + _make_init( + self._cls, + self._attrs, + self._has_pre_init, + self._has_post_init, + False, # frozen + True, # slots + False, # cache_hash + self._base_attr_map, + False, # This is not an exception + None, # no on_setattr + attrs_init=True, + ), + ) + return self + + def add_repr(self, ns: str | None): + self._cls_dict["__repr__"] = codegen.add_method_dunders(self._cls, _make_repr(self._attrs, ns, self._cls)) + return self + + def __init_subclass__(cls: type[LLMConfig]): """The purpose of this __init_subclass__ is that we want all subclass of LLMConfig to adhere to the attrs contract, and have pydantic-like interface. This means we will construct all fields and metadata and hack into how attrs use some of the 'magic' construction @@ -897,12 +1312,44 @@ class LLMConfig: logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__) cls.__name__ = f"{cls.__name__}Config" - # NOTE: auto assignment attributes generated from __config__ - _make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettingsAttr))(cls) + if not hasattr(cls, "__config__") or getattr(cls, "__config__") is None: + raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.") + + assert cls.__config__ is not None + name_type = cls.__config__.get( + "name_type", + _ModelSettingsAttr.default().name_type, # type: ignore (dynamic attribute) + ) + + _cl_name = cls.__name__.replace("Config", "") + model_name = inflection.underscore(_cl_name) if name_type == "dasherize" else _cl_name.lower() + + cls.__openllm_generation_class__ = attr.make_class( + f"{_cl_name}GenerationConfig", + [], + bases=(GenerationConfig,), + slots=True, + weakref_slot=True, + frozen=True, + repr=True, + collect_by_mro=True, + field_transformer=_make_env_transformer( + cls, + model_name, + suffix="generation", + default_callback=lambda field_name, field_default: getattr( + cls.GenerationConfig, field_name, field_default + ) + if codegen.has_own_attribute(cls, "GenerationConfig") + else field_default, + globs={"cl_": cls}, + ), + ) + # process a fields under cls.__dict__ and auto convert them with dantic.Field + # this is similar logic to attr._make._transform_attrs cd = cls.__dict__ anns = codegen.get_annotations(cls) - # _CountingAttr is the underlying representation of attr.field ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)} these: dict[str, _CountingAttr[t.Any]] = {} @@ -914,9 +1361,9 @@ class LLMConfig: val = cd.get(attr_name, attr.NOTHING) if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val): if val is attr.NOTHING: - val = cls.Field(env=_field_env_key(cls.__openllm_model_name__, attr_name)) + val = cls.Field(env=_field_env_key(model_name, attr_name)) else: - val = cls.Field(default=val, env=_field_env_key(cls.__openllm_model_name__, attr_name)) + val = cls.Field(default=val, env=_field_env_key(model_name, attr_name)) these[attr_name] = val unannotated = ca_names - annotated_names if len(unannotated) > 0: @@ -924,12 +1371,12 @@ class LLMConfig: raise openllm.exceptions.MissingAnnotationAttributeError( f"The following field doesn't have a type annotation: {missing_annotated}" ) - + # We need to set the accepted key before generation_config + # as generation_config is a special field that users shouldn't pass. cls.__openllm_accepted_keys__ = set(these.keys()) | { a.name for a in attr.fields(cls.__openllm_generation_class__) } - - # 'generation_config' is a special fields that wraps the GenerationConfig class + # 'generation_config' wraps the GenerationConfig class # which is handled in _make_assignment_script these["generation_config"] = cls.Field( default=cls.__openllm_generation_class__(), @@ -937,78 +1384,13 @@ class LLMConfig: type=GenerationConfig, ) - # Generate the base __attrs_attrs__ transformation here. - attrs, base_attrs, base_attr_map = _transform_attrs( - cls, # the current class - these, # the parsed attributes we previous did - False, # disable auto_attribs, since we already handle these - False, # disable kw_only - True, # collect_by_mro - field_transformer=_make_env_transformer(cls, cls.__openllm_model_name__), - ) - _weakref_slot = True # slots = True - _base_names = {a.name for a in base_attrs} - _attr_names = tuple(a.name for a in attrs) - _has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) - _has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) - # the protocol for attrs-decorated class - cls.__attrs_attrs__ = attrs - # generate a __attrs_init__ for the subclass, since we will - # implement a custom __init__ - cls.__attrs_init__ = codegen.add_method_dunders( - cls, - _make_init( - cls, # cls (the attrs-decorated class) - attrs, # tuple of attr.Attribute of cls - _has_pre_init, # pre_init - _has_post_init, # post_init - False, # frozen - True, # slots - False, # cache_hash - base_attr_map, # base_attr_map - False, # is_exc (check if it is exception) - None, # cls_on_setattr (essentially attr.setters) - attrs_init=True, # whether to create __attrs_init__ instead of __init__ - ), - ) - # __repr__ function with the updated fields. - cls.__repr__ = codegen.add_method_dunders(cls, _make_repr(cls.__attrs_attrs__, None, cls)) - # Traverse the MRO to collect existing slots - # and check for an existing __weakref__. - existing_slots: DictStrAny = {} - weakref_inherited = False - for base_cls in cls.__mro__[1:-1]: - if base_cls.__dict__.get("__weakref__", None) is not None: - weakref_inherited = True - existing_slots.update( - {name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])} - ) + cls = cls._ConfigBuilder(cls, model_name, these).add_attrs_init().add_repr(None).build_class() + # auto assignment attributes generated from __config__ after create the new slot class. + _make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettingsAttr))(cls) - if ( - _weakref_slot - and "__weakref__" not in getattr(cls, "__slots__", ()) - and "__weakref__" not in _attr_names - and not weakref_inherited - ): - _attr_names += ("__weakref__",) - - # We only add the names of attributes that aren't inherited. - # Setting __slots__ to inherited attributes wastes memory. - slot_names = [name for name in _attr_names if name not in _base_names] - # There are slots for attributes from current class - # that are defined in parent classes. - # As their descriptors may be overridden by a child class, - # we collect them here and update the class dict - reused_slots = { - slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names - } - # __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots - slot_names = [name for name in slot_names if name not in reused_slots] + ["__openllm_extras__"] - cls.__slots__ = tuple(slot_names) # Finally, resolve the types if getattr(cls, "__attrs_types_resolved__", None) != cls: # NOTE: We will try to resolve type here, and cached it for faster use - # It will be about 15-20ms slower comparing not resolving types. globs: DictStrAny = {"t": t, "typing": t, "Constraint": Constraint} if cls.__module__ in sys.modules: globs.update(sys.modules[cls.__module__].__dict__) @@ -1019,7 +1401,6 @@ class LLMConfig: cls.__openllm_hints__ = { f.name: f.type for ite in map(attr.fields, (cls, cls.__openllm_generation_class__)) for f in ite } - cls = __llm_config_transform__(cls) def __setattr__(self, attr: str, value: t.Any): if attr in _reserved_namespace: @@ -1045,43 +1426,85 @@ class LLMConfig: if generation_config is None: generation_config = {k: v for k, v in attrs.items() if k in _generation_cl_dict} else: - config_merger.merge(generation_config, {k: v for k, v in attrs.items() if k in _generation_cl_dict}) + generation_config = config_merger.merge(generation_config, {k: v for k, v in attrs.items() if k in _generation_cl_dict}) for k in _cached_keys: if k in generation_config or attrs.get(k) is None: del attrs[k] _cached_keys = tuple(k for k in _cached_keys if k in attrs) - self.__openllm_extras__ = first_not_none(__openllm_extras__, default={}) - config_merger.merge( - self.__openllm_extras__, {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__} - ) + self.__openllm_extras__ = config_merger.merge( first_not_none(__openllm_extras__, default={}), {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__}) for k in _cached_keys: if k in self.__openllm_extras__: del attrs[k] _cached_keys = tuple(k for k in _cached_keys if k in attrs) - if DEBUG: - logger.info( - "Creating %s with the following attributes: %s, generation_config=%s", - self.__class__.__name__, - _cached_keys, - generation_config, - ) - # The rest of attrs should only be the attributes to be passed to __attrs_init__ self.__attrs_init__(generation_config=self["generation_class"](**generation_config), **attrs) - def __getitem__(self, item: str | t.Any) -> t.Any: + # NOTE: These required fields should be at the top, as it will be kw_only + + # fmt: off + + # update-config-stubs.py: start + if t.TYPE_CHECKING: + @overload + def __getitem__(self, item: t.Literal["default_id"] = ...) -> str: ... + @overload + def __getitem__(self, item: t.Literal["model_ids"] = ...) -> ListStr: ... + @overload + def __getitem__(self, item: t.Literal["url"] = ...) -> str: ... + @overload + def __getitem__(self, item: t.Literal["requires_gpu"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["trust_remote_code"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["service_name"] = ...) -> str: ... + @overload + def __getitem__(self, item: t.Literal["requirements"] = ...) -> t.Optional[ListStr]: ... + @overload + def __getitem__(self, item: t.Literal["use_pipeline"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["bettertransformer"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["model_type"] = ...) -> t.Literal['causal_lm', 'seq2seq_lm']: ... + @overload + def __getitem__(self, item: t.Literal["runtime"] = ...) -> t.Literal['transformers', 'cpp']: ... + @overload + def __getitem__(self, item: t.Literal["name_type"] = ...) -> t.Literal['dasherize', 'lowercase']: ... + @overload + def __getitem__(self, item: t.Literal["model_name"] = ...) -> str: ... + @overload + def __getitem__(self, item: t.Literal["start_name"] = ...) -> str: ... + @overload + def __getitem__(self, item: t.Literal["env"] = ...) -> openllm.utils.ModelEnv: ... + @overload + def __getitem__(self, item: t.Literal["timeout"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["workers_per_resource"] = ...) -> t.Union[int, float]: ... + @overload + def __getitem__(self, item: t.Literal["fine_tune_strategies"] = ...) -> t.Dict[AdapterType, FineTuneConfig]: ... + @overload + def __getitem__(self, item: t.Literal["generation_class"] = ...) -> t.Type[GenerationConfig]: ... + @overload + def __getitem__(self, item: t.Literal["extras"] = ...) -> t.Dict[str, t.Any]: ... + # update-config-stubs.py: stop + + # fmt: on + + def __getitem__(self, item: t.LiteralString | t.Any = None) -> t.Any: """Allowing access LLMConfig as a dictionary. The order will always evaluate as - __openllm_*__ > self.key > __openllm_generation_class__ > __openllm_extras__ + __openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__ This method is purely for convenience, and should not be used for performance critical code. """ + if item is None: + raise TypeError(f"{self} doesn't understand how to index None.") if not isinstance(item, str): raise TypeError(f"LLM only supports string indexing, not {item.__class__.__name__}") + item = inflection.underscore(item) if item in _reserved_namespace: raise ForbiddenAttributeError( f"'{item}' is a reserved namespace for {self.__class__} and should not be access nor modified." @@ -1093,6 +1516,8 @@ class LLMConfig: return getattr(self, item) elif hasattr(self.__openllm_generation_class__, item): return getattr(self.generation_config, item) + elif item in self.__class__.__openllm_fine_tune_strategies__: + return self.__class__.__openllm_fine_tune_strategies__[item] elif item in self.__openllm_extras__: return self.__openllm_extras__[item] else: @@ -1209,11 +1634,11 @@ class LLMConfig: return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove} - @t.overload + @overload def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig: ... - @t.overload + @overload def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny: ... @@ -1222,14 +1647,14 @@ class LLMConfig: return config.to_dict() if return_as_dict else config @classmethod - @t.overload + @overload def to_click_options( cls, f: t.Callable[..., openllm.LLMConfig] ) -> F[P, ClickFunctionWrapper[..., openllm.LLMConfig]]: ... @classmethod - @t.overload + @overload def to_click_options(cls, f: t.Callable[P, O_co]) -> F[P, ClickFunctionWrapper[P, O_co]]: ... @@ -1264,6 +1689,12 @@ class LLMConfig: return cog.optgroup.group(f"{cls.__name__} options")(f) + # All fine-tune related starts here + @classmethod + def peft_task_type(cls) -> str: + # holds a mapping from self.__openllm_model_type__ to peft.TaskType + return _PEFT_TASK_TYPE_TARGET_MAPPING[cls.__openllm_model_type__] + bentoml_cattr.register_unstructure_hook_factory( lambda cls: lenient_issubclass(cls, LLMConfig), diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 58ab61d2..d2cee810 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -16,6 +16,7 @@ from __future__ import annotations import copy import functools +import inspect import logging import os import re @@ -29,34 +30,48 @@ from abc import abstractmethod import attr import inflection import orjson +from huggingface_hub import hf_hub_download import bentoml import openllm from bentoml._internal.models.model import ModelSignature from bentoml._internal.types import ModelSignatureDict +from ._configuration import FineTuneConfig from .exceptions import ForbiddenAttributeError from .exceptions import GpuNotAvailableError from .exceptions import OpenLLMException from .utils import DEBUG from .utils import LazyLoader from .utils import ModelEnv +from .utils import ReprMixin from .utils import bentoml_cattr from .utils import first_not_none from .utils import get_debug_mode from .utils import is_bitsandbytes_available +from .utils import is_peft_available from .utils import is_torch_available from .utils import is_transformers_supports_kbit from .utils import non_intrusive_setattr from .utils import pkg +from .utils import requires_dependencies +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if hasattr(t, "get_overloads"): + from typing import overload +else: + from typing_extensions import overload + if t.TYPE_CHECKING: + import peft import torch import transformers from bentoml._internal.runner.strategy import Strategy + from ._configuration import AdapterType from .models.auto.factory import _BaseAutoLLMClass class LLMRunner(bentoml.Runner): @@ -74,6 +89,7 @@ else: LLMRunner = bentoml.Runner transformers = LazyLoader("transformers", globals(), "transformers") torch = LazyLoader("torch", globals(), "torch") + peft = LazyLoader("peft", globals(), "peft") logger = logging.getLogger(__name__) @@ -96,6 +112,41 @@ def convert_transformers_model_name(name: str | None) -> str: return re.sub("[^a-zA-Z0-9]+", "-", name) +# the below is similar to peft.utils.other.CONFIG_NAME +PEFT_CONFIG_NAME = "adapter_config.json" + + +def resolve_peft_config_type(adapter_map: dict[str, str | None] | None): + """Resolve the type of the PeftConfig given the adapter_map. + This is similar to how PeftConfig resolve its config type. + """ + if adapter_map is None: + return + + resolved: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] = {} + _has_set_default = False + for path_or_adapter_id, name in adapter_map.items(): + if _has_set_default: + raise ValueError("Only one adapter can be set as default.") + if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)): + config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME) + else: + try: + config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME) + except Exception: + raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") + with open(config_file, "r") as file: + resolved_config = orjson.loads(file.read()) + # all peft_type should be available in PEFT_CONFIG_NAME + _peft_type: AdapterType = resolved_config["peft_type"].lower() + if _peft_type not in resolved: + resolved[_peft_type] = () + resolved[_peft_type] += ((path_or_adapter_id, name, resolved_config),) + if name == "default": + _has_set_default = True + return resolved + + def import_model( model_id: str, tag: bentoml.Tag, @@ -179,10 +230,6 @@ def import_model( try: return bentoml.transformers.save_model(tag, model, custom_objects={"tokenizer": tokenizer}) finally: - import gc - - gc.collect() - # NOTE: We need to free up the cache after importing the model # in the case where users first run openllm start without the model # available locally. @@ -193,11 +240,12 @@ def import_model( _reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"} -class LLMInterface(ABC): - """This defines the loose contract for all openllm.LLM implementations.""" +_M = t.TypeVar("_M") +_T = t.TypeVar("_T") - config_class: type[openllm.LLMConfig] - """The config class to use for this LLM. If you are creating a custom LLM, you must specify this class.""" + +class LLMInterface(ABC, t.Generic[_M, _T]): + """This defines the loose contract for all openllm.LLM implementations.""" @property def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]] | None: @@ -270,32 +318,76 @@ class LLMInterface(ABC): example implementation.""" raise NotImplementedError + # NOTE: All fields below are attributes that can be accessed by users. + config_class: type[openllm.LLMConfig] + """The config class to use for this LLM. If you are creating a custom LLM, you must specify this class.""" -_M = t.TypeVar("_M") -_T = t.TypeVar("_T") + config: openllm.LLMConfig + """The config instance to use for this LLM. This will be created based on config_class""" + + bettertransformer: bool + """Whether to load this LLM with FasterTransformer enabled. The order of loading is: + - If pass within `for_model`, `from_pretrained` or `__init__`. + - If `self.bettertransformer` is set within `llm_post_init`. + - Finally, if none of the above, default to self.config['bettertransformer'] + + > **Note** that if LoRA is enabled, bettertransformer will be disabled. + """ + + # NOTE: The following will be populated by __init_subclass__, note that these should not + # be mutated by users. + __llm_trust_remote_code__: bool + """This is used to determine during 'import_model' whether to trust remote code or not. + This works synonymous with `trust_remote_code` kwarg in transformers Auto classes. If not passed, + then by default fallback to config_class['trust_remote_code'] + """ + + __llm_implementation__: t.Literal["pt", "tf", "flax"] + """This is used to determine which implementation that this LLM has. Usually, this will inferred from + class name, that follows the HuggingFace's naming convention: + + - `OPTForConditionalGeneration` -> `pt` + - `TFOPTForConditionalGeneration` -> `tf` + - `FlaxOPTForConditionalGeneration` -> `flax` + """ + + __llm_model__: _M | peft.PeftModel | torch.nn.Module | None + """A reference to the actual model. Instead of access this directly, you should use `model` property instead.""" + + __llm_tokenizer__: _T | None + """A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead.""" + + __llm_tag__: bentoml.Tag | None + """A reference to the tag used for this LLM. Instead of access this directly, you should use `tag` property instead.""" + + __llm_bentomodel__: bentoml.Model | None + """A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead.""" + + __llm_trainer__: transformers.Trainer | None + """A reference to the Trainer to be used for this LLM to fine-tune.""" + + __llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], peft.PeftConfig]] | None + """A reference to the the cached LoRA adapter mapping.""" + + __llm_post_init__: t.Callable[[t.Self], None] | None + """A callable that will be called after the LLM is initialized. This is set if subclass contains a 'llm_post_init'""" + + __llm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None + """A callable that will be called after the model is loaded. This is set when 'load_model' is implemented""" + + __llm_init_kwargs__: property | None + """A check if 'import_kwargs' is implemented in subclass.""" + + # The following are internal, users shouldn't access this directly. + _model_args: tuple[t.Any, ...] + _model_attrs: dict[str, t.Any] + _tokenizer_attrs: dict[str, t.Any] + + _adapters_mapping: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None @attr.define(slots=True, repr=False) -class LLM(LLMInterface, t.Generic[_M, _T]): - if t.TYPE_CHECKING: - # The following will be populated by metaclass - __llm_trust_remote_code__: bool - __llm_implementation__: t.Literal["pt", "tf", "flax"] - __llm_model__: _M | None - __llm_tokenizer__: _T | None - __llm_tag__: bentoml.Tag | None - __llm_bentomodel__: bentoml.Model | None - - __llm_post_init__: t.Callable[[t.Self], None] | None - __llm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None - __llm_init_kwargs__: property | None - - _model_args: tuple[t.Any, ...] - _model_attrs: dict[str, t.Any] - _tokenizer_attrs: dict[str, t.Any] - - bettertransformer: bool - +class LLM(LLMInterface[_M, _T], ReprMixin): def __init_subclass__(cls): cd = cls.__dict__ prefix_class_name_config = cls.__name__ @@ -321,19 +413,33 @@ class LLM(LLMInterface, t.Generic[_M, _T]): "Missing required key 'config_class'. Make sure to define it within the LLM subclass." ) - if cls.import_model is LLMInterface.import_model: - # using the default import model + if cls.import_model is LLMInterface[_M, _T].import_model: + # using the default import model if no custom import is set setattr(cls, "import_model", functools.partial(import_model, _model_framework=implementation)) - else: - logger.debug("Custom 'import_model' will be used when loading model %s", cls.__name__) - cls.__llm_post_init__ = None if cls.llm_post_init is LLMInterface.llm_post_init else cls.llm_post_init - cls.__llm_custom_load__ = None if cls.load_model is LLMInterface.load_model else cls.load_model - cls.__llm_init_kwargs__ = None if cls.import_kwargs is LLMInterface.import_kwargs else cls.import_kwargs + cls.__llm_post_init__ = None if cls.llm_post_init is LLMInterface[_M, _T].llm_post_init else cls.llm_post_init + cls.__llm_custom_load__ = None if cls.load_model is LLMInterface[_M, _T].load_model else cls.load_model + cls.__llm_init_kwargs__ = ( + None if cls.import_kwargs is LLMInterface[_M, _T].import_kwargs else cls.import_kwargs + ) - for at in {"bentomodel", "tag", "model", "tokenizer"}: + for at in {"bentomodel", "tag", "model", "tokenizer", "adapter_map", "trainer"}: setattr(cls, f"__llm_{at}__", None) + # update docstring for given entrypoint + for fn in {"generate", "generate_one", "generate_iterator"}: + original_fn = getattr(cls, fn, getattr(LLMInterface, fn)) + original_fn.__doc__ = ( + original_fn.__doc__ + or f"""\ + '{fn}' implementation {cls.__name__}. + + Note that if LoRA is enabled (via either SDK or CLI), `self.model` will become a `peft.PeftModel` + The original can then be accessed with 'self.model.get_base_model()'. + """ + ) + setattr(cls, fn, original_fn) + # The following is the similar interface to HuggingFace pretrained protocol. @classmethod def from_pretrained( @@ -343,14 +449,146 @@ class LLM(LLMInterface, t.Generic[_M, _T]): *args: t.Any, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, + adapter_id: str | None = None, + adapter_name: str | None = None, + adapter_map: dict[str, str | None] | None = None, **attrs: t.Any, ) -> LLM[_M, _T]: + """Instantiate a pretrained LLM. + it follows the same design principle as HuggingFace's `from_pretrained` method, plus the following: + + Optimization options: + + > This is most notable during serving time. + + - quantize: quantize the model with the given quantization method. Currently supported int8, int4 quantization + - bettertransformer: Apply FasterTransformer to given pretrained weight + + > Currently, the above two options are mutually exclusive. + + Adapter options: + + > This is used in conjunction with the fine-tuning features + + - adapter_id: Optional [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to apply to said model. + - adapter_name: Optional name of the adapter to apply to said model. If not provided, it will be handled internally by OpenLLM. + - adapter_map: optional dictionary of adapter_id to adapter_name. Note that this is mutually exclusive with adapter_id/adapter_name arguments. + + Args: + model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used. + llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM + will use `config_class` to construct default configuration. + quantize: The quantization to use for this LLM. Defaults to None. Possible values + include int8, int4 and gptq. + bettertransformer: Whether to use BetterTransformer with this model. Defaults to False. + adapter_id: The [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to use for this LLM. Defaults to None. + adapter_name: The adapter name to use for this LLM. Defaults to None. + adapter_map: The adapter map to use for this LLM. Defaults to None. Note that this is mutually exclusive with adapter_id/adapter_name arguments. + *args: The args to be passed to the model. + **attrs: The kwargs to be passed to the model. + """ + quantization_config = attrs.pop("quantization_config", None) + if quantization_config and quantize: + raise ValueError( + """'quantization_config' and 'quantize' are mutually exclusive. Either customise + your quantization_config or use the quantize argument.""" + ) + + # quantization setup + quantization_config = attrs.pop("quantization_config", None) + # 8 bit configuration + int8_threshold = attrs.pop("llm_int8_threshhold", 6.0) + cpu_offloading = attrs.pop("llm_int8_enable_fp32_cpu_offload", False) + int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None) + int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False) + # 4 bit configuration + int4_compute_dtype = attrs.pop("llm_bnb_4bit_compute_dtype", torch.bfloat16) + int4_quant_type = attrs.pop("llm_bnb_4bit_quant_type", "nf4") + int4_use_double_quant = attrs.pop("llm_bnb_4bit_use_double_quant", True) + + # NOTE: Quantization setup + if quantization_config is None: + # quantize is a openllm.LLM feature, where we can quantize the model + # with bitsandbytes or quantization aware training. + if quantize is not None: + if not is_bitsandbytes_available(): + raise RuntimeError( + "Quantization requires bitsandbytes to be installed. Make " + "sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'" + ) + logger.debug( + "'quantize' is not None. %s will use a default 'quantization_config' for %s. " + "If you want to customise the quantization config, make sure to pass your " + "own 'quantization_config'", + cls.__name__, + quantize, + ) + if quantize == "int8": + if int8_skip_modules is None: + int8_skip_modules = [] + if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm": + logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__) + int8_skip_modules.append("lm_head") + quantization_config = transformers.BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_enable_fp32_cpu_offload=cpu_offloading, + llm_int8_threshhold=int8_threshold, + llm_int8_skip_modules=int8_skip_modules, + llm_int8_has_fp16_weight=int8_has_fp16_weight, + ) + elif quantize == "int4": + if is_transformers_supports_kbit(): + quantization_config = transformers.BitsAndBytesConfig( + load_in_4bit=True, + llm_bnb_4bit_compute_dtype=int4_compute_dtype, + llm_bnb_4bit_quant_type=int4_quant_type, + llm_bnb_4bit_use_double_quant=int4_use_double_quant, + ) + else: + logger.warning( + "'quantize' is set to int4, while the current transformers version %s does not support " + "k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore " + "make sure to install the latest version of transformers either via PyPI or " + "from git source: 'pip install git+https://github.com/huggingface/transformers'.", + pkg.pkg_version_info("transformers"), + ) + elif quantize == "gptq": + # TODO: support GPTQ loading quantization + raise NotImplementedError("GPTQ is not supported yet.") + if model_id is None: + raise RuntimeError( + "'quantize=%s' requires passing custom path to quantized weights as we are unable to load " + "the model on the fly. See https://github.com/qwopqwop200/GPTQ-for-LLaMa for " + "instruction on how to quantize '%s' with GPTQ.", + quantize, + cls.__name__, + ) + else: + raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantize} instead.") + + # NOTE: Fine-tuning setup + if adapter_map and adapter_id: + raise ValueError( + """'adapter_map' and 'adapter_id' are mutually exclusive. Either provide a + 'adapter_map' ({adapter_id: adapter_name | None, ...}) or use + the combination of adapter_id/adapter_name arguments. + """ + ) + if adapter_map is None and adapter_id is not None: + adapter_map = {adapter_id: adapter_name} + + if adapter_map is not None and not is_peft_available(): + raise RuntimeError( + "LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'" + ) + return cls( model_id=model_id, llm_config=llm_config, *args, - quantize=quantize, bettertransformer=bettertransformer, + _adapters_mapping=resolve_peft_config_type(adapter_map), + quantization_config=quantization_config, **attrs, ) @@ -359,12 +597,17 @@ class LLM(LLMInterface, t.Generic[_M, _T]): model_id: str | None = None, llm_config: openllm.LLMConfig | None = None, *args: t.Any, - quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, + _adapters_mapping: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] + | None = None, **attrs: t.Any, ): """Initialize the LLM with given pretrained model. + > **Warning** + > To initializing any LLM, you should use `openllm.AutoLLM` or `openllm.LLM.from_pretrained` instead. + > `__init__` initialization is only for internal use. + Note: - *args to be passed to the model. - **attrs will first be parsed to the AutoConfig, then the rest will be parsed to the import_model @@ -437,8 +680,6 @@ class LLM(LLMInterface, t.Generic[_M, _T]): model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used. llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM will use `config_class` to construct default configuration. - quantize: The quantization to use for this LLM. Defaults to None. Possible values - include int8, int4 and gptq. bettertransformer: Whether to use BetterTransformer with this model. Defaults to False. *args: The args to be passed to the model. **attrs: The kwargs to be passed to the model. @@ -454,18 +695,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]): # low_cpu_mem_usage is only available for model # this is helpful on system with low memory to avoid OOM low_cpu_mem_usage = attrs.pop("low_cpu_mem_usage", True) - - # quantization setup quantization_config = attrs.pop("quantization_config", None) - # 8 bit configuration - int8_threshold = attrs.pop("llm_int8_threshhold", 6.0) - cpu_offloading = attrs.pop("llm_int8_enable_fp32_cpu_offload", False) - int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None) - int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False) - # 4 bit configuration - int4_compute_dtype = attrs.pop("llm_bnb_4bit_compute_dtype", torch.bfloat16) - int4_quant_type = attrs.pop("llm_bnb_4bit_quant_type", "nf4") - int4_use_double_quant = attrs.pop("llm_bnb_4bit_use_double_quant", True) if llm_config is not None: logger.debug("Using provided LLMConfig to initialize LLM instead of from default: %r", llm_config) @@ -475,69 +705,10 @@ class LLM(LLMInterface, t.Generic[_M, _T]): # The rests of the kwargs that is not used by the config class should be stored into __openllm_extras__. attrs = self.config["extras"] - if quantization_config and quantize: - raise ValueError( - """'quantization_config' and 'quantize' are mutually exclusive. Either customise - your quantization_config or use the quantize argument.""" - ) - if quantization_config is None: - # quantize is a openllm.LLM feature, where we can quantize the model - # with bitsandbytes or quantization aware training. - if quantize is not None: - if not is_bitsandbytes_available(): - raise RuntimeError( - "Quantization requires bitsandbytes to be installed. Make " - "sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'" - ) - logger.debug( - "'quantize' is not None. %s will use a default 'quantization_config' for %s. " - "If you want to customise the quantization config, make sure to pass your " - "own 'quantization_config'", - self, - quantize, - ) - if quantize == "int8": - if int8_skip_modules is None: - int8_skip_modules = [] - if "lm_head" not in int8_skip_modules and self.config["model_type"] == "causal_lm": - logger.debug("Skipping 'lm_head' for quantization for %s", self) - int8_skip_modules.append("lm_head") - quantization_config = transformers.BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_enable_fp32_cpu_offload=cpu_offloading, - llm_int8_threshhold=int8_threshold, - llm_int8_skip_modules=int8_skip_modules, - llm_int8_has_fp16_weight=int8_has_fp16_weight, - ) - elif quantize == "int4": - if is_transformers_supports_kbit(): - quantization_config = transformers.BitsAndBytesConfig( - load_in_4bit=True, - llm_bnb_4bit_compute_dtype=int4_compute_dtype, - llm_bnb_4bit_quant_type=int4_quant_type, - llm_bnb_4bit_use_double_quant=int4_use_double_quant, - ) - else: - logger.warning( - "'quantize' is set to int4, while the current transformers version %s does not support " - "k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore " - "make sure to install the latest version of transformers either via PyPI or " - "from git source: 'pip install git+https://github.com/huggingface/transformers'.", - pkg.pkg_version_info("transformers"), - ) - elif quantize == "gptq": - # TODO: support GPTQ loading quantization - if model_id is None: - raise RuntimeError( - "'quantize=%s' requires passing custom path to quantized weights as we are unable to load " - "the model on the fly. See https://github.com/qwopqwop200/GPTQ-for-LLaMa for " - "instruction on how to quantize '%s' with GPTQ.", - quantize, - self, - ) - raise NotImplementedError("GPTQ is not supported yet.") - else: - raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantize} instead.") + if self.config["use_pipeline"] and _adapters_mapping: + raise ValueError(f"{self} will be used as a Pipeline, which is not yet compatible with LoRA adapter.") + + self._adapters_mapping = _adapters_mapping if self.__llm_implementation__ == "pt": if not self.config["use_pipeline"]: @@ -546,10 +717,8 @@ class LLM(LLMInterface, t.Generic[_M, _T]): model_kwds, tokenizer_kwds = {}, {} if self.__llm_init_kwargs__: - if t.TYPE_CHECKING: - # the above meta value should determine that this LLM has custom kwargs - assert self.import_kwargs - model_kwds, tokenizer_kwds = self.import_kwargs + # NOTE: recast here for type safety + model_kwds, tokenizer_kwds = t.cast("tuple[dict[str, t.Any], dict[str, t.Any]]", self.__llm_init_kwargs__) logger.debug( "'%s' default kwargs for model: '%s', tokenizer: '%s'", self.__class__.__name__, @@ -564,7 +733,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]): assert model_id is not None self._model_id = model_id - # parsing tokenizer and model kwargs + # parsing tokenizer and model kwargs, as the hierachy is param pass > default tokenizer_kwds.update( {k[len(TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(TOKENIZER_PREFIX)} ) @@ -588,6 +757,10 @@ class LLM(LLMInterface, t.Generic[_M, _T]): self.bettertransformer = bettertransformer else: non_intrusive_setattr(self, "bettertransformer", self.config["bettertransformer"]) + # If lora is passed, the disable bettertransformer + if _adapters_mapping and self.bettertransformer is True: + logger.debug("LoRA is visible for %s, disabling BetterTransformer", self) + self.bettertransformer = False def __setattr__(self, attr: str, value: t.Any): if attr in _reserved_namespace: @@ -599,9 +772,21 @@ class LLM(LLMInterface, t.Generic[_M, _T]): super().__setattr__(attr, value) - def __repr__(self) -> str: - keys = {"model_id", "runner_name", "llm_type", "config"} - return f"{self.__class__.__name__}({', '.join(f'{k}={getattr(self, k)!r}' for k in keys)})" + @property + def adapters_mapping( + self, + ) -> dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None: + return self._adapters_mapping + + @adapters_mapping.setter + def adapters_mapping( + self, value: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None + ): + self._adapters_mapping = value + + @property + def __repr_keys__(self) -> set[str]: + return {"model_id", "runner_name", "config"} @property def model_id(self) -> str: @@ -712,9 +897,9 @@ class LLM(LLMInterface, t.Generic[_M, _T]): @property def _bentomodel(self) -> bentoml.Model: if self.__llm_bentomodel__ is None: - # NOTE: Since PR#28, self.__llm_bentomodel__ changed from + # NOTE: Since #28, self.__llm_bentomodel__ changed from # ensure_model_id_exists() into just returning the model ref. - # This is because we want to save a few seconds of loading time, + # This is purely a performance reason. # as openllm.Runner and openllm.AutoLLM initialisation is around 700ms # before #28. # If users want to make sure to have the model downloaded, @@ -741,23 +926,24 @@ class LLM(LLMInterface, t.Generic[_M, _T]): if self.config["requires_gpu"] and len(openllm.utils.gpu_count()) < 1: raise GpuNotAvailableError(f"{self} only supports running with GPU (None available).") from None - kwds = self._model_attrs - kwds["trust_remote_code"] = self.__llm_trust_remote_code__ - - is_pipeline = "_pretrained_class" in self._bentomodel.info.metadata - # differentiate when saving tokenizer or other pretrained type. - is_pretrained_model = is_pipeline and "_framework" in self._bentomodel.info.metadata - - if self.bettertransformer and is_pipeline and self.config["use_pipeline"]: - # This is a pipeline, provide a accelerator args - kwds["accelerator"] = "bettertransformer" - if self.__llm_model__ is None: + kwds = self._model_attrs + kwds["trust_remote_code"] = self.__llm_trust_remote_code__ + + is_pipeline = "_pretrained_class" in self._bentomodel.info.metadata + # differentiate when saving tokenizer or other pretrained type. + is_pretrained_model = is_pipeline and "_framework" in self._bentomodel.info.metadata + + if self.bettertransformer and is_pipeline and self.config["use_pipeline"]: + # This is a pipeline, provide a accelerator args + kwds["accelerator"] = "bettertransformer" + if self.__llm_custom_load__: self.__llm_model__ = self.load_model(self.tag, *self._model_args, **kwds) else: self.__llm_model__ = self._bentomodel.load_model(*self._model_args, **kwds) + # This branch shouldn't hit when LoRA is visible. if ( self.bettertransformer and is_pretrained_model @@ -770,6 +956,143 @@ class LLM(LLMInterface, t.Generic[_M, _T]): self.__llm_model__ = BetterTransformer.transform(self.__llm_model__) return t.cast(_M, self.__llm_model__) + def _transpose_adapter_mapping( + self, + inference_mode: bool = True, + use_cache: bool = True, + ) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]]: + assert self._adapters_mapping is not None, "LoRA mapping is not set up correctly." + + if not use_cache: + logger.debug( + "'use_cache' is set to False. This means the adapter mapping resolution will not be cached. This should only be used during training." + ) + + if self.__llm_adapter_map__ is not None and use_cache: + # early out if we already serialized everything. + return self.__llm_adapter_map__ + + adapter_map: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] = {} + # this is a temporary check to accept the first option name as 'default' + # then we will raise Error when the optional_name is set to None in next iteration. + _converted_first_none = False + for _adapter_type, _adapter_tuple in self._adapters_mapping.items(): + if _adapter_type not in adapter_map: + adapter_map[_adapter_type] = {} + default_config = self.config["fine_tune_strategies"].get( + _adapter_type, FineTuneConfig(adapter_type=_adapter_type, llm_config_class=self.config_class) + ) + default_config = default_config.eval() if inference_mode else default_config.train() + for pretrained_or_peft_id, optional_name, resolved_mapping in _adapter_tuple: + if not optional_name: + if not _converted_first_none: + _converted_first_none = True + optional_name = "default" + else: + raise ValueError( + f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {pretrained_or_peft_id, resolved_mapping}" + ) + assert isinstance(optional_name, str) # optional_name should all be resolved here + if optional_name == "default": + adapter_map[_adapter_type][optional_name] = ( + default_config.with_config(**resolved_mapping).to_peft_config(), + pretrained_or_peft_id, + ) + else: + adapter_map[_adapter_type][optional_name] = ( + FineTuneConfig( + adapter_type=_adapter_type, + adapter_config=resolved_mapping, + inference_mode=inference_mode, + llm_config_class=self.config_class, + ).to_peft_config(), + pretrained_or_peft_id, + ) + + if self.__llm_adapter_map__ is None and use_cache: + self.__llm_adapter_map__ = adapter_map + + return self.__llm_adapter_map__ + + return adapter_map + + @requires_dependencies("peft", extra="fine-tune") + def apply_adapter( + self, + inference_mode: bool = True, + adapter_type: AdapterType = "lora", + load_adapters: t.Literal["all"] | list[str] | None = None, + use_cache: bool = True, + ) -> peft.PeftModel | _M | torch.nn.Module: + """Apply given LoRA mapping to the model. Note that the base model can still + be accessed via self.model.get_base_model(). + """ + assert self.model, "Internal error: Model is not loaded correctly." + assert self.__llm_model__ is not None + + # early out if _adapters_mapping is empty or it is already wrapped + # with peft. + if not self._adapters_mapping: + logger.debug("No adapter mapping is found. Skip applying adapter.") + return self.__llm_model__ + + _mapping = self._transpose_adapter_mapping(inference_mode=inference_mode, use_cache=use_cache) + if adapter_type not in _mapping: + raise ValueError( + f"Given adapter type {adapter_type} is not supported. Please choose from {list(_mapping.keys())}" + ) + adapter_mapping = _mapping[adapter_type] + default_config, peft_model_id = adapter_mapping.pop("default", None) + if default_config is None: + raise ValueError( + "There is no 'default' mapping. Please check the adapter mapping and report this bug to the OpenLLM team." + ) + + # the below shared similar logics with `get_peft_model` + # TODO: Support PromptLearningConfig + if default_config.task_type not in peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( + default_config, peft.PromptLearningConfig + ): + logger.debug( + "Given task type '%s' is not supported by peft. This means it can be a custom PeftModel implementation. Make sure the adapter is loaded manually before running inference.", + default_config.task_type, + ) + self.__llm_model__ = peft.PeftModel(self.__llm_model__, default_config) + else: + # this is not ideal to serialize like this, wait until https://github.com/huggingface/peft/pull/612 + # is merged + peft_class = peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING[default_config.task_type] + if t.cast("str | None", default_config.base_model_name_or_path) is not None: + kwargs: dict[str, t.Any] = {"is_trainable": not inference_mode} + if "config" in inspect.signature(peft_class.from_pretrained).parameters: + kwargs["config"] = default_config + else: + kwargs.update(dict(default_config.to_dict().items())) + self.__llm_model__ = peft_class.from_pretrained(self.__llm_model__, peft_model_id, **kwargs) + else: + # in this case, the given base_model_name_or_path is None. This will be hit during training + self.__llm_model__ = peft_class(self.__llm_model__, default_config) + + # now we loop through the rest with add_adapter + if len(adapter_mapping) > 0: + for adapter_name, _peft_config in adapter_mapping.items(): + self.__llm_model__.add_adapter(adapter_name, _peft_config) + + # optionally load adapters. In case of multiple adapters, or on Runner, + # we will need to set load_adapters='all' + if load_adapters is not None: + adapters_to_load = adapter_mapping.keys() if load_adapters == "all" else load_adapters + for adapter_name in adapters_to_load: + _peft_config, _peft_model_id = adapter_mapping[adapter_name] + self.__llm_model__.load_adapter( + _peft_model_id, + adapter_name=adapter_name, + is_trainable=not inference_mode, + **dict(_peft_config.to_dict()), + ) + + return self.__llm_model__ + @property def tokenizer(self) -> _T: """The tokenizer to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it.""" @@ -837,18 +1160,45 @@ class LLM(LLMInterface, t.Generic[_M, _T]): SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") SUPPORTS_CPU_MULTI_THREADING = True - llm_type: str - identifying_params: dict[str, t.Any] - - def __init_subclass__(cls, llm_type: str, identifying_params: dict[str, t.Any], **_: t.Any): - cls.llm_type = llm_type - cls.identifying_params = identifying_params - def __init__(__self: _Runnable): # NOTE: The side effect of this line # is that it will load the imported model during - # runner creation. So don't remove it!! - self.model + # runner startup. So don't remove it!! + assert self.model, "Internal error: Model is not loaded" + if self.adapters_mapping is not None: + logger.info("Applying LoRA to %s...", self.runner_name) + self.apply_adapter(inference_mode=True, load_adapters="all") + + @bentoml.Runnable.method(batchable=False) + def list_adapter(__self) -> dict[str, t.Any]: + if not is_peft_available(): + return { + "success": False, + "result": {}, + "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'", + } + if not isinstance(self.model, peft.PeftModel): + return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"} + return {"success": True, "result": self.model.peft_config, "error_msg": ""} + + @bentoml.Runnable.method(batchable=False) + def set_adapter(__self, adapter_name: str) -> dict[t.Literal["success", "error_msg"], bool | str]: + if not is_peft_available(): + return { + "success": False, + "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'", + } + if not isinstance(self.model, peft.PeftModel): + return {"success": False, "error_msg": "Model is not a PeftModel"} + try: + self.model.set_adapter(adapter_name) + return {"success": True, "error_msg": ""} + except ValueError: + logger.info("Adapter %s not found", adapter_name) + return { + "success": False, + "error_msg": f"Adapter {adapter_name} not found. Available adapters: {list(self.model.peft_config)}", + } @bentoml.Runnable.method( batchable=generate_sig.batchable, @@ -916,19 +1266,21 @@ class LLM(LLMInterface, t.Generic[_M, _T]): "__call__": _wrapped_generate_run, "__module__": f"openllm.models.{self.config['model_name']}", "__doc__": self.config["env"].start_docstring, + "__repr_keys__": lambda _: {"llm", "config", "llm_type", "identifying_params"}, } ), )( types.new_class( inflection.camelize(self.config["model_name"]) + "Runnable", (_Runnable,), - { - "SUPPORTED_RESOURCES": ("nvidia.com/gpu", "cpu") - if self.config["requires_gpu"] - else ("nvidia.com/gpu",), - "llm_type": self.llm_type, - "identifying_params": self.identifying_params, - }, + {}, + lambda ns: ns.update( + { + "SUPPORTED_RESOURCES": ("nvidia.com/gpu", "cpu") + if self.config["requires_gpu"] + else ("nvidia.com/gpu",), + } + ), ), name=self.runner_name, embedded=False, @@ -962,7 +1314,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]): return self.postprocess_generate(prompt, generated_result, **postprocess_kwargs) -@t.overload +@overload def Runner( model_name: str, *, @@ -973,7 +1325,7 @@ def Runner( ... -@t.overload +@overload def Runner( model_name: str, *, diff --git a/src/openllm/_package.py b/src/openllm/_package.py index a23926c8..cf9829b0 100644 --- a/src/openllm/_package.py +++ b/src/openllm/_package.py @@ -23,7 +23,9 @@ import typing as t from pathlib import Path import fs +import fs.copy import inflection +import orjson import bentoml import openllm @@ -38,8 +40,16 @@ from .utils import is_flax_available from .utils import is_tf_available from .utils import is_torch_available from .utils import pkg +from .utils import resolve_user_filepath +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if hasattr(t, "get_overloads"): + from typing import overload +else: + from typing_extensions import overload + if t.TYPE_CHECKING: from fs.base import FS @@ -83,19 +93,23 @@ def construct_python_options( llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, + adapter_map: dict[str, str | None] | None = None, ) -> PythonOptions: packages = ["openllm"] + if adapter_map is not None: + packages += ["openllm[fine-tune]"] # NOTE: add openllm to the default dependencies # if users has openllm custom built wheels, it will still respect # that since bentoml will always install dependencies from requirements.txt # first, then proceed to install everything inside the wheels/ folder. if extra_dependencies is not None: - packages += [f"openllm[{k}]" for k in extra_dependencies] + filtered = set(extra_dependencies + ("fine-tune",)) + packages += [f"openllm[{k}]" for k in filtered] if llm.config["requirements"] is not None: packages.extend(llm.config["requirements"]) - if not (str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false"): + if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}") env: ModelEnv = llm.config["env"] @@ -149,6 +163,7 @@ def construct_docker_options( workers_per_resource: int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, + adapter_map: dict[str, str | None] | None, ) -> DockerOptions: _bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "") _bentoml_config_options_opts = [ @@ -164,10 +179,14 @@ def construct_docker_options( env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_MODEL_ID": llm.model_id, + "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_CONFIG_OPTIONS": _bentoml_config_options, } + if adapter_map: + env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") + # We need to handle None separately here, as env from subprocess doesn't # accept None value. _env = ModelEnv(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize) @@ -181,38 +200,6 @@ def construct_docker_options( return DockerOptions(cuda_version="11.6", env=env_dict, system_packages=["git"]) -@t.overload -def build( - model_name: str, - *, - model_id: str | None = ..., - quantize: t.LiteralString | None = ..., - bettertransformer: bool | None = ..., - _extra_dependencies: tuple[str, ...] | None = ..., - _workers_per_resource: int | float | None = ..., - _overwrite_existing_bento: bool = ..., - __cli__: t.Literal[False] = ..., - **attrs: t.Any, -) -> bentoml.Bento: - ... - - -@t.overload -def build( - model_name: str, - *, - model_id: str | None = ..., - quantize: t.LiteralString | None = ..., - bettertransformer: bool | None = ..., - _extra_dependencies: tuple[str, ...] | None = ..., - _workers_per_resource: int | float | None = ..., - _overwrite_existing_bento: bool = ..., - __cli__: t.Literal[True] = ..., - **attrs: t.Any, -) -> tuple[bentoml.Bento, bool]: - ... - - def _build_bento( bento_tag: bentoml.Tag, service_name: str, @@ -221,34 +208,87 @@ def _build_bento( workers_per_resource: int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, + adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None, + build_ctx: str | None = None, ) -> bentoml.Bento: framework_envvar = llm.config["env"]["framework_value"] labels = dict(llm.identifying_params) labels.update({"_type": llm.llm_type, "_framework": framework_envvar}) logger.info("Building Bento for LLM '%s'", llm.config["start_name"]) + + if adapter_map is not None: + assert build_ctx is not None, "build_ctx is required when 'adapter_map' is not None" + updated_mapping: dict[str, str | None] = {} + for adapter_id, name in adapter_map.items(): + try: + resolve_user_filepath(adapter_id, build_ctx) + src_folder_name = os.path.basename(adapter_id) + src_fs = fs.open_fs(build_ctx) + llm_fs.makedir(src_folder_name, recreate=True) + fs.copy.copy_dir(src_fs, adapter_id, llm_fs, src_folder_name) + updated_mapping[src_folder_name] = name + except FileNotFoundError: + # this is the remote adapter, then just added back + # note that there is a drawback here. If the path of the local adapter + # path have the same name as the remote, then we currently don't support + # that edge case. + updated_mapping[adapter_id] = name + adapter_map = updated_mapping + + # add service.py definition to this temporary folder + codegen.write_service(llm.config["model_name"], llm.model_id, adapter_map, llm.config["service_name"], llm_fs) + return bentoml.bentos.build( f"{service_name}:svc", name=bento_tag.name, labels=labels, description=f"OpenLLM service for {llm.config['start_name']}", - include=list( - llm_fs.walk.files(filter=["*.py"]) - ), # NOTE: By default, we are using _service.py as the default service, for now. - exclude=["/venv", "__pycache__/", "*.py[cod]", "*$py.class"], - python=construct_python_options(llm, llm_fs, extra_dependencies), - docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer), + include=list(llm_fs.walk.files()), + exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"], + python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map), + docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map), version=bento_tag.version, build_ctx=llm_fs.getsyspath("/"), ) +@overload +def build( + model_name: str, + *, + model_id: str | None = ..., + quantize: t.LiteralString | None = ..., + bettertransformer: bool | None = ..., + adapter_map: dict[str, str | None] | None = ..., + __cli__: t.Literal[False] = False, + **attrs: t.Any, +) -> bentoml.Bento: + ... + + +@overload +def build( + model_name: str, + *, + model_id: str | None = ..., + quantize: t.LiteralString | None = ..., + bettertransformer: bool | None = ..., + adapter_map: dict[str, str | None] | None = ..., + __cli__: t.Literal[True] = ..., + **attrs: t.Any, +) -> tuple[bentoml.Bento, bool]: + ... + + def build( model_name: str, *, model_id: str | None = None, quantize: t.LiteralString | None = None, bettertransformer: bool | None = None, + adapter_map: dict[str, str | None] | None = None, + _build_ctx: str | None = None, _extra_dependencies: tuple[str, ...] | None = None, _workers_per_resource: int | float | None = None, _overwrite_existing_bento: bool = False, @@ -270,14 +310,14 @@ def build( llm_config = openllm.AutoConfig.for_model(model_name) - logger.info("Packing '%s' into a Bento with kwargs=%s...", model_name, attrs) + logger.info("Packing '%s' into a Bento%s...", model_name, f" with 'kwargs={attrs}' " if attrs else "") # NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path try: os.environ["OPENLLM_MODEL"] = inflection.underscore(model_name) - framework_envvar = llm_config["env"]["framework_value"] + framework_envvar = llm_config["env"].framework_value llm = t.cast( "_BaseAutoLLMClass", openllm[framework_envvar], # type: ignore (internal API) @@ -286,7 +326,9 @@ def build( model_id=model_id, llm_config=llm_config, quantize=quantize, + adapter_map=adapter_map, bettertransformer=bettertransformer, + return_runner_kwargs=False, **attrs, ) @@ -294,41 +336,40 @@ def build( labels = dict(llm.identifying_params) labels.update({"_type": llm.llm_type, "_framework": framework_envvar}) - service_name = f"generated_{llm_config['model_name']}_service.py" workers_per_resource = first_not_none(_workers_per_resource, default=llm_config["workers_per_resource"]) with fs.open_fs(f"temp://llm_{llm_config['model_name']}") as llm_fs: - # add service.py definition to this temporary folder - codegen.write_service(model_name, llm.model_id, service_name, llm_fs) - bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}") try: bento = bentoml.get(bento_tag) if _overwrite_existing_bento: - logger.info("Overwriting previously saved Bento.") bentoml.delete(bento_tag) bento = _build_bento( bento_tag, - service_name, + llm.config["service_name"], llm_fs, llm, workers_per_resource=workers_per_resource, + adapter_map=adapter_map, quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=_extra_dependencies, + build_ctx=_build_ctx, ) _previously_built = True except bentoml.exceptions.NotFound: logger.info("Building Bento for LLM '%s'", llm_config["start_name"]) bento = _build_bento( bento_tag, - service_name, + llm.config["service_name"], llm_fs, llm, workers_per_resource=workers_per_resource, + adapter_map=adapter_map, quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=_extra_dependencies, + build_ctx=_build_ctx, ) return (bento, _previously_built) if __cli__ else bento except Exception as e: diff --git a/src/openllm/_service.py b/src/openllm/_service.py index 8a9a4789..48d4c081 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -25,6 +25,7 @@ from __future__ import annotations import os import typing as t +import warnings import attr import orjson @@ -40,8 +41,25 @@ if t.TYPE_CHECKING: from starlette.requests import Request from starlette.responses import Response +# The following warnings from bitsandbytes, and probably not that important +# for users to see +warnings.filterwarnings( + "ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization" +) +warnings.filterwarnings( + "ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization" +) +warnings.filterwarnings( + "ignore", + message=( + "The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization" + " are unavailable." + ), +) + model = os.environ.get("OPENLLM_MODEL", "{__model_name__}") # openllm: model name model_id = os.environ.get("OPENLLM_MODEL_ID", "{__model_id__}") # openllm: model id +adapter_map = os.environ.get("OPENLLM_ADAPTER_MAP", """{__model_adapter_map__}""") # openllm: model adapter map llm_config = openllm.AutoConfig.for_model(model) @@ -51,6 +69,7 @@ runner = openllm.Runner( llm_config=llm_config, bettertransformer=llm_config["env"]["bettertransformer_value"], quantize=llm_config["env"]["quantize_value"], + adapter_map=orjson.loads(adapter_map), ensure_available=False, init_local=False, ) @@ -70,7 +89,19 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(responses=responses, configuration=config) -@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON(), route="/v1/metadata") +@svc.api( + input=bentoml.io.Text(), + output=bentoml.io.JSON.from_sample( + sample={ + "model_id": model_id, + "timeout": 3600, + "model_name": llm_config["model_name"], + "framework": "pt", + "configuration": "", + } + ), + route="/v1/metadata", +) def metadata_v1(_: str) -> openllm.MetadataOutput: return openllm.MetadataOutput( model_id=model_id, @@ -81,6 +112,15 @@ def metadata_v1(_: str) -> openllm.MetadataOutput: ) +@svc.api( + input=bentoml.io.Text.from_sample(sample="default"), + output=bentoml.io.JSON.from_sample(sample={"success": True, "error_msg": "some error message"}), + route="/v1/adapters", +) +async def adapters_v1(adapter_name: str) -> dict[str, bool | str]: + return await runner.set_adapter.async_run(adapter_name) + + @attr.define class HfAgentInput: inputs: str @@ -105,3 +145,14 @@ async def hf_agent(request: Request) -> Response: hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])]) svc.mount_asgi_app(hf_app, path="/hf") + + +async def list_adapter_v1(_: Request) -> Response: + res = await runner.list_adapter.async_run() + if res["success"]: + res["result"] = {k: v.to_dict() for k, v in res["result"].items()} + return JSONResponse(res, status_code=200) + + +metadata_app = Starlette(debug=True, routes=[Route("/adapters", list_adapter_v1, methods=["GET"])]) +svc.mount_asgi_app(metadata_app, path="/v1") diff --git a/src/openllm/_trainer.py b/src/openllm/_trainer.py new file mode 100644 index 00000000..905f2bea --- /dev/null +++ b/src/openllm/_trainer.py @@ -0,0 +1,25 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import transformers + + +class PeftTrainer(transformers.Trainer): + ... + + +class PeftSaveCallback(transformers.TrainerCallback): + ... diff --git a/src/openllm/cli.py b/src/openllm/cli.py index a8877b09..6a11e65b 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -56,8 +56,10 @@ from .utils import first_not_none from .utils import get_debug_mode from .utils import get_quiet_mode from .utils import gpu_count +from .utils import is_peft_available from .utils import is_torch_available from .utils import is_transformers_supports_agent +from .utils import resolve_user_filepath from .utils import set_debug_mode from .utils import set_quiet_mode @@ -105,13 +107,14 @@ class NargsOptions(cog.GroupedOption): options. """ + _nargs_parser: click.parser.Option + _prev_parser_process: t.Callable[[t.Any, click.parser.ParsingState], None] + def __init__(self, *args: t.Any, **attrs: t.Any): nargs = attrs.pop("nargs", -1) if nargs != -1: raise OpenLLMException(f"'nargs' is set, and must be -1 instead of {nargs}") super(NargsOptions, self).__init__(*args, **attrs) - self._prev_parser_process: t.Callable[[t.Any, click.parser.ParsingState], None] | None = None - self._nargs_parser: click.parser.Option | None = None def add_to_parser(self, parser: click.OptionParser, ctx: click.Context) -> None: def _parser(value: t.Any, state: click.parser.ParsingState): @@ -238,7 +241,7 @@ def quantize_option(factory: t.Any, build: bool = False): ) -def bettertransformer_option(factory: t.Any, model_env: ModelEnv | None = None): +def bettertransformer_option(factory: t.Any, build: bool = False, model_env: ModelEnv | None = None): envvar = None if model_env is not None: envvar = model_env.bettertransformer @@ -246,14 +249,35 @@ def bettertransformer_option(factory: t.Any, model_env: ModelEnv | None = None): "--bettertransformer", is_flag=True, default=None, - help="Use BetterTransformer wrapper to serve model. This will applies during serving time.", + help="Apply FasterTransformer wrapper to serve model. This will applies during serving time." + if not build + else "Set defaul environment variable whether to serve this model with FasterTransformer in build time.", envvar=envvar, show_envvar=True if envvar is not None else False, ) +_adapter_mapping_key = "adapter_map" + + +def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> dict[str, str] | None: + if not value: + return + if _adapter_mapping_key not in ctx.params: + ctx.params[_adapter_mapping_key] = {} + for v in value: + adapter_id, *adapter_name = v.rsplit(":", maxsplit=1) + try: + # try to resolve the full path if users pass in relative, + # currently only support one level of resolve path. + adapter_id = resolve_user_filepath(adapter_id, os.getcwd()) + except FileNotFoundError: + pass + ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None + + class OpenLLMCommandGroup(BentoMLCommandGroup): - NUMBER_OF_COMMON_PARAMS = 3 + NUMBER_OF_COMMON_PARAMS = 4 # parameters in common_params + 1 faked group option header @staticmethod def common_params(f: F[P, t.Any]) -> ClickFunctionWrapper[..., t.Any]: @@ -265,11 +289,14 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): from bentoml._internal.configuration import DEBUG_ENV_VAR from bentoml._internal.configuration import QUIET_ENV_VAR - @click.option("-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output.") - @click.option( + @cog.optgroup.group("Miscellaneous options") + @cog.optgroup.option( + "-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output." + ) + @cog.optgroup.option( "--debug", "--verbose", envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help="Print out debug logs." ) - @click.option( + @cog.optgroup.option( "--do-not-track", is_flag=True, default=False, @@ -507,7 +534,7 @@ _http_server_args = parse_serve_args(False) _grpc_server_args = parse_serve_args(True) -def start_model_command( +def start_command_factory( model_name: str, _context_settings: dict[str, t.Any] | None = None, _serve_grpc: bool = False, @@ -531,7 +558,7 @@ def start_model_command( configure_logging() llm_config = openllm.AutoConfig.for_model(model_name) - env: ModelEnv = llm_config["env"] + env = llm_config["env"] docstring = f"""\ {env.start_docstring} @@ -576,7 +603,7 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_ @group.command(**command_attrs) @llm_config.to_click_options @serve_decorator - @cog.optgroup.group("General LLM Options") + @cog.optgroup.group("General LLM Options", help="The following options are related to running the LLM Server.") @cog.optgroup.option( "--server-timeout", type=int, @@ -591,7 +618,25 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_ default=False, help="Bypass auto model checks and setup. This option is ahead-of-serving time.", ) - @cog.optgroup.group("LLM Optimization Options.") + @cog.optgroup.group( + "LLM Optimization Options", + help="""\ + These options are related for dynamic optimization on the fly. Current supported strategies: + + - int8: Quantize the model with 8bit (bitsandbytes required) + + - int4: Quantize the model with 4bit (bitsandbytes required) + + - bettertransformer: Convert given model to FastTransformer + + The following are currently being worked on: + + - GPTQ: [paper](https://arxiv.org/abs/2210.17323) + + - DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/) + + """, + ) @cog.optgroup.option( "--device", type=tuple, @@ -604,6 +649,32 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_ ) @quantize_option(cog.optgroup) @bettertransformer_option(cog.optgroup, model_env=env) + @cog.optgroup.group( + "Fine-tuning related options", + help="""\ + Note that the argument `--adapter-id` can accept the following format: + + - `--adapter-id /path/to/adapter` (local adapter) + + - `--adapter-id remote/adapter` (remote adapter from HuggingFace Hub) + + - `--adapter-id remote/adapter:eng_lora` (two previous adapter options with the given adapter_name) + + ```bash + + openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora + + ``` + """, + ) + @cog.optgroup.option( + "--adapter-id", + default=None, + help="Optional name or path for given LoRA adapter" + f" to wrap '{model_name}'", + multiple=True, + callback=_id_callback, + metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]", + ) @click.pass_context def model_start( ctx: click.Context, @@ -616,6 +687,10 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_ fast: bool, **attrs: t.Any, ) -> openllm.LLMConfig: + adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None) + # remove adapter_id + attrs.pop("adapter_id", None) + config, server_attrs = llm_config.model_validate_click(**attrs) # Create a new model env to work with the envvar during CLI invocation @@ -631,6 +706,13 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_ _echo("Quantization is currently only available for PyTorch models.", fg="red") ctx.exit(1) + if adapter_map and not is_peft_available(): + _echo( + "Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'", + fg="red", + ) + ctx.exit(1) + # We need to handle None separately here, as env from subprocess doesn't # accept None value. env = ModelEnv(env.model_name, bettertransformer=bettertransformer, quantize=quantize) @@ -686,22 +768,29 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_ if fast and not get_quiet_mode(): _echo( - f"Make sure to download the model before 'start': 'openllm download {model_name}{'--model-id ' + model_id if model_id else ''}'", + f"Fast mode is enabled. Make sure to download the model before 'start': 'openllm download {model_name}{'--model-id ' + model_id if model_id else ''}'", fg="yellow", ) - automodel_attrs = { - "model_id": model_id, - "llm_config": config, - "ensure_available": not fast, - } + automodel_attrs: dict[str, t.Any] = {} if framework_envvar == "pt": automodel_attrs.update({"quantize": quantize, "bettertransformer": bettertransformer}) + if adapter_map: + _echo(f"OpenLLM will convert '{model_name}' to use provided adapters layers: {list(adapter_map)}") + automodel_attrs.update({"adapter_map": adapter_map}) + llm = t.cast( "_BaseAutoLLMClass", openllm[framework_envvar], # type: ignore (internal API) - ).for_model(model_name, **automodel_attrs) + ).for_model( + model_name, + model_id=model_id, + llm_config=config, + ensure_available=not fast, + return_runner_kwargs=False, + **automodel_attrs, + ) start_env.update( { @@ -709,6 +798,7 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_ env.config: llm.config.model_dump_json().decode(), "OPENLLM_MODEL": model_name, "OPENLLM_MODEL_ID": llm.model_id, + "OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map), "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_CONFIG_OPTIONS": _bentoml_config_options_env, "BENTOML_HOME": os.environ.get("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), @@ -746,9 +836,9 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_ return model_start -_cached_http = {key: start_model_command(key, _context_settings=_CONTEXT_SETTINGS) for key in openllm.CONFIG_MAPPING} +_cached_http = {key: start_command_factory(key, _context_settings=_CONTEXT_SETTINGS) for key in openllm.CONFIG_MAPPING} _cached_grpc = { - key: start_model_command(key, _context_settings=_CONTEXT_SETTINGS, _serve_grpc=True) + key: start_command_factory(key, _context_settings=_CONTEXT_SETTINGS, _serve_grpc=True) for key in openllm.CONFIG_MAPPING } @@ -765,7 +855,7 @@ def _start( if framework is not None: os.environ[_ModelEnv.framework] = framework - start_model_command(model_name, _serve_grpc=_serve_grpc)(standalone_mode=False, **attrs) + start_command_factory(model_name, _serve_grpc=_serve_grpc)(standalone_mode=False, **attrs) start = functools.partial(_start, _serve_grpc=False) @@ -792,7 +882,17 @@ start_grpc = functools.partial(_start, _serve_grpc=True) nargs=1, metavar="FEATURE[,FEATURE]", ) +@click.option( + "--adapter-id", + default=None, + help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.", + multiple=True, + metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]", +) +@click.option("--build-ctx", default=".", help="Build context. This is required if --adapter-id uses relative path") +@click.pass_context def build( + ctx: click.Context, model_name: str, model_id: str | None, overwrite: bool, @@ -801,6 +901,9 @@ def build( enable_features: tuple[str] | None, bettertransformer: bool | None, workers_per_resource: float | None, + adapter_id: tuple[str, ...], + build_ctx: str | None, + **attrs: t.Any, ): """Package a given models into a Bento. @@ -813,6 +916,20 @@ def build( > NOTE: To run a container built from this Bento with GPU support, make sure > to have https://github.com/NVIDIA/nvidia-container-toolkit install locally. """ + adapter_map: dict[str, str | None] | None = None + + if adapter_id: + if not build_ctx: + _echo("'build_ctx' must not be None when '--adapter-id' is passsed.", fg="red") + ctx.exit(1) + + adapter_map = {} + for v in adapter_id: + _adapter_id, *adapter_name = v.rsplit(":", maxsplit=1) + # We don't resolve full path here, leave it to build + # we are just doing the parsing here. + adapter_map[_adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None + if output == "porcelain": set_quiet_mode(True) configure_logging() @@ -824,12 +941,15 @@ def build( if enable_features: enable_features = tuple(itertools.chain.from_iterable((s.split(",") for s in enable_features))) + # TODO: xxx bento, _previously_built = openllm.build( model_name, __cli__=True, model_id=model_id, quantize=quantize, bettertransformer=bettertransformer, + adapter_map=adapter_map, + _build_ctx=build_ctx, _extra_dependencies=enable_features, _workers_per_resource=workers_per_resource, _overwrite_existing_bento=overwrite, @@ -851,8 +971,8 @@ def build( + "* Push to BentoCloud with `bentoml push`:\n" + f" $ bentoml push {bento.tag}\n" + "* Containerize your Bento with `bentoml containerize`:\n" - + f" $ bentoml containerize {bento.tag}\n" - + " Tip: To enable additional BentoML feature for 'containerize', " + + f" $ bentoml containerize {bento.tag}\n\n" + + " Tip: To enable additional BentoML features for 'containerize', " + "use '--enable-features=FEATURE[,FEATURE]' " + "[see 'bentoml containerize -h' for more advanced usage]\n", fg="blue", diff --git a/src/openllm/exceptions.py b/src/openllm/exceptions.py index 05371869..c386c28d 100644 --- a/src/openllm/exceptions.py +++ b/src/openllm/exceptions.py @@ -41,3 +41,7 @@ class MissingAnnotationAttributeError(OpenLLMException): class MissingDependencyError(BaseException): """Raised when a dependency is missing.""" + + +class FineTuneStrategyNotSupportedError(OpenLLMException): + """Raised when a fine-tune strategy is not supported for given LLM.""" diff --git a/src/openllm/models/auto/factory.py b/src/openllm/models/auto/factory.py index a277c4d4..bc183f74 100644 --- a/src/openllm/models/auto/factory.py +++ b/src/openllm/models/auto/factory.py @@ -27,6 +27,14 @@ import openllm from .configuration_auto import AutoConfig +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if hasattr(t, "get_overloads"): + from typing import overload +else: + from typing_extensions import overload + + if t.TYPE_CHECKING: from collections import _odict_items from collections import _odict_keys @@ -54,7 +62,7 @@ class _BaseAutoLLMClass: "Please use '{self.__class__.__name__}.Runner(model_name)' instead." ) - @t.overload + @overload @classmethod def for_model( cls, @@ -67,7 +75,7 @@ class _BaseAutoLLMClass: ) -> openllm.LLM[t.Any, t.Any]: ... - @t.overload + @overload @classmethod def for_model( cls, @@ -116,7 +124,7 @@ class _BaseAutoLLMClass: llm = cls._model_mapping[type(llm_config)].from_pretrained( model_id, llm_config=llm_config, - **llm_config.__openllm_extras__, + **attrs, ) if ensure_available: logger.debug( @@ -234,13 +242,13 @@ class _LazyAutoMapping(ConfigModelOrderedDict): ] return t.cast(ConfigModelKeysView, mapping_keys + list(self._extra_content.keys())) - @t.overload + @overload def get( self, key: type[openllm.LLMConfig], default: t.Any, mapping_type: t.Literal["default"] = "default" ) -> type[openllm.LLM[t.Any, t.Any]]: ... - @t.overload + @overload def get(self, key: str, default: t.Any, mapping_type: t.Literal["name2model", "name2config"] = ...) -> str: ... diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index f42fdbb8..cd8e856e 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -70,9 +70,6 @@ class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedToken external_modules=[importlib.import_module(pipeline.__module__)], ) finally: - import gc - - gc.collect() if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/src/openllm/models/falcon/modeling_falcon.py b/src/openllm/models/falcon/modeling_falcon.py index 2118d93e..2a61c3ee 100644 --- a/src/openllm/models/falcon/modeling_falcon.py +++ b/src/openllm/models/falcon/modeling_falcon.py @@ -65,10 +65,6 @@ class Falcon(openllm.LLM["transformers.TextGenerationPipeline", "transformers.Pr external_modules=[importlib.import_module(model.__module__)], ) finally: - import gc - - # NOTE: We need to free the cache after saving here so that we can load it back later on. - gc.collect() torch.cuda.empty_cache() def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: diff --git a/src/openllm/models/opt/configuration_opt.py b/src/openllm/models/opt/configuration_opt.py index a57a3a7c..f1dbb288 100644 --- a/src/openllm/models/opt/configuration_opt.py +++ b/src/openllm/models/opt/configuration_opt.py @@ -43,6 +43,16 @@ class OPTConfig(openllm.LLMConfig): "facebook/opt-6.7b", "facebook/opt-66b", ], + "fine_tune_strategies": ( + { + "adapter_type": "lora", + "r": 16, + "lora_alpha": 32, + "target_modules": ["q_proj", "v_proj"], + "lora_dropout": 0.05, + "bias": "none", + }, + ), } format_outputs: bool = openllm.LLMConfig.Field( diff --git a/src/openllm/models/opt/modeling_opt.py b/src/openllm/models/opt/modeling_opt.py index 62f042c0..4eb9e435 100644 --- a/src/openllm/models/opt/modeling_opt.py +++ b/src/openllm/models/opt/modeling_opt.py @@ -21,15 +21,18 @@ import bentoml import openllm from ..._prompt import default_formatter +from ...utils import is_peft_available from .configuration_opt import DEFAULT_PROMPT_TEMPLATE if t.TYPE_CHECKING: + import peft import torch import transformers # noqa else: torch = openllm.utils.LazyLoader("torch", globals(), "torch") + peft = openllm.utils.LazyLoader("peft", globals(), "peft") transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) @@ -135,9 +138,17 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer if torch.cuda.is_available() and torch.cuda.device_count() == 1: self.model.cuda() - input_ids = t.cast(torch.Tensor, self.tokenizer(prompt, return_tensors="pt").input_ids).to(self.device) + if is_peft_available() and isinstance(self.model, peft.PeftModel): + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + else: + inputs = { + "inputs": t.cast(torch.Tensor, self.tokenizer(prompt, return_tensors="pt").input_ids).to( + self.device + ) + } + generated_tensors = self.model.generate( - input_ids, + **inputs, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), ) diff --git a/src/openllm/models/starcoder/modeling_starcoder.py b/src/openllm/models/starcoder/modeling_starcoder.py index 22d86bc5..0a97c11f 100644 --- a/src/openllm/models/starcoder/modeling_starcoder.py +++ b/src/openllm/models/starcoder/modeling_starcoder.py @@ -80,10 +80,7 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers. try: return bentoml.transformers.save_model(tag, model, custom_objects={"tokenizer": tokenizer}) finally: - import gc - # NOTE: We need to free the cache after saving here so that we can load it back later on. - gc.collect() torch.cuda.empty_cache() def sanitize_parameters( diff --git a/src/openllm/playground/README b/src/openllm/playground/README index 3c1f5d20..4fa4d7b0 100644 --- a/src/openllm/playground/README +++ b/src/openllm/playground/README @@ -7,3 +7,5 @@ Refer to each script docstring for more information. python -m openllm.playground.general --help python -m openllm.plaground.ft_opt_lora --help + +NOte that this folder will be considered 'breaking' most of the cases, so use will care. diff --git a/src/openllm/playground/ft_opt_lora.py b/src/openllm/playground/ft_opt_lora.py index 3908f4f3..f48fb8b9 100644 --- a/src/openllm/playground/ft_opt_lora.py +++ b/src/openllm/playground/ft_opt_lora.py @@ -48,12 +48,15 @@ if openllm.utils.DEBUG: if os.system(f"pip install -U {' '.join(_deps)}") != 0: raise SystemExit(1) -os.environ["BITSANDBYTES_NOWELCOME"] = str(1) - from datasets import load_dataset from peft import LoraConfig from peft import get_peft_model -from peft import prepare_model_for_int8_training + + +if openllm.utils.pkg.pkg_version_info("peft")[:2] >= (0, 4): + from peft import prepare_model_for_kbit_training +else: + from peft import prepare_model_for_int8_training as prepare_model_for_kbit_training import transformers @@ -75,7 +78,7 @@ def load_model(model_id: str) -> tuple[PeftModel, transformers.GPT2TokenizerFast model, tokenizer = opt.model, opt.tokenizer # prep the model for int8 training - model = prepare_model_for_int8_training(model) + model = prepare_model_for_kbit_training(model) lora_config = LoraConfig( r=16, diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index a1fb418d..9d3dd6f5 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -103,6 +103,8 @@ _LOGGING_CONFIG["loggers"].update( def configure_logging() -> None: + """Configure logging for OpenLLM. Behaves similar to how BentoML loggers + are being configured.""" if get_quiet_mode(): _LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR _LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR @@ -136,17 +138,22 @@ _import_structure = { "analytics": [], "codegen": [], "dantic": [], + "constants": [], + "representation": ["ReprMixin"], "import_utils": [ "OPTIONAL_DEPENDENCIES", "ENV_VARS_TRUE_VALUES", "DummyMetaclass", "ModelEnv", + "requires_dependencies", "is_cpm_kernels_available", "is_einops_available", "is_flax_available", "is_tf_available", "is_torch_available", "is_bitsandbytes_available", + "is_peft_available", + "is_datasets_available", "is_transformers_supports_kbit", "is_transformers_supports_agent", "require_backends", @@ -161,6 +168,7 @@ if t.TYPE_CHECKING: from . import bentoml_cattr as bentoml_cattr from . import codegen as codegen from . import configure_logging as configure_logging + from . import constants as constants from . import copy_file_to_fs_folder as copy_file_to_fs_folder from . import dantic as dantic from . import first_not_none as first_not_none @@ -180,14 +188,18 @@ if t.TYPE_CHECKING: from .import_utils import ModelEnv as ModelEnv from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available + from .import_utils import is_datasets_available as is_datasets_available from .import_utils import is_einops_available as is_einops_available from .import_utils import is_flax_available as is_flax_available + from .import_utils import is_peft_available as is_peft_available from .import_utils import is_tf_available as is_tf_available from .import_utils import is_torch_available as is_torch_available from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit from .import_utils import require_backends as require_backends + from .import_utils import requires_dependencies as requires_dependencies from .lazy import LazyModule as LazyModule + from .representation import ReprMixin as ReprMixin else: import sys diff --git a/src/openllm/utils/codegen.py b/src/openllm/utils/codegen.py index 68c11309..9fb8fbd8 100644 --- a/src/openllm/utils/codegen.py +++ b/src/openllm/utils/codegen.py @@ -15,10 +15,13 @@ from __future__ import annotations import logging +import os import string import typing as t from pathlib import Path +import orjson + if t.TYPE_CHECKING: from fs.base import FS @@ -33,12 +36,13 @@ else: DictStrAny = dict -_T = t.TypeVar("_T") +_T = t.TypeVar("_T", bound=t.Callable[..., t.Any]) logger = logging.getLogger(__name__) OPENLLM_MODEL_NAME = "# openllm: model name" OPENLLM_MODEL_ID = "# openllm: model id" +OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map" class ModelNameFormatter(string.Formatter): @@ -63,10 +67,16 @@ class ModelIdFormatter(ModelNameFormatter): model_keyword: t.LiteralString = "__model_id__" +class ModelAdapterMapFormatter(ModelNameFormatter): + model_keyword: t.LiteralString = "__model_adapter_map__" + + _service_file = Path(__file__).parent.parent / "_service.py" -def write_service(model_name: str, model_id: str, target_path: str, llm_fs: FS): +def write_service( + model_name: str, model_id: str, adapter_map: dict[str, str | None] | None, target_path: str, llm_fs: FS +): from . import DEBUG logger.debug("Generating service for %s to %s", model_name, target_path) @@ -76,11 +86,22 @@ def write_service(model_name: str, model_id: str, target_path: str, llm_fs: FS): # modify with model name for it in src_contents: if OPENLLM_MODEL_NAME in it: - src_contents[src_contents.index(it)] = ModelNameFormatter(model_name).vformat(it) + src_contents[src_contents.index(it)] = ( + ModelNameFormatter(model_name).vformat(it)[: -(len(OPENLLM_MODEL_NAME) + 3)] + "\n" + ) elif OPENLLM_MODEL_ID in it: - src_contents[src_contents.index(it)] = ModelIdFormatter(model_id).vformat(it) + src_contents[src_contents.index(it)] = ( + ModelIdFormatter(model_id).vformat(it)[: -(len(OPENLLM_MODEL_ID) + 3)] + "\n" + ) + elif OPENLLM_MODEL_ADAPTER_MAP in it and adapter_map: + src_contents[src_contents.index(it)] = ( + ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[ + : -(len(OPENLLM_MODEL_ADAPTER_MAP) + 3) + ] + + "\n" + ) - script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n" + "".join(src_contents) + script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents) if DEBUG: logger.info("Generated script:\n%s", script) @@ -192,7 +213,7 @@ def generate_function( if annotations: meth.__annotations__ = annotations - if DEBUG: + if DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3: logger.info("Generated script for %s:\n\n%s", typ, script) return meth diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index e73505ab..412abac2 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -31,6 +31,13 @@ from click import ParamType import openllm +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if hasattr(t, "get_overloads"): + from typing import overload +else: + from typing_extensions import overload + if t.TYPE_CHECKING: from attr import _ValidatorType @@ -42,7 +49,7 @@ if t.TYPE_CHECKING: _T = t.TypeVar("_T") -@t.overload +@overload def attrs_to_options( name: str, field: attr.Attribute[t.Any], @@ -53,7 +60,7 @@ def attrs_to_options( ... -@t.overload +@overload def attrs_to_options( # type: ignore (overlapping overload) name: str, field: attr.Attribute[O_co], diff --git a/src/openllm/utils/import_utils.py b/src/openllm/utils/import_utils.py index 1d1a205e..6476b69b 100644 --- a/src/openllm/utils/import_utils.py +++ b/src/openllm/utils/import_utils.py @@ -17,6 +17,7 @@ Some imports utils are vendorred from transformers/utils/import_utils.py for per """ from __future__ import annotations +import functools import importlib import importlib.metadata import importlib.util @@ -32,9 +33,19 @@ from packaging import version from bentoml._internal.utils import LazyLoader from bentoml._internal.utils import pkg +from .representation import ReprMixin + + +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if hasattr(t, "get_overloads"): + from typing import overload +else: + from typing_extensions import overload if t.TYPE_CHECKING: BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]] + from .._types import P else: BackendOrderredDict = OrderedDict @@ -64,9 +75,12 @@ def _is_package_available(package: str) -> bool: _torch_available = importlib.util.find_spec("torch") is not None _tf_available = importlib.util.find_spec("tensorflow") is not None _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + +_peft_available = _is_package_available("peft") _einops_available = _is_package_available("einops") _cpm_kernel_available = _is_package_available("cpm_kernels") _bitsandbytes_available = _is_package_available("bitsandbytes") +_datasets_available = _is_package_available("datasets") def is_transformers_supports_kbit() -> bool: @@ -77,6 +91,14 @@ def is_transformers_supports_agent() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 29) +def is_datasets_available() -> bool: + return _datasets_available + + +def is_peft_available() -> bool: + return _peft_available + + def is_einops_available(): return _einops_available @@ -157,6 +179,33 @@ def is_flax_available(): return _flax_available +def requires_dependencies( + package: str | list[str], *, extra: str | list[str] | None = None +) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]: + import openllm.utils + + if isinstance(package, str): + package = [package] + if isinstance(extra, str): + extra = [extra] + + def decorator(func: t.Callable[P, t.Any]): + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: + for p in package: + cached_check: t.Callable[[], bool] | None = getattr(openllm.utils, f"is_{p}_available", None) + if not ((cached_check is not None and cached_check()) or _is_package_available(p)): + raise ImportError( + f"{func.__name__} requires '{p}' to be available locally (Currently missing)." + f"Make sure to have {p} to be installed: 'pip install \"{p if not extra else 'openllm['+', '.join(extra)+']'}\"'" + ) + return func(*args, **kwargs) + + return wrapper + + return decorator + + PYTORCH_IMPORT_ERROR_WITH_TF = """\ {0} requires the PyTorch library but it was not found in your environment. However, we were able to find a TensorFlow installation. TensorFlow classes begin @@ -249,19 +298,44 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]): raise ImportError("".join(failed)) -class ModelEnv: +class ModelEnv(ReprMixin): model_name: str - if t.TYPE_CHECKING: - config: property - model_id: property - quantize: property - framework: property - bettertransformer: property + @property + def __repr_keys__(self) -> set[str]: + return {"config", "model_id", "quantize", "framework", "bettertransformer"} - framework_value: property - quantize_value: property - bettertransformer_value: property + if t.TYPE_CHECKING: + config: str + model_id: str + quantize: str + framework: str + bettertransformer: str + + framework_value: t.Literal["pt", "tf", "flax"] + quantize_value: str | None + bettertransformer_value: str | None + + # fmt: off + + @overload + def __getitem__(self, item: t.Literal["config"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["model_id"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["quantize"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["framework"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ... + @overload + def __getitem__(self, item: t.Literal['framework_value']) -> t.Literal['pt', 'tf', 'flax']: ... + @overload + def __getitem__(self, item: t.Literal['quantize_value']) -> str | None: ... + @overload + def __getitem__(self, item: t.Literal['bettertransformer_value']) -> str | None: ... + + # fmt: on def __getitem__(self, item: str | t.Any) -> t.Any: if hasattr(self, item): @@ -284,7 +358,7 @@ class ModelEnv: # gen properties env value attributes_with_values = { - "quantize": (bool, quantize), + "quantize": (str, quantize), "bettertransformer": (bool, bettertransformer), "framework": (str, "pt"), } diff --git a/src/openllm/utils/representation.py b/src/openllm/utils/representation.py new file mode 100644 index 00000000..f953dad7 --- /dev/null +++ b/src/openllm/utils/representation.py @@ -0,0 +1,61 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing as t +from abc import abstractmethod + +import attr +import orjson + + +if t.TYPE_CHECKING: + ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]] + + +class ReprMixin: + """This class display possible representation of given class. + It can be used for implementing __rich_pretty__ and __pretty__ methods in the future. + Most subclass needs to implement a __repr_keys__ property. + + Based on the design from Pydantic. + The __repr__ will display the json representation of the object for easier interaction. + The __str__ will display either __attrs_repr__ or __repr_str__. + """ + + @property + @abstractmethod + def __repr_keys__(self) -> set[str]: + """This can be overriden by base class using this mixin.""" + + def __repr__(self) -> str: + from . import bentoml_cattr + + serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()} + return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}" + + def __str__(self) -> str: + return self.__repr_str__(" ") + + def __repr_name__(self) -> str: + """Name of the instance's class, used in __repr__.""" + return self.__class__.__name__ + + def __repr_str__(self, join_str: str) -> str: + return join_str.join(repr(v) if a is None else f"{a}={repr(v)}" for a, v in self.__repr_args__()) + + def __repr_args__(self) -> ReprArgs: + attrs = ((k, getattr(self, k)) for k in self.__repr_keys__) + return tuple((k, v) for k, v in attrs if v) diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index 577f750b..366e2ca7 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -26,6 +26,13 @@ import openllm import transformers +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if hasattr(t, "get_overloads"): + from typing import overload +else: + from typing_extensions import overload + if t.TYPE_CHECKING: from openllm.models.auto.factory import _BaseAutoLLMClass @@ -162,11 +169,11 @@ class BaseClient(ClientMixin): def health(self) -> t.Any: raise NotImplementedError - @t.overload + @overload def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ... - @t.overload + @overload def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: ... @@ -222,11 +229,11 @@ class BaseAsyncClient(ClientMixin): async def health(self) -> t.Any: raise NotImplementedError - @t.overload + @overload async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ... - @t.overload + @overload async def query( self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any ) -> dict[str, t.Any]: diff --git a/tools/assert-model-table-latest b/tools/assert-model-table-latest index bc9f438e..0bd78414 100755 --- a/tools/assert-model-table-latest +++ b/tools/assert-model-table-latest @@ -4,9 +4,11 @@ from __future__ import annotations import os import subprocess +import sys from markdown_it import MarkdownIt + md = MarkdownIt() ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -16,7 +18,17 @@ with open(os.path.join(ROOT, "README.md"), "r") as f: # NOTE: Currently, we only have one table in README, which is the Model readme. table = [r for r in readme if r.type == "html_block" and r.content.startswith(" str: + if "NotRequired" in annotations: + return annotations[len("NotRequired[") : -1] + elif "Required" in annotations: + return annotations[len("Required[") : -1] + else: + return annotations + + +def main() -> int: + transformed = {"fine_tune_strategies": "t.Dict[AdapterType, FineTuneConfig]"} + with _TARGET_FILE.open("r") as f: + processed = f.readlines() + + start_idx, end_idx = processed.index(" " * 4 + START_COMMENT), processed.index(" " * 4 + END_COMMENT) + + # convention to use t.TYPE_CHECKING + lines = [" " * 4 + "if t.TYPE_CHECKING:\n"] + for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items(): + lines.extend( + list( + map( + lambda line: " " * 8 + line, + [ + "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", + f'def __getitem__(self, item: t.Literal["{keys}"] = ...) -> {transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n', + ], + ) + ) + ) + # special case variables: generation_class, extras + lines.extend( + list( + map( + lambda line: " " * 8 + line, + [ + "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", + 'def __getitem__(self, item: t.Literal["generation_class"] = ...) -> t.Type[GenerationConfig]: ...\n', + "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", + 'def __getitem__(self, item: t.Literal["extras"] = ...) -> t.Dict[str, t.Any]: ...\n', + ], + ) + ) + ) + + processed = ( + processed[:start_idx] + [" " * 4 + START_COMMENT] + lines + [" " * 4 + END_COMMENT] + processed[end_idx + 1 :] + ) + with _TARGET_FILE.open("w") as f: + f.writelines(processed) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/update-optional-dependencies.py b/tools/update-optional-dependencies.py index 30c2d7b3..4b6bf568 100755 --- a/tools/update-optional-dependencies.py +++ b/tools/update-optional-dependencies.py @@ -85,6 +85,7 @@ _NIGHTLY_MAPPING: dict[str, Dependencies] = { "optimum": Dependencies.from_tuple("optimum", "huggingface/optimum", "main", None), "accelerate": Dependencies.from_tuple("accelerate", "huggingface/accelerate", "main", None), "bitsandbytes": Dependencies.from_tuple("bitsandbytes", "TimDettmers/bitsandbytes", "main", None), + "deepspeed": Dependencies.from_tuple("deepspeed", "microsoft/deepspeed", "master", None), } FINE_TUNE_DEPS = ["peft", "bitsandbytes", "datasets", "accelerate", "deepspeed"] @@ -125,8 +126,8 @@ def main() -> int: with open(os.path.join(ROOT, "pyproject.toml"), "w") as f: f.write(tomlkit.dumps(pyproject)) - with open(os.path.join(ROOT, "nightly-requirements.generated.txt"), "w") as f: - f.write("# This file is generated by `./tools/update-optional-dependencies.py`\n# DO NOT EDIT\n-e .\n") + with open(os.path.join(ROOT, "nightly-requirements.txt"), "w") as f: + f.write("# This file is generated by `./tools/update-optional-dependencies.py`\n# DO NOT EDIT\n-e .[all]\n") f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values()]) if shutil.which("taplo"): diff --git a/typings/attr/__init__.pyi b/typings/attr/__init__.pyi index 5127759b..62d2fe13 100644 --- a/typings/attr/__init__.pyi +++ b/typings/attr/__init__.pyi @@ -493,7 +493,7 @@ dataclass = ... def _make_init( cls: type[AttrsInstance], - attrs: tuple[Attribute[_T]], + attrs: tuple[Attribute[Any]], pre_init: bool, post_init: bool, frozen: bool, diff --git a/typings/click_option_group/_decorators.pyi b/typings/click_option_group/_decorators.pyi index 09b4769b..f62b4ad8 100644 --- a/typings/click_option_group/_decorators.pyi +++ b/typings/click_option_group/_decorators.pyi @@ -99,7 +99,7 @@ class _OptGroup(Generic[O_co]): :param attrs: Additional parameters of option group class """ ... - def option(self, *param_decls: Any, **attrs: Any) -> F[P, ClickFunctionWrapper[P, O_co]]: + def option(self, *param_decls: Any, **attrs: Any) -> FC: """The decorator adds a new option to the group The decorator is lazy. It adds option decls and attrs. diff --git a/typings/deepmerge/merger.pyi b/typings/deepmerge/merger.pyi index 90d62a4a..c5a86d3f 100644 --- a/typings/deepmerge/merger.pyi +++ b/typings/deepmerge/merger.pyi @@ -21,6 +21,6 @@ class Merger: fallback_strategies: List[str], type_conflict_strategies: List[str], ) -> None: ... - def merge(self, base: ConfigDictType, nxt: ConfigDictType) -> None: ... + def merge(self, base: ConfigDictType, nxt: ConfigDictType) -> ConfigDictType: ... def type_conflict_strategy(self, *args: Any) -> Any: ... def value_strategy(self, path: str, base: StrategyList, nxt: StrategyList) -> None: ...