feat: serve adapter layers (#52)

This commit is contained in:
Aaron Pham
2023-06-23 10:07:15 -04:00
committed by GitHub
parent 5981e49342
commit dfca956fad
33 changed files with 1896 additions and 496 deletions

2
.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
nightly-requirements.txt linguist-generated=true
* text=auto eol=lf

View File

@@ -346,10 +346,9 @@ async def prompt(input_text: str) -> str:
OpenLLM seamlessly integrates with Hugging Face Agents.
> **Warning** The Hugging Face Agent is still in the experimental stage. It is
> recommended to OpenLLM with
> `pip install -r nightly-requirements.generated.txt` to get the latest API
> update for Hugging Face agent.
> **Warning** The HuggingFace Agent is still at experimental stage. It is
> recommended to OpenLLM with `pip install -r nightly-requirements.txt` to get
> the latest API update for HuggingFace agent.
```python
import transformers

45
changelog.d/52.feature.md Normal file
View File

@@ -0,0 +1,45 @@
#### Serving LLM with fine-tuned LoRA, QLoRA adapters layers
Then the given fine tuning weights can be served with the model via
`openllm start`:
```bash
openllm start opt --model-id facebook/opt-6.7b --adapter-id /path/to/adapters
```
If you just wish to try some pretrained adapter checkpoint, you can use
`--adapter-id`:
```bash
openllm start opt --model-id facebook/opt-6.7b --adapter-id aarnphm/opt-6.7b-lora
```
To use multiple adapters, use the following format:
```bash
openllm start opt --model-id facebook/opt-6.7b --adapter-id aarnphm/opt-6.7b-lora --adapter-id aarnphm/opt-6.7b-lora:french_lora
```
By default, the first `adapter-id` will be the default lora layer, but
optionally users can change what lora layer to use for inference via
`/v1/adapters`:
```bash
curl -X POST http://localhost:3000/v1/adapters --json '{"adapter_name": "vn_lora"}'
```
> Note that for multiple `adapter-name` and `adapter-id`, it is recomended to
> update to use the default adapter before sending the inference, to avoid any
> performance degradation
To include this into the Bento, one can also provide a `--adapter-id` into
`openllm build`:
```bash
openllm build opt --model-id facebook/opt-6.7b --adapter-id ...
```
### Rework
Separate out configuration builder, to make it more flexible for future
configuration generation.

View File

@@ -1,9 +1,10 @@
# This file is generated by `./tools/update-optional-dependencies.py`
# DO NOT EDIT
-e .
-e .[all]
bentoml[grpc,io] @ git+https://github.com/bentoml/bentoml.git@main
peft @ git+https://github.com/huggingface/peft.git@main
transformers[torch,tokenizers,accelerate] @ git+https://github.com/huggingface/transformers.git@main
optimum @ git+https://github.com/huggingface/optimum.git@main
accelerate @ git+https://github.com/huggingface/accelerate.git@main
bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git@main
deepspeed @ git+https://github.com/microsoft/deepspeed.git@master

View File

@@ -58,15 +58,15 @@ requires-python = ">=3.8"
# NOTE: Don't modify project.optional-dependencies
# as it is managed by ./tools/update-optional-dependencies.py
[project.optional-dependencies]
agents = ["transformers[agents]", "diffusers", "soundfile"]
agents = ["transformers[agents]>=4.30", "diffusers", "soundfile"]
all = [
"openllm[chatglm]",
"openllm[starcoder]",
"openllm[falcon]",
"openllm[agents]",
"openllm[flan-t5]",
"openllm[fine-tune]",
"openllm[openai]",
"openllm[flan-t5]",
]
chatglm = ["cpm_kernels", "sentencepiece"]
falcon = ["einops", "xformers", "safetensors"]

View File

@@ -26,7 +26,9 @@ deploy, and monitor any LLMs with ease.
from __future__ import annotations
import logging
import os
import typing as t
import warnings
from . import utils as utils
from .__about__ import __version__ as __version__
@@ -39,6 +41,24 @@ if utils.DEBUG:
utils.configure_logging()
logging.basicConfig(level=logging.NOTSET)
else:
# configuration for bitsandbytes before import
os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
# The following warnings from bitsandbytes, and probably not that important
# for users to see when DEBUG is False
warnings.filterwarnings(
"ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
)
warnings.filterwarnings(
"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
)
warnings.filterwarnings(
"ignore",
message=(
"The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization"
" are unavailable."
),
)
_import_structure = {

View File

File diff suppressed because it is too large Load Diff

View File

@@ -16,6 +16,7 @@ from __future__ import annotations
import copy
import functools
import inspect
import logging
import os
import re
@@ -29,34 +30,48 @@ from abc import abstractmethod
import attr
import inflection
import orjson
from huggingface_hub import hf_hub_download
import bentoml
import openllm
from bentoml._internal.models.model import ModelSignature
from bentoml._internal.types import ModelSignatureDict
from ._configuration import FineTuneConfig
from .exceptions import ForbiddenAttributeError
from .exceptions import GpuNotAvailableError
from .exceptions import OpenLLMException
from .utils import DEBUG
from .utils import LazyLoader
from .utils import ModelEnv
from .utils import ReprMixin
from .utils import bentoml_cattr
from .utils import first_not_none
from .utils import get_debug_mode
from .utils import is_bitsandbytes_available
from .utils import is_peft_available
from .utils import is_torch_available
from .utils import is_transformers_supports_kbit
from .utils import non_intrusive_setattr
from .utils import pkg
from .utils import requires_dependencies
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
import peft
import torch
import transformers
from bentoml._internal.runner.strategy import Strategy
from ._configuration import AdapterType
from .models.auto.factory import _BaseAutoLLMClass
class LLMRunner(bentoml.Runner):
@@ -74,6 +89,7 @@ else:
LLMRunner = bentoml.Runner
transformers = LazyLoader("transformers", globals(), "transformers")
torch = LazyLoader("torch", globals(), "torch")
peft = LazyLoader("peft", globals(), "peft")
logger = logging.getLogger(__name__)
@@ -96,6 +112,41 @@ def convert_transformers_model_name(name: str | None) -> str:
return re.sub("[^a-zA-Z0-9]+", "-", name)
# the below is similar to peft.utils.other.CONFIG_NAME
PEFT_CONFIG_NAME = "adapter_config.json"
def resolve_peft_config_type(adapter_map: dict[str, str | None] | None):
"""Resolve the type of the PeftConfig given the adapter_map.
This is similar to how PeftConfig resolve its config type.
"""
if adapter_map is None:
return
resolved: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] = {}
_has_set_default = False
for path_or_adapter_id, name in adapter_map.items():
if _has_set_default:
raise ValueError("Only one adapter can be set as default.")
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
else:
try:
config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME)
except Exception:
raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'")
with open(config_file, "r") as file:
resolved_config = orjson.loads(file.read())
# all peft_type should be available in PEFT_CONFIG_NAME
_peft_type: AdapterType = resolved_config["peft_type"].lower()
if _peft_type not in resolved:
resolved[_peft_type] = ()
resolved[_peft_type] += ((path_or_adapter_id, name, resolved_config),)
if name == "default":
_has_set_default = True
return resolved
def import_model(
model_id: str,
tag: bentoml.Tag,
@@ -179,10 +230,6 @@ def import_model(
try:
return bentoml.transformers.save_model(tag, model, custom_objects={"tokenizer": tokenizer})
finally:
import gc
gc.collect()
# NOTE: We need to free up the cache after importing the model
# in the case where users first run openllm start without the model
# available locally.
@@ -193,11 +240,12 @@ def import_model(
_reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"}
class LLMInterface(ABC):
"""This defines the loose contract for all openllm.LLM implementations."""
_M = t.TypeVar("_M")
_T = t.TypeVar("_T")
config_class: type[openllm.LLMConfig]
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
class LLMInterface(ABC, t.Generic[_M, _T]):
"""This defines the loose contract for all openllm.LLM implementations."""
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]] | None:
@@ -270,32 +318,76 @@ class LLMInterface(ABC):
example implementation."""
raise NotImplementedError
# NOTE: All fields below are attributes that can be accessed by users.
config_class: type[openllm.LLMConfig]
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
_M = t.TypeVar("_M")
_T = t.TypeVar("_T")
config: openllm.LLMConfig
"""The config instance to use for this LLM. This will be created based on config_class"""
bettertransformer: bool
"""Whether to load this LLM with FasterTransformer enabled. The order of loading is:
- If pass within `for_model`, `from_pretrained` or `__init__`.
- If `self.bettertransformer` is set within `llm_post_init`.
- Finally, if none of the above, default to self.config['bettertransformer']
> **Note** that if LoRA is enabled, bettertransformer will be disabled.
"""
# NOTE: The following will be populated by __init_subclass__, note that these should not
# be mutated by users.
__llm_trust_remote_code__: bool
"""This is used to determine during 'import_model' whether to trust remote code or not.
This works synonymous with `trust_remote_code` kwarg in transformers Auto classes. If not passed,
then by default fallback to config_class['trust_remote_code']
"""
__llm_implementation__: t.Literal["pt", "tf", "flax"]
"""This is used to determine which implementation that this LLM has. Usually, this will inferred from
class name, that follows the HuggingFace's naming convention:
- `OPTForConditionalGeneration` -> `pt`
- `TFOPTForConditionalGeneration` -> `tf`
- `FlaxOPTForConditionalGeneration` -> `flax`
"""
__llm_model__: _M | peft.PeftModel | torch.nn.Module | None
"""A reference to the actual model. Instead of access this directly, you should use `model` property instead."""
__llm_tokenizer__: _T | None
"""A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead."""
__llm_tag__: bentoml.Tag | None
"""A reference to the tag used for this LLM. Instead of access this directly, you should use `tag` property instead."""
__llm_bentomodel__: bentoml.Model | None
"""A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead."""
__llm_trainer__: transformers.Trainer | None
"""A reference to the Trainer to be used for this LLM to fine-tune."""
__llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], peft.PeftConfig]] | None
"""A reference to the the cached LoRA adapter mapping."""
__llm_post_init__: t.Callable[[t.Self], None] | None
"""A callable that will be called after the LLM is initialized. This is set if subclass contains a 'llm_post_init'"""
__llm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None
"""A callable that will be called after the model is loaded. This is set when 'load_model' is implemented"""
__llm_init_kwargs__: property | None
"""A check if 'import_kwargs' is implemented in subclass."""
# The following are internal, users shouldn't access this directly.
_model_args: tuple[t.Any, ...]
_model_attrs: dict[str, t.Any]
_tokenizer_attrs: dict[str, t.Any]
_adapters_mapping: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None
@attr.define(slots=True, repr=False)
class LLM(LLMInterface, t.Generic[_M, _T]):
if t.TYPE_CHECKING:
# The following will be populated by metaclass
__llm_trust_remote_code__: bool
__llm_implementation__: t.Literal["pt", "tf", "flax"]
__llm_model__: _M | None
__llm_tokenizer__: _T | None
__llm_tag__: bentoml.Tag | None
__llm_bentomodel__: bentoml.Model | None
__llm_post_init__: t.Callable[[t.Self], None] | None
__llm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None
__llm_init_kwargs__: property | None
_model_args: tuple[t.Any, ...]
_model_attrs: dict[str, t.Any]
_tokenizer_attrs: dict[str, t.Any]
bettertransformer: bool
class LLM(LLMInterface[_M, _T], ReprMixin):
def __init_subclass__(cls):
cd = cls.__dict__
prefix_class_name_config = cls.__name__
@@ -321,19 +413,33 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
"Missing required key 'config_class'. Make sure to define it within the LLM subclass."
)
if cls.import_model is LLMInterface.import_model:
# using the default import model
if cls.import_model is LLMInterface[_M, _T].import_model:
# using the default import model if no custom import is set
setattr(cls, "import_model", functools.partial(import_model, _model_framework=implementation))
else:
logger.debug("Custom 'import_model' will be used when loading model %s", cls.__name__)
cls.__llm_post_init__ = None if cls.llm_post_init is LLMInterface.llm_post_init else cls.llm_post_init
cls.__llm_custom_load__ = None if cls.load_model is LLMInterface.load_model else cls.load_model
cls.__llm_init_kwargs__ = None if cls.import_kwargs is LLMInterface.import_kwargs else cls.import_kwargs
cls.__llm_post_init__ = None if cls.llm_post_init is LLMInterface[_M, _T].llm_post_init else cls.llm_post_init
cls.__llm_custom_load__ = None if cls.load_model is LLMInterface[_M, _T].load_model else cls.load_model
cls.__llm_init_kwargs__ = (
None if cls.import_kwargs is LLMInterface[_M, _T].import_kwargs else cls.import_kwargs
)
for at in {"bentomodel", "tag", "model", "tokenizer"}:
for at in {"bentomodel", "tag", "model", "tokenizer", "adapter_map", "trainer"}:
setattr(cls, f"__llm_{at}__", None)
# update docstring for given entrypoint
for fn in {"generate", "generate_one", "generate_iterator"}:
original_fn = getattr(cls, fn, getattr(LLMInterface, fn))
original_fn.__doc__ = (
original_fn.__doc__
or f"""\
'{fn}' implementation {cls.__name__}.
Note that if LoRA is enabled (via either SDK or CLI), `self.model` will become a `peft.PeftModel`
The original can then be accessed with 'self.model.get_base_model()'.
"""
)
setattr(cls, fn, original_fn)
# The following is the similar interface to HuggingFace pretrained protocol.
@classmethod
def from_pretrained(
@@ -343,14 +449,146 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
*args: t.Any,
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
bettertransformer: bool | None = None,
adapter_id: str | None = None,
adapter_name: str | None = None,
adapter_map: dict[str, str | None] | None = None,
**attrs: t.Any,
) -> LLM[_M, _T]:
"""Instantiate a pretrained LLM.
it follows the same design principle as HuggingFace's `from_pretrained` method, plus the following:
Optimization options:
> This is most notable during serving time.
- quantize: quantize the model with the given quantization method. Currently supported int8, int4 quantization
- bettertransformer: Apply FasterTransformer to given pretrained weight
> Currently, the above two options are mutually exclusive.
Adapter options:
> This is used in conjunction with the fine-tuning features
- adapter_id: Optional [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to apply to said model.
- adapter_name: Optional name of the adapter to apply to said model. If not provided, it will be handled internally by OpenLLM.
- adapter_map: optional dictionary of adapter_id to adapter_name. Note that this is mutually exclusive with adapter_id/adapter_name arguments.
Args:
model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used.
llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM
will use `config_class` to construct default configuration.
quantize: The quantization to use for this LLM. Defaults to None. Possible values
include int8, int4 and gptq.
bettertransformer: Whether to use BetterTransformer with this model. Defaults to False.
adapter_id: The [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to use for this LLM. Defaults to None.
adapter_name: The adapter name to use for this LLM. Defaults to None.
adapter_map: The adapter map to use for this LLM. Defaults to None. Note that this is mutually exclusive with adapter_id/adapter_name arguments.
*args: The args to be passed to the model.
**attrs: The kwargs to be passed to the model.
"""
quantization_config = attrs.pop("quantization_config", None)
if quantization_config and quantize:
raise ValueError(
"""'quantization_config' and 'quantize' are mutually exclusive. Either customise
your quantization_config or use the quantize argument."""
)
# quantization setup
quantization_config = attrs.pop("quantization_config", None)
# 8 bit configuration
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
cpu_offloading = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
# 4 bit configuration
int4_compute_dtype = attrs.pop("llm_bnb_4bit_compute_dtype", torch.bfloat16)
int4_quant_type = attrs.pop("llm_bnb_4bit_quant_type", "nf4")
int4_use_double_quant = attrs.pop("llm_bnb_4bit_use_double_quant", True)
# NOTE: Quantization setup
if quantization_config is None:
# quantize is a openllm.LLM feature, where we can quantize the model
# with bitsandbytes or quantization aware training.
if quantize is not None:
if not is_bitsandbytes_available():
raise RuntimeError(
"Quantization requires bitsandbytes to be installed. Make "
"sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
)
logger.debug(
"'quantize' is not None. %s will use a default 'quantization_config' for %s. "
"If you want to customise the quantization config, make sure to pass your "
"own 'quantization_config'",
cls.__name__,
quantize,
)
if quantize == "int8":
if int8_skip_modules is None:
int8_skip_modules = []
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
int8_skip_modules.append("lm_head")
quantization_config = transformers.BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=cpu_offloading,
llm_int8_threshhold=int8_threshold,
llm_int8_skip_modules=int8_skip_modules,
llm_int8_has_fp16_weight=int8_has_fp16_weight,
)
elif quantize == "int4":
if is_transformers_supports_kbit():
quantization_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
llm_bnb_4bit_compute_dtype=int4_compute_dtype,
llm_bnb_4bit_quant_type=int4_quant_type,
llm_bnb_4bit_use_double_quant=int4_use_double_quant,
)
else:
logger.warning(
"'quantize' is set to int4, while the current transformers version %s does not support "
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
"make sure to install the latest version of transformers either via PyPI or "
"from git source: 'pip install git+https://github.com/huggingface/transformers'.",
pkg.pkg_version_info("transformers"),
)
elif quantize == "gptq":
# TODO: support GPTQ loading quantization
raise NotImplementedError("GPTQ is not supported yet.")
if model_id is None:
raise RuntimeError(
"'quantize=%s' requires passing custom path to quantized weights as we are unable to load "
"the model on the fly. See https://github.com/qwopqwop200/GPTQ-for-LLaMa for "
"instruction on how to quantize '%s' with GPTQ.",
quantize,
cls.__name__,
)
else:
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantize} instead.")
# NOTE: Fine-tuning setup
if adapter_map and adapter_id:
raise ValueError(
"""'adapter_map' and 'adapter_id' are mutually exclusive. Either provide a
'adapter_map' ({adapter_id: adapter_name | None, ...}) or use
the combination of adapter_id/adapter_name arguments.
"""
)
if adapter_map is None and adapter_id is not None:
adapter_map = {adapter_id: adapter_name}
if adapter_map is not None and not is_peft_available():
raise RuntimeError(
"LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
)
return cls(
model_id=model_id,
llm_config=llm_config,
*args,
quantize=quantize,
bettertransformer=bettertransformer,
_adapters_mapping=resolve_peft_config_type(adapter_map),
quantization_config=quantization_config,
**attrs,
)
@@ -359,12 +597,17 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
model_id: str | None = None,
llm_config: openllm.LLMConfig | None = None,
*args: t.Any,
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
bettertransformer: bool | None = None,
_adapters_mapping: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]]
| None = None,
**attrs: t.Any,
):
"""Initialize the LLM with given pretrained model.
> **Warning**
> To initializing any LLM, you should use `openllm.AutoLLM` or `openllm.LLM.from_pretrained` instead.
> `__init__` initialization is only for internal use.
Note:
- *args to be passed to the model.
- **attrs will first be parsed to the AutoConfig, then the rest will be parsed to the import_model
@@ -437,8 +680,6 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used.
llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM
will use `config_class` to construct default configuration.
quantize: The quantization to use for this LLM. Defaults to None. Possible values
include int8, int4 and gptq.
bettertransformer: Whether to use BetterTransformer with this model. Defaults to False.
*args: The args to be passed to the model.
**attrs: The kwargs to be passed to the model.
@@ -454,18 +695,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
# low_cpu_mem_usage is only available for model
# this is helpful on system with low memory to avoid OOM
low_cpu_mem_usage = attrs.pop("low_cpu_mem_usage", True)
# quantization setup
quantization_config = attrs.pop("quantization_config", None)
# 8 bit configuration
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
cpu_offloading = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
# 4 bit configuration
int4_compute_dtype = attrs.pop("llm_bnb_4bit_compute_dtype", torch.bfloat16)
int4_quant_type = attrs.pop("llm_bnb_4bit_quant_type", "nf4")
int4_use_double_quant = attrs.pop("llm_bnb_4bit_use_double_quant", True)
if llm_config is not None:
logger.debug("Using provided LLMConfig to initialize LLM instead of from default: %r", llm_config)
@@ -475,69 +705,10 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
# The rests of the kwargs that is not used by the config class should be stored into __openllm_extras__.
attrs = self.config["extras"]
if quantization_config and quantize:
raise ValueError(
"""'quantization_config' and 'quantize' are mutually exclusive. Either customise
your quantization_config or use the quantize argument."""
)
if quantization_config is None:
# quantize is a openllm.LLM feature, where we can quantize the model
# with bitsandbytes or quantization aware training.
if quantize is not None:
if not is_bitsandbytes_available():
raise RuntimeError(
"Quantization requires bitsandbytes to be installed. Make "
"sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
)
logger.debug(
"'quantize' is not None. %s will use a default 'quantization_config' for %s. "
"If you want to customise the quantization config, make sure to pass your "
"own 'quantization_config'",
self,
quantize,
)
if quantize == "int8":
if int8_skip_modules is None:
int8_skip_modules = []
if "lm_head" not in int8_skip_modules and self.config["model_type"] == "causal_lm":
logger.debug("Skipping 'lm_head' for quantization for %s", self)
int8_skip_modules.append("lm_head")
quantization_config = transformers.BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=cpu_offloading,
llm_int8_threshhold=int8_threshold,
llm_int8_skip_modules=int8_skip_modules,
llm_int8_has_fp16_weight=int8_has_fp16_weight,
)
elif quantize == "int4":
if is_transformers_supports_kbit():
quantization_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
llm_bnb_4bit_compute_dtype=int4_compute_dtype,
llm_bnb_4bit_quant_type=int4_quant_type,
llm_bnb_4bit_use_double_quant=int4_use_double_quant,
)
else:
logger.warning(
"'quantize' is set to int4, while the current transformers version %s does not support "
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
"make sure to install the latest version of transformers either via PyPI or "
"from git source: 'pip install git+https://github.com/huggingface/transformers'.",
pkg.pkg_version_info("transformers"),
)
elif quantize == "gptq":
# TODO: support GPTQ loading quantization
if model_id is None:
raise RuntimeError(
"'quantize=%s' requires passing custom path to quantized weights as we are unable to load "
"the model on the fly. See https://github.com/qwopqwop200/GPTQ-for-LLaMa for "
"instruction on how to quantize '%s' with GPTQ.",
quantize,
self,
)
raise NotImplementedError("GPTQ is not supported yet.")
else:
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantize} instead.")
if self.config["use_pipeline"] and _adapters_mapping:
raise ValueError(f"{self} will be used as a Pipeline, which is not yet compatible with LoRA adapter.")
self._adapters_mapping = _adapters_mapping
if self.__llm_implementation__ == "pt":
if not self.config["use_pipeline"]:
@@ -546,10 +717,8 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
model_kwds, tokenizer_kwds = {}, {}
if self.__llm_init_kwargs__:
if t.TYPE_CHECKING:
# the above meta value should determine that this LLM has custom kwargs
assert self.import_kwargs
model_kwds, tokenizer_kwds = self.import_kwargs
# NOTE: recast here for type safety
model_kwds, tokenizer_kwds = t.cast("tuple[dict[str, t.Any], dict[str, t.Any]]", self.__llm_init_kwargs__)
logger.debug(
"'%s' default kwargs for model: '%s', tokenizer: '%s'",
self.__class__.__name__,
@@ -564,7 +733,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
assert model_id is not None
self._model_id = model_id
# parsing tokenizer and model kwargs
# parsing tokenizer and model kwargs, as the hierachy is param pass > default
tokenizer_kwds.update(
{k[len(TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(TOKENIZER_PREFIX)}
)
@@ -588,6 +757,10 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
self.bettertransformer = bettertransformer
else:
non_intrusive_setattr(self, "bettertransformer", self.config["bettertransformer"])
# If lora is passed, the disable bettertransformer
if _adapters_mapping and self.bettertransformer is True:
logger.debug("LoRA is visible for %s, disabling BetterTransformer", self)
self.bettertransformer = False
def __setattr__(self, attr: str, value: t.Any):
if attr in _reserved_namespace:
@@ -599,9 +772,21 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
super().__setattr__(attr, value)
def __repr__(self) -> str:
keys = {"model_id", "runner_name", "llm_type", "config"}
return f"{self.__class__.__name__}({', '.join(f'{k}={getattr(self, k)!r}' for k in keys)})"
@property
def adapters_mapping(
self,
) -> dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None:
return self._adapters_mapping
@adapters_mapping.setter
def adapters_mapping(
self, value: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None
):
self._adapters_mapping = value
@property
def __repr_keys__(self) -> set[str]:
return {"model_id", "runner_name", "config"}
@property
def model_id(self) -> str:
@@ -712,9 +897,9 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
@property
def _bentomodel(self) -> bentoml.Model:
if self.__llm_bentomodel__ is None:
# NOTE: Since PR#28, self.__llm_bentomodel__ changed from
# NOTE: Since #28, self.__llm_bentomodel__ changed from
# ensure_model_id_exists() into just returning the model ref.
# This is because we want to save a few seconds of loading time,
# This is purely a performance reason.
# as openllm.Runner and openllm.AutoLLM initialisation is around 700ms
# before #28.
# If users want to make sure to have the model downloaded,
@@ -741,23 +926,24 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
if self.config["requires_gpu"] and len(openllm.utils.gpu_count()) < 1:
raise GpuNotAvailableError(f"{self} only supports running with GPU (None available).") from None
kwds = self._model_attrs
kwds["trust_remote_code"] = self.__llm_trust_remote_code__
is_pipeline = "_pretrained_class" in self._bentomodel.info.metadata
# differentiate when saving tokenizer or other pretrained type.
is_pretrained_model = is_pipeline and "_framework" in self._bentomodel.info.metadata
if self.bettertransformer and is_pipeline and self.config["use_pipeline"]:
# This is a pipeline, provide a accelerator args
kwds["accelerator"] = "bettertransformer"
if self.__llm_model__ is None:
kwds = self._model_attrs
kwds["trust_remote_code"] = self.__llm_trust_remote_code__
is_pipeline = "_pretrained_class" in self._bentomodel.info.metadata
# differentiate when saving tokenizer or other pretrained type.
is_pretrained_model = is_pipeline and "_framework" in self._bentomodel.info.metadata
if self.bettertransformer and is_pipeline and self.config["use_pipeline"]:
# This is a pipeline, provide a accelerator args
kwds["accelerator"] = "bettertransformer"
if self.__llm_custom_load__:
self.__llm_model__ = self.load_model(self.tag, *self._model_args, **kwds)
else:
self.__llm_model__ = self._bentomodel.load_model(*self._model_args, **kwds)
# This branch shouldn't hit when LoRA is visible.
if (
self.bettertransformer
and is_pretrained_model
@@ -770,6 +956,143 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
self.__llm_model__ = BetterTransformer.transform(self.__llm_model__)
return t.cast(_M, self.__llm_model__)
def _transpose_adapter_mapping(
self,
inference_mode: bool = True,
use_cache: bool = True,
) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]]:
assert self._adapters_mapping is not None, "LoRA mapping is not set up correctly."
if not use_cache:
logger.debug(
"'use_cache' is set to False. This means the adapter mapping resolution will not be cached. This should only be used during training."
)
if self.__llm_adapter_map__ is not None and use_cache:
# early out if we already serialized everything.
return self.__llm_adapter_map__
adapter_map: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] = {}
# this is a temporary check to accept the first option name as 'default'
# then we will raise Error when the optional_name is set to None in next iteration.
_converted_first_none = False
for _adapter_type, _adapter_tuple in self._adapters_mapping.items():
if _adapter_type not in adapter_map:
adapter_map[_adapter_type] = {}
default_config = self.config["fine_tune_strategies"].get(
_adapter_type, FineTuneConfig(adapter_type=_adapter_type, llm_config_class=self.config_class)
)
default_config = default_config.eval() if inference_mode else default_config.train()
for pretrained_or_peft_id, optional_name, resolved_mapping in _adapter_tuple:
if not optional_name:
if not _converted_first_none:
_converted_first_none = True
optional_name = "default"
else:
raise ValueError(
f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {pretrained_or_peft_id, resolved_mapping}"
)
assert isinstance(optional_name, str) # optional_name should all be resolved here
if optional_name == "default":
adapter_map[_adapter_type][optional_name] = (
default_config.with_config(**resolved_mapping).to_peft_config(),
pretrained_or_peft_id,
)
else:
adapter_map[_adapter_type][optional_name] = (
FineTuneConfig(
adapter_type=_adapter_type,
adapter_config=resolved_mapping,
inference_mode=inference_mode,
llm_config_class=self.config_class,
).to_peft_config(),
pretrained_or_peft_id,
)
if self.__llm_adapter_map__ is None and use_cache:
self.__llm_adapter_map__ = adapter_map
return self.__llm_adapter_map__
return adapter_map
@requires_dependencies("peft", extra="fine-tune")
def apply_adapter(
self,
inference_mode: bool = True,
adapter_type: AdapterType = "lora",
load_adapters: t.Literal["all"] | list[str] | None = None,
use_cache: bool = True,
) -> peft.PeftModel | _M | torch.nn.Module:
"""Apply given LoRA mapping to the model. Note that the base model can still
be accessed via self.model.get_base_model().
"""
assert self.model, "Internal error: Model is not loaded correctly."
assert self.__llm_model__ is not None
# early out if _adapters_mapping is empty or it is already wrapped
# with peft.
if not self._adapters_mapping:
logger.debug("No adapter mapping is found. Skip applying adapter.")
return self.__llm_model__
_mapping = self._transpose_adapter_mapping(inference_mode=inference_mode, use_cache=use_cache)
if adapter_type not in _mapping:
raise ValueError(
f"Given adapter type {adapter_type} is not supported. Please choose from {list(_mapping.keys())}"
)
adapter_mapping = _mapping[adapter_type]
default_config, peft_model_id = adapter_mapping.pop("default", None)
if default_config is None:
raise ValueError(
"There is no 'default' mapping. Please check the adapter mapping and report this bug to the OpenLLM team."
)
# the below shared similar logics with `get_peft_model`
# TODO: Support PromptLearningConfig
if default_config.task_type not in peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
default_config, peft.PromptLearningConfig
):
logger.debug(
"Given task type '%s' is not supported by peft. This means it can be a custom PeftModel implementation. Make sure the adapter is loaded manually before running inference.",
default_config.task_type,
)
self.__llm_model__ = peft.PeftModel(self.__llm_model__, default_config)
else:
# this is not ideal to serialize like this, wait until https://github.com/huggingface/peft/pull/612
# is merged
peft_class = peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING[default_config.task_type]
if t.cast("str | None", default_config.base_model_name_or_path) is not None:
kwargs: dict[str, t.Any] = {"is_trainable": not inference_mode}
if "config" in inspect.signature(peft_class.from_pretrained).parameters:
kwargs["config"] = default_config
else:
kwargs.update(dict(default_config.to_dict().items()))
self.__llm_model__ = peft_class.from_pretrained(self.__llm_model__, peft_model_id, **kwargs)
else:
# in this case, the given base_model_name_or_path is None. This will be hit during training
self.__llm_model__ = peft_class(self.__llm_model__, default_config)
# now we loop through the rest with add_adapter
if len(adapter_mapping) > 0:
for adapter_name, _peft_config in adapter_mapping.items():
self.__llm_model__.add_adapter(adapter_name, _peft_config)
# optionally load adapters. In case of multiple adapters, or on Runner,
# we will need to set load_adapters='all'
if load_adapters is not None:
adapters_to_load = adapter_mapping.keys() if load_adapters == "all" else load_adapters
for adapter_name in adapters_to_load:
_peft_config, _peft_model_id = adapter_mapping[adapter_name]
self.__llm_model__.load_adapter(
_peft_model_id,
adapter_name=adapter_name,
is_trainable=not inference_mode,
**dict(_peft_config.to_dict()),
)
return self.__llm_model__
@property
def tokenizer(self) -> _T:
"""The tokenizer to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it."""
@@ -837,18 +1160,45 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
llm_type: str
identifying_params: dict[str, t.Any]
def __init_subclass__(cls, llm_type: str, identifying_params: dict[str, t.Any], **_: t.Any):
cls.llm_type = llm_type
cls.identifying_params = identifying_params
def __init__(__self: _Runnable):
# NOTE: The side effect of this line
# is that it will load the imported model during
# runner creation. So don't remove it!!
self.model
# runner startup. So don't remove it!!
assert self.model, "Internal error: Model is not loaded"
if self.adapters_mapping is not None:
logger.info("Applying LoRA to %s...", self.runner_name)
self.apply_adapter(inference_mode=True, load_adapters="all")
@bentoml.Runnable.method(batchable=False)
def list_adapter(__self) -> dict[str, t.Any]:
if not is_peft_available():
return {
"success": False,
"result": {},
"error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'",
}
if not isinstance(self.model, peft.PeftModel):
return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"}
return {"success": True, "result": self.model.peft_config, "error_msg": ""}
@bentoml.Runnable.method(batchable=False)
def set_adapter(__self, adapter_name: str) -> dict[t.Literal["success", "error_msg"], bool | str]:
if not is_peft_available():
return {
"success": False,
"error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'",
}
if not isinstance(self.model, peft.PeftModel):
return {"success": False, "error_msg": "Model is not a PeftModel"}
try:
self.model.set_adapter(adapter_name)
return {"success": True, "error_msg": ""}
except ValueError:
logger.info("Adapter %s not found", adapter_name)
return {
"success": False,
"error_msg": f"Adapter {adapter_name} not found. Available adapters: {list(self.model.peft_config)}",
}
@bentoml.Runnable.method(
batchable=generate_sig.batchable,
@@ -916,19 +1266,21 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
"__call__": _wrapped_generate_run,
"__module__": f"openllm.models.{self.config['model_name']}",
"__doc__": self.config["env"].start_docstring,
"__repr_keys__": lambda _: {"llm", "config", "llm_type", "identifying_params"},
}
),
)(
types.new_class(
inflection.camelize(self.config["model_name"]) + "Runnable",
(_Runnable,),
{
"SUPPORTED_RESOURCES": ("nvidia.com/gpu", "cpu")
if self.config["requires_gpu"]
else ("nvidia.com/gpu",),
"llm_type": self.llm_type,
"identifying_params": self.identifying_params,
},
{},
lambda ns: ns.update(
{
"SUPPORTED_RESOURCES": ("nvidia.com/gpu", "cpu")
if self.config["requires_gpu"]
else ("nvidia.com/gpu",),
}
),
),
name=self.runner_name,
embedded=False,
@@ -962,7 +1314,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
return self.postprocess_generate(prompt, generated_result, **postprocess_kwargs)
@t.overload
@overload
def Runner(
model_name: str,
*,
@@ -973,7 +1325,7 @@ def Runner(
...
@t.overload
@overload
def Runner(
model_name: str,
*,

View File

@@ -23,7 +23,9 @@ import typing as t
from pathlib import Path
import fs
import fs.copy
import inflection
import orjson
import bentoml
import openllm
@@ -38,8 +40,16 @@ from .utils import is_flax_available
from .utils import is_tf_available
from .utils import is_torch_available
from .utils import pkg
from .utils import resolve_user_filepath
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
from fs.base import FS
@@ -83,19 +93,23 @@ def construct_python_options(
llm: openllm.LLM[t.Any, t.Any],
llm_fs: FS,
extra_dependencies: tuple[str, ...] | None = None,
adapter_map: dict[str, str | None] | None = None,
) -> PythonOptions:
packages = ["openllm"]
if adapter_map is not None:
packages += ["openllm[fine-tune]"]
# NOTE: add openllm to the default dependencies
# if users has openllm custom built wheels, it will still respect
# that since bentoml will always install dependencies from requirements.txt
# first, then proceed to install everything inside the wheels/ folder.
if extra_dependencies is not None:
packages += [f"openllm[{k}]" for k in extra_dependencies]
filtered = set(extra_dependencies + ("fine-tune",))
packages += [f"openllm[{k}]" for k in filtered]
if llm.config["requirements"] is not None:
packages.extend(llm.config["requirements"])
if not (str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false"):
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false":
packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
env: ModelEnv = llm.config["env"]
@@ -149,6 +163,7 @@ def construct_docker_options(
workers_per_resource: int | float,
quantize: t.LiteralString | None,
bettertransformer: bool | None,
adapter_map: dict[str, str | None] | None,
) -> DockerOptions:
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options_opts = [
@@ -164,10 +179,14 @@ def construct_docker_options(
env.config: f"'{llm.config.model_dump_json().decode()}'",
"OPENLLM_MODEL": llm.config["model_name"],
"OPENLLM_MODEL_ID": llm.model_id,
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'",
"BENTOML_DEBUG": str(get_debug_mode()),
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options,
}
if adapter_map:
env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
# We need to handle None separately here, as env from subprocess doesn't
# accept None value.
_env = ModelEnv(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize)
@@ -181,38 +200,6 @@ def construct_docker_options(
return DockerOptions(cuda_version="11.6", env=env_dict, system_packages=["git"])
@t.overload
def build(
model_name: str,
*,
model_id: str | None = ...,
quantize: t.LiteralString | None = ...,
bettertransformer: bool | None = ...,
_extra_dependencies: tuple[str, ...] | None = ...,
_workers_per_resource: int | float | None = ...,
_overwrite_existing_bento: bool = ...,
__cli__: t.Literal[False] = ...,
**attrs: t.Any,
) -> bentoml.Bento:
...
@t.overload
def build(
model_name: str,
*,
model_id: str | None = ...,
quantize: t.LiteralString | None = ...,
bettertransformer: bool | None = ...,
_extra_dependencies: tuple[str, ...] | None = ...,
_workers_per_resource: int | float | None = ...,
_overwrite_existing_bento: bool = ...,
__cli__: t.Literal[True] = ...,
**attrs: t.Any,
) -> tuple[bentoml.Bento, bool]:
...
def _build_bento(
bento_tag: bentoml.Tag,
service_name: str,
@@ -221,34 +208,87 @@ def _build_bento(
workers_per_resource: int | float,
quantize: t.LiteralString | None,
bettertransformer: bool | None,
adapter_map: dict[str, str | None] | None = None,
extra_dependencies: tuple[str, ...] | None = None,
build_ctx: str | None = None,
) -> bentoml.Bento:
framework_envvar = llm.config["env"]["framework_value"]
labels = dict(llm.identifying_params)
labels.update({"_type": llm.llm_type, "_framework": framework_envvar})
logger.info("Building Bento for LLM '%s'", llm.config["start_name"])
if adapter_map is not None:
assert build_ctx is not None, "build_ctx is required when 'adapter_map' is not None"
updated_mapping: dict[str, str | None] = {}
for adapter_id, name in adapter_map.items():
try:
resolve_user_filepath(adapter_id, build_ctx)
src_folder_name = os.path.basename(adapter_id)
src_fs = fs.open_fs(build_ctx)
llm_fs.makedir(src_folder_name, recreate=True)
fs.copy.copy_dir(src_fs, adapter_id, llm_fs, src_folder_name)
updated_mapping[src_folder_name] = name
except FileNotFoundError:
# this is the remote adapter, then just added back
# note that there is a drawback here. If the path of the local adapter
# path have the same name as the remote, then we currently don't support
# that edge case.
updated_mapping[adapter_id] = name
adapter_map = updated_mapping
# add service.py definition to this temporary folder
codegen.write_service(llm.config["model_name"], llm.model_id, adapter_map, llm.config["service_name"], llm_fs)
return bentoml.bentos.build(
f"{service_name}:svc",
name=bento_tag.name,
labels=labels,
description=f"OpenLLM service for {llm.config['start_name']}",
include=list(
llm_fs.walk.files(filter=["*.py"])
), # NOTE: By default, we are using _service.py as the default service, for now.
exclude=["/venv", "__pycache__/", "*.py[cod]", "*$py.class"],
python=construct_python_options(llm, llm_fs, extra_dependencies),
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer),
include=list(llm_fs.walk.files()),
exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"],
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map),
version=bento_tag.version,
build_ctx=llm_fs.getsyspath("/"),
)
@overload
def build(
model_name: str,
*,
model_id: str | None = ...,
quantize: t.LiteralString | None = ...,
bettertransformer: bool | None = ...,
adapter_map: dict[str, str | None] | None = ...,
__cli__: t.Literal[False] = False,
**attrs: t.Any,
) -> bentoml.Bento:
...
@overload
def build(
model_name: str,
*,
model_id: str | None = ...,
quantize: t.LiteralString | None = ...,
bettertransformer: bool | None = ...,
adapter_map: dict[str, str | None] | None = ...,
__cli__: t.Literal[True] = ...,
**attrs: t.Any,
) -> tuple[bentoml.Bento, bool]:
...
def build(
model_name: str,
*,
model_id: str | None = None,
quantize: t.LiteralString | None = None,
bettertransformer: bool | None = None,
adapter_map: dict[str, str | None] | None = None,
_build_ctx: str | None = None,
_extra_dependencies: tuple[str, ...] | None = None,
_workers_per_resource: int | float | None = None,
_overwrite_existing_bento: bool = False,
@@ -270,14 +310,14 @@ def build(
llm_config = openllm.AutoConfig.for_model(model_name)
logger.info("Packing '%s' into a Bento with kwargs=%s...", model_name, attrs)
logger.info("Packing '%s' into a Bento%s...", model_name, f" with 'kwargs={attrs}' " if attrs else "")
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
try:
os.environ["OPENLLM_MODEL"] = inflection.underscore(model_name)
framework_envvar = llm_config["env"]["framework_value"]
framework_envvar = llm_config["env"].framework_value
llm = t.cast(
"_BaseAutoLLMClass",
openllm[framework_envvar], # type: ignore (internal API)
@@ -286,7 +326,9 @@ def build(
model_id=model_id,
llm_config=llm_config,
quantize=quantize,
adapter_map=adapter_map,
bettertransformer=bettertransformer,
return_runner_kwargs=False,
**attrs,
)
@@ -294,41 +336,40 @@ def build(
labels = dict(llm.identifying_params)
labels.update({"_type": llm.llm_type, "_framework": framework_envvar})
service_name = f"generated_{llm_config['model_name']}_service.py"
workers_per_resource = first_not_none(_workers_per_resource, default=llm_config["workers_per_resource"])
with fs.open_fs(f"temp://llm_{llm_config['model_name']}") as llm_fs:
# add service.py definition to this temporary folder
codegen.write_service(model_name, llm.model_id, service_name, llm_fs)
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
try:
bento = bentoml.get(bento_tag)
if _overwrite_existing_bento:
logger.info("Overwriting previously saved Bento.")
bentoml.delete(bento_tag)
bento = _build_bento(
bento_tag,
service_name,
llm.config["service_name"],
llm_fs,
llm,
workers_per_resource=workers_per_resource,
adapter_map=adapter_map,
quantize=quantize,
bettertransformer=bettertransformer,
extra_dependencies=_extra_dependencies,
build_ctx=_build_ctx,
)
_previously_built = True
except bentoml.exceptions.NotFound:
logger.info("Building Bento for LLM '%s'", llm_config["start_name"])
bento = _build_bento(
bento_tag,
service_name,
llm.config["service_name"],
llm_fs,
llm,
workers_per_resource=workers_per_resource,
adapter_map=adapter_map,
quantize=quantize,
bettertransformer=bettertransformer,
extra_dependencies=_extra_dependencies,
build_ctx=_build_ctx,
)
return (bento, _previously_built) if __cli__ else bento
except Exception as e:

View File

@@ -25,6 +25,7 @@ from __future__ import annotations
import os
import typing as t
import warnings
import attr
import orjson
@@ -40,8 +41,25 @@ if t.TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
# The following warnings from bitsandbytes, and probably not that important
# for users to see
warnings.filterwarnings(
"ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
)
warnings.filterwarnings(
"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
)
warnings.filterwarnings(
"ignore",
message=(
"The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization"
" are unavailable."
),
)
model = os.environ.get("OPENLLM_MODEL", "{__model_name__}") # openllm: model name
model_id = os.environ.get("OPENLLM_MODEL_ID", "{__model_id__}") # openllm: model id
adapter_map = os.environ.get("OPENLLM_ADAPTER_MAP", """{__model_adapter_map__}""") # openllm: model adapter map
llm_config = openllm.AutoConfig.for_model(model)
@@ -51,6 +69,7 @@ runner = openllm.Runner(
llm_config=llm_config,
bettertransformer=llm_config["env"]["bettertransformer_value"],
quantize=llm_config["env"]["quantize_value"],
adapter_map=orjson.loads(adapter_map),
ensure_available=False,
init_local=False,
)
@@ -70,7 +89,19 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
return openllm.GenerationOutput(responses=responses, configuration=config)
@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON(), route="/v1/metadata")
@svc.api(
input=bentoml.io.Text(),
output=bentoml.io.JSON.from_sample(
sample={
"model_id": model_id,
"timeout": 3600,
"model_name": llm_config["model_name"],
"framework": "pt",
"configuration": "",
}
),
route="/v1/metadata",
)
def metadata_v1(_: str) -> openllm.MetadataOutput:
return openllm.MetadataOutput(
model_id=model_id,
@@ -81,6 +112,15 @@ def metadata_v1(_: str) -> openllm.MetadataOutput:
)
@svc.api(
input=bentoml.io.Text.from_sample(sample="default"),
output=bentoml.io.JSON.from_sample(sample={"success": True, "error_msg": "some error message"}),
route="/v1/adapters",
)
async def adapters_v1(adapter_name: str) -> dict[str, bool | str]:
return await runner.set_adapter.async_run(adapter_name)
@attr.define
class HfAgentInput:
inputs: str
@@ -105,3 +145,14 @@ async def hf_agent(request: Request) -> Response:
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
svc.mount_asgi_app(hf_app, path="/hf")
async def list_adapter_v1(_: Request) -> Response:
res = await runner.list_adapter.async_run()
if res["success"]:
res["result"] = {k: v.to_dict() for k, v in res["result"].items()}
return JSONResponse(res, status_code=200)
metadata_app = Starlette(debug=True, routes=[Route("/adapters", list_adapter_v1, methods=["GET"])])
svc.mount_asgi_app(metadata_app, path="/v1")

25
src/openllm/_trainer.py Normal file
View File

@@ -0,0 +1,25 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import transformers
class PeftTrainer(transformers.Trainer):
...
class PeftSaveCallback(transformers.TrainerCallback):
...

View File

@@ -56,8 +56,10 @@ from .utils import first_not_none
from .utils import get_debug_mode
from .utils import get_quiet_mode
from .utils import gpu_count
from .utils import is_peft_available
from .utils import is_torch_available
from .utils import is_transformers_supports_agent
from .utils import resolve_user_filepath
from .utils import set_debug_mode
from .utils import set_quiet_mode
@@ -105,13 +107,14 @@ class NargsOptions(cog.GroupedOption):
options.
"""
_nargs_parser: click.parser.Option
_prev_parser_process: t.Callable[[t.Any, click.parser.ParsingState], None]
def __init__(self, *args: t.Any, **attrs: t.Any):
nargs = attrs.pop("nargs", -1)
if nargs != -1:
raise OpenLLMException(f"'nargs' is set, and must be -1 instead of {nargs}")
super(NargsOptions, self).__init__(*args, **attrs)
self._prev_parser_process: t.Callable[[t.Any, click.parser.ParsingState], None] | None = None
self._nargs_parser: click.parser.Option | None = None
def add_to_parser(self, parser: click.OptionParser, ctx: click.Context) -> None:
def _parser(value: t.Any, state: click.parser.ParsingState):
@@ -238,7 +241,7 @@ def quantize_option(factory: t.Any, build: bool = False):
)
def bettertransformer_option(factory: t.Any, model_env: ModelEnv | None = None):
def bettertransformer_option(factory: t.Any, build: bool = False, model_env: ModelEnv | None = None):
envvar = None
if model_env is not None:
envvar = model_env.bettertransformer
@@ -246,14 +249,35 @@ def bettertransformer_option(factory: t.Any, model_env: ModelEnv | None = None):
"--bettertransformer",
is_flag=True,
default=None,
help="Use BetterTransformer wrapper to serve model. This will applies during serving time.",
help="Apply FasterTransformer wrapper to serve model. This will applies during serving time."
if not build
else "Set defaul environment variable whether to serve this model with FasterTransformer in build time.",
envvar=envvar,
show_envvar=True if envvar is not None else False,
)
_adapter_mapping_key = "adapter_map"
def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> dict[str, str] | None:
if not value:
return
if _adapter_mapping_key not in ctx.params:
ctx.params[_adapter_mapping_key] = {}
for v in value:
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
try:
# try to resolve the full path if users pass in relative,
# currently only support one level of resolve path.
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError:
pass
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
class OpenLLMCommandGroup(BentoMLCommandGroup):
NUMBER_OF_COMMON_PARAMS = 3
NUMBER_OF_COMMON_PARAMS = 4 # parameters in common_params + 1 faked group option header
@staticmethod
def common_params(f: F[P, t.Any]) -> ClickFunctionWrapper[..., t.Any]:
@@ -265,11 +289,14 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
from bentoml._internal.configuration import DEBUG_ENV_VAR
from bentoml._internal.configuration import QUIET_ENV_VAR
@click.option("-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output.")
@click.option(
@cog.optgroup.group("Miscellaneous options")
@cog.optgroup.option(
"-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output."
)
@cog.optgroup.option(
"--debug", "--verbose", envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help="Print out debug logs."
)
@click.option(
@cog.optgroup.option(
"--do-not-track",
is_flag=True,
default=False,
@@ -507,7 +534,7 @@ _http_server_args = parse_serve_args(False)
_grpc_server_args = parse_serve_args(True)
def start_model_command(
def start_command_factory(
model_name: str,
_context_settings: dict[str, t.Any] | None = None,
_serve_grpc: bool = False,
@@ -531,7 +558,7 @@ def start_model_command(
configure_logging()
llm_config = openllm.AutoConfig.for_model(model_name)
env: ModelEnv = llm_config["env"]
env = llm_config["env"]
docstring = f"""\
{env.start_docstring}
@@ -576,7 +603,7 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
@group.command(**command_attrs)
@llm_config.to_click_options
@serve_decorator
@cog.optgroup.group("General LLM Options")
@cog.optgroup.group("General LLM Options", help="The following options are related to running the LLM Server.")
@cog.optgroup.option(
"--server-timeout",
type=int,
@@ -591,7 +618,25 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
default=False,
help="Bypass auto model checks and setup. This option is ahead-of-serving time.",
)
@cog.optgroup.group("LLM Optimization Options.")
@cog.optgroup.group(
"LLM Optimization Options",
help="""\
These options are related for dynamic optimization on the fly. Current supported strategies:
- int8: Quantize the model with 8bit (bitsandbytes required)
- int4: Quantize the model with 4bit (bitsandbytes required)
- bettertransformer: Convert given model to FastTransformer
The following are currently being worked on:
- GPTQ: [paper](https://arxiv.org/abs/2210.17323)
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
""",
)
@cog.optgroup.option(
"--device",
type=tuple,
@@ -604,6 +649,32 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
)
@quantize_option(cog.optgroup)
@bettertransformer_option(cog.optgroup, model_env=env)
@cog.optgroup.group(
"Fine-tuning related options",
help="""\
Note that the argument `--adapter-id` can accept the following format:
- `--adapter-id /path/to/adapter` (local adapter)
- `--adapter-id remote/adapter` (remote adapter from HuggingFace Hub)
- `--adapter-id remote/adapter:eng_lora` (two previous adapter options with the given adapter_name)
```bash
openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora
```
""",
)
@cog.optgroup.option(
"--adapter-id",
default=None,
help="Optional name or path for given LoRA adapter" + f" to wrap '{model_name}'",
multiple=True,
callback=_id_callback,
metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]",
)
@click.pass_context
def model_start(
ctx: click.Context,
@@ -616,6 +687,10 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
fast: bool,
**attrs: t.Any,
) -> openllm.LLMConfig:
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
# remove adapter_id
attrs.pop("adapter_id", None)
config, server_attrs = llm_config.model_validate_click(**attrs)
# Create a new model env to work with the envvar during CLI invocation
@@ -631,6 +706,13 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
_echo("Quantization is currently only available for PyTorch models.", fg="red")
ctx.exit(1)
if adapter_map and not is_peft_available():
_echo(
"Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'",
fg="red",
)
ctx.exit(1)
# We need to handle None separately here, as env from subprocess doesn't
# accept None value.
env = ModelEnv(env.model_name, bettertransformer=bettertransformer, quantize=quantize)
@@ -686,22 +768,29 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
if fast and not get_quiet_mode():
_echo(
f"Make sure to download the model before 'start': 'openllm download {model_name}{'--model-id ' + model_id if model_id else ''}'",
f"Fast mode is enabled. Make sure to download the model before 'start': 'openllm download {model_name}{'--model-id ' + model_id if model_id else ''}'",
fg="yellow",
)
automodel_attrs = {
"model_id": model_id,
"llm_config": config,
"ensure_available": not fast,
}
automodel_attrs: dict[str, t.Any] = {}
if framework_envvar == "pt":
automodel_attrs.update({"quantize": quantize, "bettertransformer": bettertransformer})
if adapter_map:
_echo(f"OpenLLM will convert '{model_name}' to use provided adapters layers: {list(adapter_map)}")
automodel_attrs.update({"adapter_map": adapter_map})
llm = t.cast(
"_BaseAutoLLMClass",
openllm[framework_envvar], # type: ignore (internal API)
).for_model(model_name, **automodel_attrs)
).for_model(
model_name,
model_id=model_id,
llm_config=config,
ensure_available=not fast,
return_runner_kwargs=False,
**automodel_attrs,
)
start_env.update(
{
@@ -709,6 +798,7 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
env.config: llm.config.model_dump_json().decode(),
"OPENLLM_MODEL": model_name,
"OPENLLM_MODEL_ID": llm.model_id,
"OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map),
"BENTOML_DEBUG": str(get_debug_mode()),
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options_env,
"BENTOML_HOME": os.environ.get("BENTOML_HOME", BentoMLContainer.bentoml_home.get()),
@@ -746,9 +836,9 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
return model_start
_cached_http = {key: start_model_command(key, _context_settings=_CONTEXT_SETTINGS) for key in openllm.CONFIG_MAPPING}
_cached_http = {key: start_command_factory(key, _context_settings=_CONTEXT_SETTINGS) for key in openllm.CONFIG_MAPPING}
_cached_grpc = {
key: start_model_command(key, _context_settings=_CONTEXT_SETTINGS, _serve_grpc=True)
key: start_command_factory(key, _context_settings=_CONTEXT_SETTINGS, _serve_grpc=True)
for key in openllm.CONFIG_MAPPING
}
@@ -765,7 +855,7 @@ def _start(
if framework is not None:
os.environ[_ModelEnv.framework] = framework
start_model_command(model_name, _serve_grpc=_serve_grpc)(standalone_mode=False, **attrs)
start_command_factory(model_name, _serve_grpc=_serve_grpc)(standalone_mode=False, **attrs)
start = functools.partial(_start, _serve_grpc=False)
@@ -792,7 +882,17 @@ start_grpc = functools.partial(_start, _serve_grpc=True)
nargs=1,
metavar="FEATURE[,FEATURE]",
)
@click.option(
"--adapter-id",
default=None,
help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.",
multiple=True,
metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]",
)
@click.option("--build-ctx", default=".", help="Build context. This is required if --adapter-id uses relative path")
@click.pass_context
def build(
ctx: click.Context,
model_name: str,
model_id: str | None,
overwrite: bool,
@@ -801,6 +901,9 @@ def build(
enable_features: tuple[str] | None,
bettertransformer: bool | None,
workers_per_resource: float | None,
adapter_id: tuple[str, ...],
build_ctx: str | None,
**attrs: t.Any,
):
"""Package a given models into a Bento.
@@ -813,6 +916,20 @@ def build(
> NOTE: To run a container built from this Bento with GPU support, make sure
> to have https://github.com/NVIDIA/nvidia-container-toolkit install locally.
"""
adapter_map: dict[str, str | None] | None = None
if adapter_id:
if not build_ctx:
_echo("'build_ctx' must not be None when '--adapter-id' is passsed.", fg="red")
ctx.exit(1)
adapter_map = {}
for v in adapter_id:
_adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
# We don't resolve full path here, leave it to build
# we are just doing the parsing here.
adapter_map[_adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
if output == "porcelain":
set_quiet_mode(True)
configure_logging()
@@ -824,12 +941,15 @@ def build(
if enable_features:
enable_features = tuple(itertools.chain.from_iterable((s.split(",") for s in enable_features)))
# TODO: xxx
bento, _previously_built = openllm.build(
model_name,
__cli__=True,
model_id=model_id,
quantize=quantize,
bettertransformer=bettertransformer,
adapter_map=adapter_map,
_build_ctx=build_ctx,
_extra_dependencies=enable_features,
_workers_per_resource=workers_per_resource,
_overwrite_existing_bento=overwrite,
@@ -851,8 +971,8 @@ def build(
+ "* Push to BentoCloud with `bentoml push`:\n"
+ f" $ bentoml push {bento.tag}\n"
+ "* Containerize your Bento with `bentoml containerize`:\n"
+ f" $ bentoml containerize {bento.tag}\n"
+ " Tip: To enable additional BentoML feature for 'containerize', "
+ f" $ bentoml containerize {bento.tag}\n\n"
+ " Tip: To enable additional BentoML features for 'containerize', "
+ "use '--enable-features=FEATURE[,FEATURE]' "
+ "[see 'bentoml containerize -h' for more advanced usage]\n",
fg="blue",

View File

@@ -41,3 +41,7 @@ class MissingAnnotationAttributeError(OpenLLMException):
class MissingDependencyError(BaseException):
"""Raised when a dependency is missing."""
class FineTuneStrategyNotSupportedError(OpenLLMException):
"""Raised when a fine-tune strategy is not supported for given LLM."""

View File

@@ -27,6 +27,14 @@ import openllm
from .configuration_auto import AutoConfig
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
from collections import _odict_items
from collections import _odict_keys
@@ -54,7 +62,7 @@ class _BaseAutoLLMClass:
"Please use '{self.__class__.__name__}.Runner(model_name)' instead."
)
@t.overload
@overload
@classmethod
def for_model(
cls,
@@ -67,7 +75,7 @@ class _BaseAutoLLMClass:
) -> openllm.LLM[t.Any, t.Any]:
...
@t.overload
@overload
@classmethod
def for_model(
cls,
@@ -116,7 +124,7 @@ class _BaseAutoLLMClass:
llm = cls._model_mapping[type(llm_config)].from_pretrained(
model_id,
llm_config=llm_config,
**llm_config.__openllm_extras__,
**attrs,
)
if ensure_available:
logger.debug(
@@ -234,13 +242,13 @@ class _LazyAutoMapping(ConfigModelOrderedDict):
]
return t.cast(ConfigModelKeysView, mapping_keys + list(self._extra_content.keys()))
@t.overload
@overload
def get(
self, key: type[openllm.LLMConfig], default: t.Any, mapping_type: t.Literal["default"] = "default"
) -> type[openllm.LLM[t.Any, t.Any]]:
...
@t.overload
@overload
def get(self, key: str, default: t.Any, mapping_type: t.Literal["name2model", "name2config"] = ...) -> str:
...

View File

@@ -70,9 +70,6 @@ class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedToken
external_modules=[importlib.import_module(pipeline.__module__)],
)
finally:
import gc
gc.collect()
if openllm.utils.is_torch_available() and torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -65,10 +65,6 @@ class Falcon(openllm.LLM["transformers.TextGenerationPipeline", "transformers.Pr
external_modules=[importlib.import_module(model.__module__)],
)
finally:
import gc
# NOTE: We need to free the cache after saving here so that we can load it back later on.
gc.collect()
torch.cuda.empty_cache()
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:

View File

@@ -43,6 +43,16 @@ class OPTConfig(openllm.LLMConfig):
"facebook/opt-6.7b",
"facebook/opt-66b",
],
"fine_tune_strategies": (
{
"adapter_type": "lora",
"r": 16,
"lora_alpha": 32,
"target_modules": ["q_proj", "v_proj"],
"lora_dropout": 0.05,
"bias": "none",
},
),
}
format_outputs: bool = openllm.LLMConfig.Field(

View File

@@ -21,15 +21,18 @@ import bentoml
import openllm
from ..._prompt import default_formatter
from ...utils import is_peft_available
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING:
import peft
import torch
import transformers # noqa
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
peft = openllm.utils.LazyLoader("peft", globals(), "peft")
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
@@ -135,9 +138,17 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
self.model.cuda()
input_ids = t.cast(torch.Tensor, self.tokenizer(prompt, return_tensors="pt").input_ids).to(self.device)
if is_peft_available() and isinstance(self.model, peft.PeftModel):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
else:
inputs = {
"inputs": t.cast(torch.Tensor, self.tokenizer(prompt, return_tensors="pt").input_ids).to(
self.device
)
}
generated_tensors = self.model.generate(
input_ids,
**inputs,
do_sample=True,
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
)

View File

@@ -80,10 +80,7 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
try:
return bentoml.transformers.save_model(tag, model, custom_objects={"tokenizer": tokenizer})
finally:
import gc
# NOTE: We need to free the cache after saving here so that we can load it back later on.
gc.collect()
torch.cuda.empty_cache()
def sanitize_parameters(

View File

@@ -7,3 +7,5 @@ Refer to each script docstring for more information.
python -m openllm.playground.general --help
python -m openllm.plaground.ft_opt_lora --help
NOte that this folder will be considered 'breaking' most of the cases, so use will care.

View File

@@ -48,12 +48,15 @@ if openllm.utils.DEBUG:
if os.system(f"pip install -U {' '.join(_deps)}") != 0:
raise SystemExit(1)
os.environ["BITSANDBYTES_NOWELCOME"] = str(1)
from datasets import load_dataset
from peft import LoraConfig
from peft import get_peft_model
from peft import prepare_model_for_int8_training
if openllm.utils.pkg.pkg_version_info("peft")[:2] >= (0, 4):
from peft import prepare_model_for_kbit_training
else:
from peft import prepare_model_for_int8_training as prepare_model_for_kbit_training
import transformers
@@ -75,7 +78,7 @@ def load_model(model_id: str) -> tuple[PeftModel, transformers.GPT2TokenizerFast
model, tokenizer = opt.model, opt.tokenizer
# prep the model for int8 training
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=16,

View File

@@ -103,6 +103,8 @@ _LOGGING_CONFIG["loggers"].update(
def configure_logging() -> None:
"""Configure logging for OpenLLM. Behaves similar to how BentoML loggers
are being configured."""
if get_quiet_mode():
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
@@ -136,17 +138,22 @@ _import_structure = {
"analytics": [],
"codegen": [],
"dantic": [],
"constants": [],
"representation": ["ReprMixin"],
"import_utils": [
"OPTIONAL_DEPENDENCIES",
"ENV_VARS_TRUE_VALUES",
"DummyMetaclass",
"ModelEnv",
"requires_dependencies",
"is_cpm_kernels_available",
"is_einops_available",
"is_flax_available",
"is_tf_available",
"is_torch_available",
"is_bitsandbytes_available",
"is_peft_available",
"is_datasets_available",
"is_transformers_supports_kbit",
"is_transformers_supports_agent",
"require_backends",
@@ -161,6 +168,7 @@ if t.TYPE_CHECKING:
from . import bentoml_cattr as bentoml_cattr
from . import codegen as codegen
from . import configure_logging as configure_logging
from . import constants as constants
from . import copy_file_to_fs_folder as copy_file_to_fs_folder
from . import dantic as dantic
from . import first_not_none as first_not_none
@@ -180,14 +188,18 @@ if t.TYPE_CHECKING:
from .import_utils import ModelEnv as ModelEnv
from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available
from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available
from .import_utils import is_datasets_available as is_datasets_available
from .import_utils import is_einops_available as is_einops_available
from .import_utils import is_flax_available as is_flax_available
from .import_utils import is_peft_available as is_peft_available
from .import_utils import is_tf_available as is_tf_available
from .import_utils import is_torch_available as is_torch_available
from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent
from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit
from .import_utils import require_backends as require_backends
from .import_utils import requires_dependencies as requires_dependencies
from .lazy import LazyModule as LazyModule
from .representation import ReprMixin as ReprMixin
else:
import sys

View File

@@ -15,10 +15,13 @@
from __future__ import annotations
import logging
import os
import string
import typing as t
from pathlib import Path
import orjson
if t.TYPE_CHECKING:
from fs.base import FS
@@ -33,12 +36,13 @@ else:
DictStrAny = dict
_T = t.TypeVar("_T")
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
logger = logging.getLogger(__name__)
OPENLLM_MODEL_NAME = "# openllm: model name"
OPENLLM_MODEL_ID = "# openllm: model id"
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
class ModelNameFormatter(string.Formatter):
@@ -63,10 +67,16 @@ class ModelIdFormatter(ModelNameFormatter):
model_keyword: t.LiteralString = "__model_id__"
class ModelAdapterMapFormatter(ModelNameFormatter):
model_keyword: t.LiteralString = "__model_adapter_map__"
_service_file = Path(__file__).parent.parent / "_service.py"
def write_service(model_name: str, model_id: str, target_path: str, llm_fs: FS):
def write_service(
model_name: str, model_id: str, adapter_map: dict[str, str | None] | None, target_path: str, llm_fs: FS
):
from . import DEBUG
logger.debug("Generating service for %s to %s", model_name, target_path)
@@ -76,11 +86,22 @@ def write_service(model_name: str, model_id: str, target_path: str, llm_fs: FS):
# modify with model name
for it in src_contents:
if OPENLLM_MODEL_NAME in it:
src_contents[src_contents.index(it)] = ModelNameFormatter(model_name).vformat(it)
src_contents[src_contents.index(it)] = (
ModelNameFormatter(model_name).vformat(it)[: -(len(OPENLLM_MODEL_NAME) + 3)] + "\n"
)
elif OPENLLM_MODEL_ID in it:
src_contents[src_contents.index(it)] = ModelIdFormatter(model_id).vformat(it)
src_contents[src_contents.index(it)] = (
ModelIdFormatter(model_id).vformat(it)[: -(len(OPENLLM_MODEL_ID) + 3)] + "\n"
)
elif OPENLLM_MODEL_ADAPTER_MAP in it and adapter_map:
src_contents[src_contents.index(it)] = (
ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[
: -(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)
]
+ "\n"
)
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n" + "".join(src_contents)
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
if DEBUG:
logger.info("Generated script:\n%s", script)
@@ -192,7 +213,7 @@ def generate_function(
if annotations:
meth.__annotations__ = annotations
if DEBUG:
if DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3:
logger.info("Generated script for %s:\n\n%s", typ, script)
return meth

View File

@@ -31,6 +31,13 @@ from click import ParamType
import openllm
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
from attr import _ValidatorType
@@ -42,7 +49,7 @@ if t.TYPE_CHECKING:
_T = t.TypeVar("_T")
@t.overload
@overload
def attrs_to_options(
name: str,
field: attr.Attribute[t.Any],
@@ -53,7 +60,7 @@ def attrs_to_options(
...
@t.overload
@overload
def attrs_to_options( # type: ignore (overlapping overload)
name: str,
field: attr.Attribute[O_co],

View File

@@ -17,6 +17,7 @@ Some imports utils are vendorred from transformers/utils/import_utils.py for per
"""
from __future__ import annotations
import functools
import importlib
import importlib.metadata
import importlib.util
@@ -32,9 +33,19 @@ from packaging import version
from bentoml._internal.utils import LazyLoader
from bentoml._internal.utils import pkg
from .representation import ReprMixin
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
from .._types import P
else:
BackendOrderredDict = OrderedDict
@@ -64,9 +75,12 @@ def _is_package_available(package: str) -> bool:
_torch_available = importlib.util.find_spec("torch") is not None
_tf_available = importlib.util.find_spec("tensorflow") is not None
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
_peft_available = _is_package_available("peft")
_einops_available = _is_package_available("einops")
_cpm_kernel_available = _is_package_available("cpm_kernels")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_datasets_available = _is_package_available("datasets")
def is_transformers_supports_kbit() -> bool:
@@ -77,6 +91,14 @@ def is_transformers_supports_agent() -> bool:
return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
def is_datasets_available() -> bool:
return _datasets_available
def is_peft_available() -> bool:
return _peft_available
def is_einops_available():
return _einops_available
@@ -157,6 +179,33 @@ def is_flax_available():
return _flax_available
def requires_dependencies(
package: str | list[str], *, extra: str | list[str] | None = None
) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
import openllm.utils
if isinstance(package, str):
package = [package]
if isinstance(extra, str):
extra = [extra]
def decorator(func: t.Callable[P, t.Any]):
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
for p in package:
cached_check: t.Callable[[], bool] | None = getattr(openllm.utils, f"is_{p}_available", None)
if not ((cached_check is not None and cached_check()) or _is_package_available(p)):
raise ImportError(
f"{func.__name__} requires '{p}' to be available locally (Currently missing)."
f"Make sure to have {p} to be installed: 'pip install \"{p if not extra else 'openllm['+', '.join(extra)+']'}\"'"
)
return func(*args, **kwargs)
return wrapper
return decorator
PYTORCH_IMPORT_ERROR_WITH_TF = """\
{0} requires the PyTorch library but it was not found in your environment.
However, we were able to find a TensorFlow installation. TensorFlow classes begin
@@ -249,19 +298,44 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]):
raise ImportError("".join(failed))
class ModelEnv:
class ModelEnv(ReprMixin):
model_name: str
if t.TYPE_CHECKING:
config: property
model_id: property
quantize: property
framework: property
bettertransformer: property
@property
def __repr_keys__(self) -> set[str]:
return {"config", "model_id", "quantize", "framework", "bettertransformer"}
framework_value: property
quantize_value: property
bettertransformer_value: property
if t.TYPE_CHECKING:
config: str
model_id: str
quantize: str
framework: str
bettertransformer: str
framework_value: t.Literal["pt", "tf", "flax"]
quantize_value: str | None
bettertransformer_value: str | None
# fmt: off
@overload
def __getitem__(self, item: t.Literal["config"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["model_id"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["quantize"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["framework"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal['framework_value']) -> t.Literal['pt', 'tf', 'flax']: ...
@overload
def __getitem__(self, item: t.Literal['quantize_value']) -> str | None: ...
@overload
def __getitem__(self, item: t.Literal['bettertransformer_value']) -> str | None: ...
# fmt: on
def __getitem__(self, item: str | t.Any) -> t.Any:
if hasattr(self, item):
@@ -284,7 +358,7 @@ class ModelEnv:
# gen properties env value
attributes_with_values = {
"quantize": (bool, quantize),
"quantize": (str, quantize),
"bettertransformer": (bool, bettertransformer),
"framework": (str, "pt"),
}

View File

@@ -0,0 +1,61 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
from abc import abstractmethod
import attr
import orjson
if t.TYPE_CHECKING:
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
class ReprMixin:
"""This class display possible representation of given class.
It can be used for implementing __rich_pretty__ and __pretty__ methods in the future.
Most subclass needs to implement a __repr_keys__ property.
Based on the design from Pydantic.
The __repr__ will display the json representation of the object for easier interaction.
The __str__ will display either __attrs_repr__ or __repr_str__.
"""
@property
@abstractmethod
def __repr_keys__(self) -> set[str]:
"""This can be overriden by base class using this mixin."""
def __repr__(self) -> str:
from . import bentoml_cattr
serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}
return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}"
def __str__(self) -> str:
return self.__repr_str__(" ")
def __repr_name__(self) -> str:
"""Name of the instance's class, used in __repr__."""
return self.__class__.__name__
def __repr_str__(self, join_str: str) -> str:
return join_str.join(repr(v) if a is None else f"{a}={repr(v)}" for a, v in self.__repr_args__())
def __repr_args__(self) -> ReprArgs:
attrs = ((k, getattr(self, k)) for k in self.__repr_keys__)
return tuple((k, v) for k, v in attrs if v)

View File

@@ -26,6 +26,13 @@ import openllm
import transformers
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
from openllm.models.auto.factory import _BaseAutoLLMClass
@@ -162,11 +169,11 @@ class BaseClient(ClientMixin):
def health(self) -> t.Any:
raise NotImplementedError
@t.overload
@overload
def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
...
@t.overload
@overload
def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]:
...
@@ -222,11 +229,11 @@ class BaseAsyncClient(ClientMixin):
async def health(self) -> t.Any:
raise NotImplementedError
@t.overload
@overload
async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
...
@t.overload
@overload
async def query(
self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any
) -> dict[str, t.Any]:

View File

@@ -4,9 +4,11 @@ from __future__ import annotations
import os
import subprocess
import sys
from markdown_it import MarkdownIt
md = MarkdownIt()
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -16,7 +18,17 @@ with open(os.path.join(ROOT, "README.md"), "r") as f:
# NOTE: Currently, we only have one table in README, which is the Model readme.
table = [r for r in readme if r.type == "html_block" and r.content.startswith("<td><a")]
available = subprocess.check_output(["openllm", "models", "-o", "porcelain"]).strip().decode("utf-8").count("\n") + 1
prev = os.environ.pop("OPENLLMDEVDEBUG", str(0))
available = (
subprocess.check_output(
[sys.executable, "-m", "openllm", "models", "-o", "porcelain"],
)
.strip()
.decode("utf-8")
.count("\n")
+ 1
)
os.environ["OPENLLMDEVDEBUG"] = prev
on_table = len(table) # NOTE: minus the header

90
tools/update-config-stubs.py Executable file
View File

@@ -0,0 +1,90 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
import os
from pathlib import Path
import openllm
import importlib
from openllm._configuration import ModelSettings
# currently we are assuming the indentatio level is 4 for comments
START_COMMENT = f"# {os.path.basename(__file__)}: start\n"
END_COMMENT = f"# {os.path.basename(__file__)}: stop\n"
_TARGET_FILE = Path(__file__).parent.parent / "src" / "openllm" / "_configuration.py"
_imported = importlib.import_module(ModelSettings.__module__)
def process_annotations(annotations: str) -> str:
if "NotRequired" in annotations:
return annotations[len("NotRequired[") : -1]
elif "Required" in annotations:
return annotations[len("Required[") : -1]
else:
return annotations
def main() -> int:
transformed = {"fine_tune_strategies": "t.Dict[AdapterType, FineTuneConfig]"}
with _TARGET_FILE.open("r") as f:
processed = f.readlines()
start_idx, end_idx = processed.index(" " * 4 + START_COMMENT), processed.index(" " * 4 + END_COMMENT)
# convention to use t.TYPE_CHECKING
lines = [" " * 4 + "if t.TYPE_CHECKING:\n"]
for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items():
lines.extend(
list(
map(
lambda line: " " * 8 + line,
[
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
f'def __getitem__(self, item: t.Literal["{keys}"] = ...) -> {transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n',
],
)
)
)
# special case variables: generation_class, extras
lines.extend(
list(
map(
lambda line: " " * 8 + line,
[
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
'def __getitem__(self, item: t.Literal["generation_class"] = ...) -> t.Type[GenerationConfig]: ...\n',
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
'def __getitem__(self, item: t.Literal["extras"] = ...) -> t.Dict[str, t.Any]: ...\n',
],
)
)
)
processed = (
processed[:start_idx] + [" " * 4 + START_COMMENT] + lines + [" " * 4 + END_COMMENT] + processed[end_idx + 1 :]
)
with _TARGET_FILE.open("w") as f:
f.writelines(processed)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -85,6 +85,7 @@ _NIGHTLY_MAPPING: dict[str, Dependencies] = {
"optimum": Dependencies.from_tuple("optimum", "huggingface/optimum", "main", None),
"accelerate": Dependencies.from_tuple("accelerate", "huggingface/accelerate", "main", None),
"bitsandbytes": Dependencies.from_tuple("bitsandbytes", "TimDettmers/bitsandbytes", "main", None),
"deepspeed": Dependencies.from_tuple("deepspeed", "microsoft/deepspeed", "master", None),
}
FINE_TUNE_DEPS = ["peft", "bitsandbytes", "datasets", "accelerate", "deepspeed"]
@@ -125,8 +126,8 @@ def main() -> int:
with open(os.path.join(ROOT, "pyproject.toml"), "w") as f:
f.write(tomlkit.dumps(pyproject))
with open(os.path.join(ROOT, "nightly-requirements.generated.txt"), "w") as f:
f.write("# This file is generated by `./tools/update-optional-dependencies.py`\n# DO NOT EDIT\n-e .\n")
with open(os.path.join(ROOT, "nightly-requirements.txt"), "w") as f:
f.write("# This file is generated by `./tools/update-optional-dependencies.py`\n# DO NOT EDIT\n-e .[all]\n")
f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values()])
if shutil.which("taplo"):

View File

@@ -493,7 +493,7 @@ dataclass = ...
def _make_init(
cls: type[AttrsInstance],
attrs: tuple[Attribute[_T]],
attrs: tuple[Attribute[Any]],
pre_init: bool,
post_init: bool,
frozen: bool,

View File

@@ -99,7 +99,7 @@ class _OptGroup(Generic[O_co]):
:param attrs: Additional parameters of option group class
"""
...
def option(self, *param_decls: Any, **attrs: Any) -> F[P, ClickFunctionWrapper[P, O_co]]:
def option(self, *param_decls: Any, **attrs: Any) -> FC:
"""The decorator adds a new option to the group
The decorator is lazy. It adds option decls and attrs.

View File

@@ -21,6 +21,6 @@ class Merger:
fallback_strategies: List[str],
type_conflict_strategies: List[str],
) -> None: ...
def merge(self, base: ConfigDictType, nxt: ConfigDictType) -> None: ...
def merge(self, base: ConfigDictType, nxt: ConfigDictType) -> ConfigDictType: ...
def type_conflict_strategy(self, *args: Any) -> Any: ...
def value_strategy(self, path: str, base: StrategyList, nxt: StrategyList) -> None: ...