mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-22 06:19:35 -05:00
feat: serve adapter layers (#52)
This commit is contained in:
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
nightly-requirements.txt linguist-generated=true
|
||||
* text=auto eol=lf
|
||||
@@ -346,10 +346,9 @@ async def prompt(input_text: str) -> str:
|
||||
|
||||
OpenLLM seamlessly integrates with Hugging Face Agents.
|
||||
|
||||
> **Warning** The Hugging Face Agent is still in the experimental stage. It is
|
||||
> recommended to OpenLLM with
|
||||
> `pip install -r nightly-requirements.generated.txt` to get the latest API
|
||||
> update for Hugging Face agent.
|
||||
> **Warning** The HuggingFace Agent is still at experimental stage. It is
|
||||
> recommended to OpenLLM with `pip install -r nightly-requirements.txt` to get
|
||||
> the latest API update for HuggingFace agent.
|
||||
|
||||
```python
|
||||
import transformers
|
||||
|
||||
45
changelog.d/52.feature.md
Normal file
45
changelog.d/52.feature.md
Normal file
@@ -0,0 +1,45 @@
|
||||
#### Serving LLM with fine-tuned LoRA, QLoRA adapters layers
|
||||
|
||||
Then the given fine tuning weights can be served with the model via
|
||||
`openllm start`:
|
||||
|
||||
```bash
|
||||
openllm start opt --model-id facebook/opt-6.7b --adapter-id /path/to/adapters
|
||||
```
|
||||
|
||||
If you just wish to try some pretrained adapter checkpoint, you can use
|
||||
`--adapter-id`:
|
||||
|
||||
```bash
|
||||
openllm start opt --model-id facebook/opt-6.7b --adapter-id aarnphm/opt-6.7b-lora
|
||||
```
|
||||
|
||||
To use multiple adapters, use the following format:
|
||||
|
||||
```bash
|
||||
openllm start opt --model-id facebook/opt-6.7b --adapter-id aarnphm/opt-6.7b-lora --adapter-id aarnphm/opt-6.7b-lora:french_lora
|
||||
```
|
||||
|
||||
By default, the first `adapter-id` will be the default lora layer, but
|
||||
optionally users can change what lora layer to use for inference via
|
||||
`/v1/adapters`:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:3000/v1/adapters --json '{"adapter_name": "vn_lora"}'
|
||||
```
|
||||
|
||||
> Note that for multiple `adapter-name` and `adapter-id`, it is recomended to
|
||||
> update to use the default adapter before sending the inference, to avoid any
|
||||
> performance degradation
|
||||
|
||||
To include this into the Bento, one can also provide a `--adapter-id` into
|
||||
`openllm build`:
|
||||
|
||||
```bash
|
||||
openllm build opt --model-id facebook/opt-6.7b --adapter-id ...
|
||||
```
|
||||
|
||||
### Rework
|
||||
|
||||
Separate out configuration builder, to make it more flexible for future
|
||||
configuration generation.
|
||||
@@ -1,9 +1,10 @@
|
||||
# This file is generated by `./tools/update-optional-dependencies.py`
|
||||
# DO NOT EDIT
|
||||
-e .
|
||||
-e .[all]
|
||||
bentoml[grpc,io] @ git+https://github.com/bentoml/bentoml.git@main
|
||||
peft @ git+https://github.com/huggingface/peft.git@main
|
||||
transformers[torch,tokenizers,accelerate] @ git+https://github.com/huggingface/transformers.git@main
|
||||
optimum @ git+https://github.com/huggingface/optimum.git@main
|
||||
accelerate @ git+https://github.com/huggingface/accelerate.git@main
|
||||
bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git@main
|
||||
deepspeed @ git+https://github.com/microsoft/deepspeed.git@master
|
||||
@@ -58,15 +58,15 @@ requires-python = ">=3.8"
|
||||
# NOTE: Don't modify project.optional-dependencies
|
||||
# as it is managed by ./tools/update-optional-dependencies.py
|
||||
[project.optional-dependencies]
|
||||
agents = ["transformers[agents]", "diffusers", "soundfile"]
|
||||
agents = ["transformers[agents]>=4.30", "diffusers", "soundfile"]
|
||||
all = [
|
||||
"openllm[chatglm]",
|
||||
"openllm[starcoder]",
|
||||
"openllm[falcon]",
|
||||
"openllm[agents]",
|
||||
"openllm[flan-t5]",
|
||||
"openllm[fine-tune]",
|
||||
"openllm[openai]",
|
||||
"openllm[flan-t5]",
|
||||
]
|
||||
chatglm = ["cpm_kernels", "sentencepiece"]
|
||||
falcon = ["einops", "xformers", "safetensors"]
|
||||
|
||||
@@ -26,7 +26,9 @@ deploy, and monitor any LLMs with ease.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
import warnings
|
||||
|
||||
from . import utils as utils
|
||||
from .__about__ import __version__ as __version__
|
||||
@@ -39,6 +41,24 @@ if utils.DEBUG:
|
||||
|
||||
utils.configure_logging()
|
||||
logging.basicConfig(level=logging.NOTSET)
|
||||
else:
|
||||
# configuration for bitsandbytes before import
|
||||
os.environ["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
# The following warnings from bitsandbytes, and probably not that important
|
||||
# for users to see when DEBUG is False
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization"
|
||||
" are unavailable."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -29,34 +30,48 @@ from abc import abstractmethod
|
||||
import attr
|
||||
import inflection
|
||||
import orjson
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from bentoml._internal.types import ModelSignatureDict
|
||||
|
||||
from ._configuration import FineTuneConfig
|
||||
from .exceptions import ForbiddenAttributeError
|
||||
from .exceptions import GpuNotAvailableError
|
||||
from .exceptions import OpenLLMException
|
||||
from .utils import DEBUG
|
||||
from .utils import LazyLoader
|
||||
from .utils import ModelEnv
|
||||
from .utils import ReprMixin
|
||||
from .utils import bentoml_cattr
|
||||
from .utils import first_not_none
|
||||
from .utils import get_debug_mode
|
||||
from .utils import is_bitsandbytes_available
|
||||
from .utils import is_peft_available
|
||||
from .utils import is_torch_available
|
||||
from .utils import is_transformers_supports_kbit
|
||||
from .utils import non_intrusive_setattr
|
||||
from .utils import pkg
|
||||
from .utils import requires_dependencies
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from bentoml._internal.runner.strategy import Strategy
|
||||
|
||||
from ._configuration import AdapterType
|
||||
from .models.auto.factory import _BaseAutoLLMClass
|
||||
|
||||
class LLMRunner(bentoml.Runner):
|
||||
@@ -74,6 +89,7 @@ else:
|
||||
LLMRunner = bentoml.Runner
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
peft = LazyLoader("peft", globals(), "peft")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,6 +112,41 @@ def convert_transformers_model_name(name: str | None) -> str:
|
||||
return re.sub("[^a-zA-Z0-9]+", "-", name)
|
||||
|
||||
|
||||
# the below is similar to peft.utils.other.CONFIG_NAME
|
||||
PEFT_CONFIG_NAME = "adapter_config.json"
|
||||
|
||||
|
||||
def resolve_peft_config_type(adapter_map: dict[str, str | None] | None):
|
||||
"""Resolve the type of the PeftConfig given the adapter_map.
|
||||
This is similar to how PeftConfig resolve its config type.
|
||||
"""
|
||||
if adapter_map is None:
|
||||
return
|
||||
|
||||
resolved: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] = {}
|
||||
_has_set_default = False
|
||||
for path_or_adapter_id, name in adapter_map.items():
|
||||
if _has_set_default:
|
||||
raise ValueError("Only one adapter can be set as default.")
|
||||
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
|
||||
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
except Exception:
|
||||
raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'")
|
||||
with open(config_file, "r") as file:
|
||||
resolved_config = orjson.loads(file.read())
|
||||
# all peft_type should be available in PEFT_CONFIG_NAME
|
||||
_peft_type: AdapterType = resolved_config["peft_type"].lower()
|
||||
if _peft_type not in resolved:
|
||||
resolved[_peft_type] = ()
|
||||
resolved[_peft_type] += ((path_or_adapter_id, name, resolved_config),)
|
||||
if name == "default":
|
||||
_has_set_default = True
|
||||
return resolved
|
||||
|
||||
|
||||
def import_model(
|
||||
model_id: str,
|
||||
tag: bentoml.Tag,
|
||||
@@ -179,10 +230,6 @@ def import_model(
|
||||
try:
|
||||
return bentoml.transformers.save_model(tag, model, custom_objects={"tokenizer": tokenizer})
|
||||
finally:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model
|
||||
# available locally.
|
||||
@@ -193,11 +240,12 @@ def import_model(
|
||||
_reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"}
|
||||
|
||||
|
||||
class LLMInterface(ABC):
|
||||
"""This defines the loose contract for all openllm.LLM implementations."""
|
||||
_M = t.TypeVar("_M")
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
config_class: type[openllm.LLMConfig]
|
||||
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
|
||||
|
||||
class LLMInterface(ABC, t.Generic[_M, _T]):
|
||||
"""This defines the loose contract for all openllm.LLM implementations."""
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]] | None:
|
||||
@@ -270,32 +318,76 @@ class LLMInterface(ABC):
|
||||
example implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
# NOTE: All fields below are attributes that can be accessed by users.
|
||||
config_class: type[openllm.LLMConfig]
|
||||
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
|
||||
|
||||
_M = t.TypeVar("_M")
|
||||
_T = t.TypeVar("_T")
|
||||
config: openllm.LLMConfig
|
||||
"""The config instance to use for this LLM. This will be created based on config_class"""
|
||||
|
||||
bettertransformer: bool
|
||||
"""Whether to load this LLM with FasterTransformer enabled. The order of loading is:
|
||||
- If pass within `for_model`, `from_pretrained` or `__init__`.
|
||||
- If `self.bettertransformer` is set within `llm_post_init`.
|
||||
- Finally, if none of the above, default to self.config['bettertransformer']
|
||||
|
||||
> **Note** that if LoRA is enabled, bettertransformer will be disabled.
|
||||
"""
|
||||
|
||||
# NOTE: The following will be populated by __init_subclass__, note that these should not
|
||||
# be mutated by users.
|
||||
__llm_trust_remote_code__: bool
|
||||
"""This is used to determine during 'import_model' whether to trust remote code or not.
|
||||
This works synonymous with `trust_remote_code` kwarg in transformers Auto classes. If not passed,
|
||||
then by default fallback to config_class['trust_remote_code']
|
||||
"""
|
||||
|
||||
__llm_implementation__: t.Literal["pt", "tf", "flax"]
|
||||
"""This is used to determine which implementation that this LLM has. Usually, this will inferred from
|
||||
class name, that follows the HuggingFace's naming convention:
|
||||
|
||||
- `OPTForConditionalGeneration` -> `pt`
|
||||
- `TFOPTForConditionalGeneration` -> `tf`
|
||||
- `FlaxOPTForConditionalGeneration` -> `flax`
|
||||
"""
|
||||
|
||||
__llm_model__: _M | peft.PeftModel | torch.nn.Module | None
|
||||
"""A reference to the actual model. Instead of access this directly, you should use `model` property instead."""
|
||||
|
||||
__llm_tokenizer__: _T | None
|
||||
"""A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead."""
|
||||
|
||||
__llm_tag__: bentoml.Tag | None
|
||||
"""A reference to the tag used for this LLM. Instead of access this directly, you should use `tag` property instead."""
|
||||
|
||||
__llm_bentomodel__: bentoml.Model | None
|
||||
"""A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead."""
|
||||
|
||||
__llm_trainer__: transformers.Trainer | None
|
||||
"""A reference to the Trainer to be used for this LLM to fine-tune."""
|
||||
|
||||
__llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], peft.PeftConfig]] | None
|
||||
"""A reference to the the cached LoRA adapter mapping."""
|
||||
|
||||
__llm_post_init__: t.Callable[[t.Self], None] | None
|
||||
"""A callable that will be called after the LLM is initialized. This is set if subclass contains a 'llm_post_init'"""
|
||||
|
||||
__llm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None
|
||||
"""A callable that will be called after the model is loaded. This is set when 'load_model' is implemented"""
|
||||
|
||||
__llm_init_kwargs__: property | None
|
||||
"""A check if 'import_kwargs' is implemented in subclass."""
|
||||
|
||||
# The following are internal, users shouldn't access this directly.
|
||||
_model_args: tuple[t.Any, ...]
|
||||
_model_attrs: dict[str, t.Any]
|
||||
_tokenizer_attrs: dict[str, t.Any]
|
||||
|
||||
_adapters_mapping: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None
|
||||
|
||||
|
||||
@attr.define(slots=True, repr=False)
|
||||
class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
if t.TYPE_CHECKING:
|
||||
# The following will be populated by metaclass
|
||||
__llm_trust_remote_code__: bool
|
||||
__llm_implementation__: t.Literal["pt", "tf", "flax"]
|
||||
__llm_model__: _M | None
|
||||
__llm_tokenizer__: _T | None
|
||||
__llm_tag__: bentoml.Tag | None
|
||||
__llm_bentomodel__: bentoml.Model | None
|
||||
|
||||
__llm_post_init__: t.Callable[[t.Self], None] | None
|
||||
__llm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None
|
||||
__llm_init_kwargs__: property | None
|
||||
|
||||
_model_args: tuple[t.Any, ...]
|
||||
_model_attrs: dict[str, t.Any]
|
||||
_tokenizer_attrs: dict[str, t.Any]
|
||||
|
||||
bettertransformer: bool
|
||||
|
||||
class LLM(LLMInterface[_M, _T], ReprMixin):
|
||||
def __init_subclass__(cls):
|
||||
cd = cls.__dict__
|
||||
prefix_class_name_config = cls.__name__
|
||||
@@ -321,19 +413,33 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
"Missing required key 'config_class'. Make sure to define it within the LLM subclass."
|
||||
)
|
||||
|
||||
if cls.import_model is LLMInterface.import_model:
|
||||
# using the default import model
|
||||
if cls.import_model is LLMInterface[_M, _T].import_model:
|
||||
# using the default import model if no custom import is set
|
||||
setattr(cls, "import_model", functools.partial(import_model, _model_framework=implementation))
|
||||
else:
|
||||
logger.debug("Custom 'import_model' will be used when loading model %s", cls.__name__)
|
||||
|
||||
cls.__llm_post_init__ = None if cls.llm_post_init is LLMInterface.llm_post_init else cls.llm_post_init
|
||||
cls.__llm_custom_load__ = None if cls.load_model is LLMInterface.load_model else cls.load_model
|
||||
cls.__llm_init_kwargs__ = None if cls.import_kwargs is LLMInterface.import_kwargs else cls.import_kwargs
|
||||
cls.__llm_post_init__ = None if cls.llm_post_init is LLMInterface[_M, _T].llm_post_init else cls.llm_post_init
|
||||
cls.__llm_custom_load__ = None if cls.load_model is LLMInterface[_M, _T].load_model else cls.load_model
|
||||
cls.__llm_init_kwargs__ = (
|
||||
None if cls.import_kwargs is LLMInterface[_M, _T].import_kwargs else cls.import_kwargs
|
||||
)
|
||||
|
||||
for at in {"bentomodel", "tag", "model", "tokenizer"}:
|
||||
for at in {"bentomodel", "tag", "model", "tokenizer", "adapter_map", "trainer"}:
|
||||
setattr(cls, f"__llm_{at}__", None)
|
||||
|
||||
# update docstring for given entrypoint
|
||||
for fn in {"generate", "generate_one", "generate_iterator"}:
|
||||
original_fn = getattr(cls, fn, getattr(LLMInterface, fn))
|
||||
original_fn.__doc__ = (
|
||||
original_fn.__doc__
|
||||
or f"""\
|
||||
'{fn}' implementation {cls.__name__}.
|
||||
|
||||
Note that if LoRA is enabled (via either SDK or CLI), `self.model` will become a `peft.PeftModel`
|
||||
The original can then be accessed with 'self.model.get_base_model()'.
|
||||
"""
|
||||
)
|
||||
setattr(cls, fn, original_fn)
|
||||
|
||||
# The following is the similar interface to HuggingFace pretrained protocol.
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -343,14 +449,146 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
*args: t.Any,
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
adapter_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> LLM[_M, _T]:
|
||||
"""Instantiate a pretrained LLM.
|
||||
it follows the same design principle as HuggingFace's `from_pretrained` method, plus the following:
|
||||
|
||||
Optimization options:
|
||||
|
||||
> This is most notable during serving time.
|
||||
|
||||
- quantize: quantize the model with the given quantization method. Currently supported int8, int4 quantization
|
||||
- bettertransformer: Apply FasterTransformer to given pretrained weight
|
||||
|
||||
> Currently, the above two options are mutually exclusive.
|
||||
|
||||
Adapter options:
|
||||
|
||||
> This is used in conjunction with the fine-tuning features
|
||||
|
||||
- adapter_id: Optional [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to apply to said model.
|
||||
- adapter_name: Optional name of the adapter to apply to said model. If not provided, it will be handled internally by OpenLLM.
|
||||
- adapter_map: optional dictionary of adapter_id to adapter_name. Note that this is mutually exclusive with adapter_id/adapter_name arguments.
|
||||
|
||||
Args:
|
||||
model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used.
|
||||
llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM
|
||||
will use `config_class` to construct default configuration.
|
||||
quantize: The quantization to use for this LLM. Defaults to None. Possible values
|
||||
include int8, int4 and gptq.
|
||||
bettertransformer: Whether to use BetterTransformer with this model. Defaults to False.
|
||||
adapter_id: The [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to use for this LLM. Defaults to None.
|
||||
adapter_name: The adapter name to use for this LLM. Defaults to None.
|
||||
adapter_map: The adapter map to use for this LLM. Defaults to None. Note that this is mutually exclusive with adapter_id/adapter_name arguments.
|
||||
*args: The args to be passed to the model.
|
||||
**attrs: The kwargs to be passed to the model.
|
||||
"""
|
||||
quantization_config = attrs.pop("quantization_config", None)
|
||||
if quantization_config and quantize:
|
||||
raise ValueError(
|
||||
"""'quantization_config' and 'quantize' are mutually exclusive. Either customise
|
||||
your quantization_config or use the quantize argument."""
|
||||
)
|
||||
|
||||
# quantization setup
|
||||
quantization_config = attrs.pop("quantization_config", None)
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
|
||||
cpu_offloading = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
|
||||
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
|
||||
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop("llm_bnb_4bit_compute_dtype", torch.bfloat16)
|
||||
int4_quant_type = attrs.pop("llm_bnb_4bit_quant_type", "nf4")
|
||||
int4_use_double_quant = attrs.pop("llm_bnb_4bit_use_double_quant", True)
|
||||
|
||||
# NOTE: Quantization setup
|
||||
if quantization_config is None:
|
||||
# quantize is a openllm.LLM feature, where we can quantize the model
|
||||
# with bitsandbytes or quantization aware training.
|
||||
if quantize is not None:
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError(
|
||||
"Quantization requires bitsandbytes to be installed. Make "
|
||||
"sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
logger.debug(
|
||||
"'quantize' is not None. %s will use a default 'quantization_config' for %s. "
|
||||
"If you want to customise the quantization config, make sure to pass your "
|
||||
"own 'quantization_config'",
|
||||
cls.__name__,
|
||||
quantize,
|
||||
)
|
||||
if quantize == "int8":
|
||||
if int8_skip_modules is None:
|
||||
int8_skip_modules = []
|
||||
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
|
||||
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
|
||||
int8_skip_modules.append("lm_head")
|
||||
quantization_config = transformers.BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=cpu_offloading,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
elif quantize == "int4":
|
||||
if is_transformers_supports_kbit():
|
||||
quantization_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
llm_bnb_4bit_quant_type=int4_quant_type,
|
||||
llm_bnb_4bit_use_double_quant=int4_use_double_quant,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"'quantize' is set to int4, while the current transformers version %s does not support "
|
||||
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
|
||||
"make sure to install the latest version of transformers either via PyPI or "
|
||||
"from git source: 'pip install git+https://github.com/huggingface/transformers'.",
|
||||
pkg.pkg_version_info("transformers"),
|
||||
)
|
||||
elif quantize == "gptq":
|
||||
# TODO: support GPTQ loading quantization
|
||||
raise NotImplementedError("GPTQ is not supported yet.")
|
||||
if model_id is None:
|
||||
raise RuntimeError(
|
||||
"'quantize=%s' requires passing custom path to quantized weights as we are unable to load "
|
||||
"the model on the fly. See https://github.com/qwopqwop200/GPTQ-for-LLaMa for "
|
||||
"instruction on how to quantize '%s' with GPTQ.",
|
||||
quantize,
|
||||
cls.__name__,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantize} instead.")
|
||||
|
||||
# NOTE: Fine-tuning setup
|
||||
if adapter_map and adapter_id:
|
||||
raise ValueError(
|
||||
"""'adapter_map' and 'adapter_id' are mutually exclusive. Either provide a
|
||||
'adapter_map' ({adapter_id: adapter_name | None, ...}) or use
|
||||
the combination of adapter_id/adapter_name arguments.
|
||||
"""
|
||||
)
|
||||
if adapter_map is None and adapter_id is not None:
|
||||
adapter_map = {adapter_id: adapter_name}
|
||||
|
||||
if adapter_map is not None and not is_peft_available():
|
||||
raise RuntimeError(
|
||||
"LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
llm_config=llm_config,
|
||||
*args,
|
||||
quantize=quantize,
|
||||
bettertransformer=bettertransformer,
|
||||
_adapters_mapping=resolve_peft_config_type(adapter_map),
|
||||
quantization_config=quantization_config,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
@@ -359,12 +597,17 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
model_id: str | None = None,
|
||||
llm_config: openllm.LLMConfig | None = None,
|
||||
*args: t.Any,
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
_adapters_mapping: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]]
|
||||
| None = None,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
"""Initialize the LLM with given pretrained model.
|
||||
|
||||
> **Warning**
|
||||
> To initializing any LLM, you should use `openllm.AutoLLM` or `openllm.LLM.from_pretrained` instead.
|
||||
> `__init__` initialization is only for internal use.
|
||||
|
||||
Note:
|
||||
- *args to be passed to the model.
|
||||
- **attrs will first be parsed to the AutoConfig, then the rest will be parsed to the import_model
|
||||
@@ -437,8 +680,6 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
model_id: The pretrained model to use. Defaults to None. If None, 'self.default_id' will be used.
|
||||
llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM
|
||||
will use `config_class` to construct default configuration.
|
||||
quantize: The quantization to use for this LLM. Defaults to None. Possible values
|
||||
include int8, int4 and gptq.
|
||||
bettertransformer: Whether to use BetterTransformer with this model. Defaults to False.
|
||||
*args: The args to be passed to the model.
|
||||
**attrs: The kwargs to be passed to the model.
|
||||
@@ -454,18 +695,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
# low_cpu_mem_usage is only available for model
|
||||
# this is helpful on system with low memory to avoid OOM
|
||||
low_cpu_mem_usage = attrs.pop("low_cpu_mem_usage", True)
|
||||
|
||||
# quantization setup
|
||||
quantization_config = attrs.pop("quantization_config", None)
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
|
||||
cpu_offloading = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
|
||||
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
|
||||
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop("llm_bnb_4bit_compute_dtype", torch.bfloat16)
|
||||
int4_quant_type = attrs.pop("llm_bnb_4bit_quant_type", "nf4")
|
||||
int4_use_double_quant = attrs.pop("llm_bnb_4bit_use_double_quant", True)
|
||||
|
||||
if llm_config is not None:
|
||||
logger.debug("Using provided LLMConfig to initialize LLM instead of from default: %r", llm_config)
|
||||
@@ -475,69 +705,10 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
# The rests of the kwargs that is not used by the config class should be stored into __openllm_extras__.
|
||||
attrs = self.config["extras"]
|
||||
|
||||
if quantization_config and quantize:
|
||||
raise ValueError(
|
||||
"""'quantization_config' and 'quantize' are mutually exclusive. Either customise
|
||||
your quantization_config or use the quantize argument."""
|
||||
)
|
||||
if quantization_config is None:
|
||||
# quantize is a openllm.LLM feature, where we can quantize the model
|
||||
# with bitsandbytes or quantization aware training.
|
||||
if quantize is not None:
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError(
|
||||
"Quantization requires bitsandbytes to be installed. Make "
|
||||
"sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
logger.debug(
|
||||
"'quantize' is not None. %s will use a default 'quantization_config' for %s. "
|
||||
"If you want to customise the quantization config, make sure to pass your "
|
||||
"own 'quantization_config'",
|
||||
self,
|
||||
quantize,
|
||||
)
|
||||
if quantize == "int8":
|
||||
if int8_skip_modules is None:
|
||||
int8_skip_modules = []
|
||||
if "lm_head" not in int8_skip_modules and self.config["model_type"] == "causal_lm":
|
||||
logger.debug("Skipping 'lm_head' for quantization for %s", self)
|
||||
int8_skip_modules.append("lm_head")
|
||||
quantization_config = transformers.BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=cpu_offloading,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
elif quantize == "int4":
|
||||
if is_transformers_supports_kbit():
|
||||
quantization_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
llm_bnb_4bit_quant_type=int4_quant_type,
|
||||
llm_bnb_4bit_use_double_quant=int4_use_double_quant,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"'quantize' is set to int4, while the current transformers version %s does not support "
|
||||
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
|
||||
"make sure to install the latest version of transformers either via PyPI or "
|
||||
"from git source: 'pip install git+https://github.com/huggingface/transformers'.",
|
||||
pkg.pkg_version_info("transformers"),
|
||||
)
|
||||
elif quantize == "gptq":
|
||||
# TODO: support GPTQ loading quantization
|
||||
if model_id is None:
|
||||
raise RuntimeError(
|
||||
"'quantize=%s' requires passing custom path to quantized weights as we are unable to load "
|
||||
"the model on the fly. See https://github.com/qwopqwop200/GPTQ-for-LLaMa for "
|
||||
"instruction on how to quantize '%s' with GPTQ.",
|
||||
quantize,
|
||||
self,
|
||||
)
|
||||
raise NotImplementedError("GPTQ is not supported yet.")
|
||||
else:
|
||||
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantize} instead.")
|
||||
if self.config["use_pipeline"] and _adapters_mapping:
|
||||
raise ValueError(f"{self} will be used as a Pipeline, which is not yet compatible with LoRA adapter.")
|
||||
|
||||
self._adapters_mapping = _adapters_mapping
|
||||
|
||||
if self.__llm_implementation__ == "pt":
|
||||
if not self.config["use_pipeline"]:
|
||||
@@ -546,10 +717,8 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
|
||||
model_kwds, tokenizer_kwds = {}, {}
|
||||
if self.__llm_init_kwargs__:
|
||||
if t.TYPE_CHECKING:
|
||||
# the above meta value should determine that this LLM has custom kwargs
|
||||
assert self.import_kwargs
|
||||
model_kwds, tokenizer_kwds = self.import_kwargs
|
||||
# NOTE: recast here for type safety
|
||||
model_kwds, tokenizer_kwds = t.cast("tuple[dict[str, t.Any], dict[str, t.Any]]", self.__llm_init_kwargs__)
|
||||
logger.debug(
|
||||
"'%s' default kwargs for model: '%s', tokenizer: '%s'",
|
||||
self.__class__.__name__,
|
||||
@@ -564,7 +733,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
assert model_id is not None
|
||||
self._model_id = model_id
|
||||
|
||||
# parsing tokenizer and model kwargs
|
||||
# parsing tokenizer and model kwargs, as the hierachy is param pass > default
|
||||
tokenizer_kwds.update(
|
||||
{k[len(TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(TOKENIZER_PREFIX)}
|
||||
)
|
||||
@@ -588,6 +757,10 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
self.bettertransformer = bettertransformer
|
||||
else:
|
||||
non_intrusive_setattr(self, "bettertransformer", self.config["bettertransformer"])
|
||||
# If lora is passed, the disable bettertransformer
|
||||
if _adapters_mapping and self.bettertransformer is True:
|
||||
logger.debug("LoRA is visible for %s, disabling BetterTransformer", self)
|
||||
self.bettertransformer = False
|
||||
|
||||
def __setattr__(self, attr: str, value: t.Any):
|
||||
if attr in _reserved_namespace:
|
||||
@@ -599,9 +772,21 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
keys = {"model_id", "runner_name", "llm_type", "config"}
|
||||
return f"{self.__class__.__name__}({', '.join(f'{k}={getattr(self, k)!r}' for k in keys)})"
|
||||
@property
|
||||
def adapters_mapping(
|
||||
self,
|
||||
) -> dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None:
|
||||
return self._adapters_mapping
|
||||
|
||||
@adapters_mapping.setter
|
||||
def adapters_mapping(
|
||||
self, value: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None
|
||||
):
|
||||
self._adapters_mapping = value
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
return {"model_id", "runner_name", "config"}
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
@@ -712,9 +897,9 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
@property
|
||||
def _bentomodel(self) -> bentoml.Model:
|
||||
if self.__llm_bentomodel__ is None:
|
||||
# NOTE: Since PR#28, self.__llm_bentomodel__ changed from
|
||||
# NOTE: Since #28, self.__llm_bentomodel__ changed from
|
||||
# ensure_model_id_exists() into just returning the model ref.
|
||||
# This is because we want to save a few seconds of loading time,
|
||||
# This is purely a performance reason.
|
||||
# as openllm.Runner and openllm.AutoLLM initialisation is around 700ms
|
||||
# before #28.
|
||||
# If users want to make sure to have the model downloaded,
|
||||
@@ -741,23 +926,24 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
if self.config["requires_gpu"] and len(openllm.utils.gpu_count()) < 1:
|
||||
raise GpuNotAvailableError(f"{self} only supports running with GPU (None available).") from None
|
||||
|
||||
kwds = self._model_attrs
|
||||
kwds["trust_remote_code"] = self.__llm_trust_remote_code__
|
||||
|
||||
is_pipeline = "_pretrained_class" in self._bentomodel.info.metadata
|
||||
# differentiate when saving tokenizer or other pretrained type.
|
||||
is_pretrained_model = is_pipeline and "_framework" in self._bentomodel.info.metadata
|
||||
|
||||
if self.bettertransformer and is_pipeline and self.config["use_pipeline"]:
|
||||
# This is a pipeline, provide a accelerator args
|
||||
kwds["accelerator"] = "bettertransformer"
|
||||
|
||||
if self.__llm_model__ is None:
|
||||
kwds = self._model_attrs
|
||||
kwds["trust_remote_code"] = self.__llm_trust_remote_code__
|
||||
|
||||
is_pipeline = "_pretrained_class" in self._bentomodel.info.metadata
|
||||
# differentiate when saving tokenizer or other pretrained type.
|
||||
is_pretrained_model = is_pipeline and "_framework" in self._bentomodel.info.metadata
|
||||
|
||||
if self.bettertransformer and is_pipeline and self.config["use_pipeline"]:
|
||||
# This is a pipeline, provide a accelerator args
|
||||
kwds["accelerator"] = "bettertransformer"
|
||||
|
||||
if self.__llm_custom_load__:
|
||||
self.__llm_model__ = self.load_model(self.tag, *self._model_args, **kwds)
|
||||
else:
|
||||
self.__llm_model__ = self._bentomodel.load_model(*self._model_args, **kwds)
|
||||
|
||||
# This branch shouldn't hit when LoRA is visible.
|
||||
if (
|
||||
self.bettertransformer
|
||||
and is_pretrained_model
|
||||
@@ -770,6 +956,143 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
self.__llm_model__ = BetterTransformer.transform(self.__llm_model__)
|
||||
return t.cast(_M, self.__llm_model__)
|
||||
|
||||
def _transpose_adapter_mapping(
|
||||
self,
|
||||
inference_mode: bool = True,
|
||||
use_cache: bool = True,
|
||||
) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]]:
|
||||
assert self._adapters_mapping is not None, "LoRA mapping is not set up correctly."
|
||||
|
||||
if not use_cache:
|
||||
logger.debug(
|
||||
"'use_cache' is set to False. This means the adapter mapping resolution will not be cached. This should only be used during training."
|
||||
)
|
||||
|
||||
if self.__llm_adapter_map__ is not None and use_cache:
|
||||
# early out if we already serialized everything.
|
||||
return self.__llm_adapter_map__
|
||||
|
||||
adapter_map: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] = {}
|
||||
# this is a temporary check to accept the first option name as 'default'
|
||||
# then we will raise Error when the optional_name is set to None in next iteration.
|
||||
_converted_first_none = False
|
||||
for _adapter_type, _adapter_tuple in self._adapters_mapping.items():
|
||||
if _adapter_type not in adapter_map:
|
||||
adapter_map[_adapter_type] = {}
|
||||
default_config = self.config["fine_tune_strategies"].get(
|
||||
_adapter_type, FineTuneConfig(adapter_type=_adapter_type, llm_config_class=self.config_class)
|
||||
)
|
||||
default_config = default_config.eval() if inference_mode else default_config.train()
|
||||
for pretrained_or_peft_id, optional_name, resolved_mapping in _adapter_tuple:
|
||||
if not optional_name:
|
||||
if not _converted_first_none:
|
||||
_converted_first_none = True
|
||||
optional_name = "default"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {pretrained_or_peft_id, resolved_mapping}"
|
||||
)
|
||||
assert isinstance(optional_name, str) # optional_name should all be resolved here
|
||||
if optional_name == "default":
|
||||
adapter_map[_adapter_type][optional_name] = (
|
||||
default_config.with_config(**resolved_mapping).to_peft_config(),
|
||||
pretrained_or_peft_id,
|
||||
)
|
||||
else:
|
||||
adapter_map[_adapter_type][optional_name] = (
|
||||
FineTuneConfig(
|
||||
adapter_type=_adapter_type,
|
||||
adapter_config=resolved_mapping,
|
||||
inference_mode=inference_mode,
|
||||
llm_config_class=self.config_class,
|
||||
).to_peft_config(),
|
||||
pretrained_or_peft_id,
|
||||
)
|
||||
|
||||
if self.__llm_adapter_map__ is None and use_cache:
|
||||
self.__llm_adapter_map__ = adapter_map
|
||||
|
||||
return self.__llm_adapter_map__
|
||||
|
||||
return adapter_map
|
||||
|
||||
@requires_dependencies("peft", extra="fine-tune")
|
||||
def apply_adapter(
|
||||
self,
|
||||
inference_mode: bool = True,
|
||||
adapter_type: AdapterType = "lora",
|
||||
load_adapters: t.Literal["all"] | list[str] | None = None,
|
||||
use_cache: bool = True,
|
||||
) -> peft.PeftModel | _M | torch.nn.Module:
|
||||
"""Apply given LoRA mapping to the model. Note that the base model can still
|
||||
be accessed via self.model.get_base_model().
|
||||
"""
|
||||
assert self.model, "Internal error: Model is not loaded correctly."
|
||||
assert self.__llm_model__ is not None
|
||||
|
||||
# early out if _adapters_mapping is empty or it is already wrapped
|
||||
# with peft.
|
||||
if not self._adapters_mapping:
|
||||
logger.debug("No adapter mapping is found. Skip applying adapter.")
|
||||
return self.__llm_model__
|
||||
|
||||
_mapping = self._transpose_adapter_mapping(inference_mode=inference_mode, use_cache=use_cache)
|
||||
if adapter_type not in _mapping:
|
||||
raise ValueError(
|
||||
f"Given adapter type {adapter_type} is not supported. Please choose from {list(_mapping.keys())}"
|
||||
)
|
||||
adapter_mapping = _mapping[adapter_type]
|
||||
default_config, peft_model_id = adapter_mapping.pop("default", None)
|
||||
if default_config is None:
|
||||
raise ValueError(
|
||||
"There is no 'default' mapping. Please check the adapter mapping and report this bug to the OpenLLM team."
|
||||
)
|
||||
|
||||
# the below shared similar logics with `get_peft_model`
|
||||
# TODO: Support PromptLearningConfig
|
||||
if default_config.task_type not in peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
|
||||
default_config, peft.PromptLearningConfig
|
||||
):
|
||||
logger.debug(
|
||||
"Given task type '%s' is not supported by peft. This means it can be a custom PeftModel implementation. Make sure the adapter is loaded manually before running inference.",
|
||||
default_config.task_type,
|
||||
)
|
||||
self.__llm_model__ = peft.PeftModel(self.__llm_model__, default_config)
|
||||
else:
|
||||
# this is not ideal to serialize like this, wait until https://github.com/huggingface/peft/pull/612
|
||||
# is merged
|
||||
peft_class = peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING[default_config.task_type]
|
||||
if t.cast("str | None", default_config.base_model_name_or_path) is not None:
|
||||
kwargs: dict[str, t.Any] = {"is_trainable": not inference_mode}
|
||||
if "config" in inspect.signature(peft_class.from_pretrained).parameters:
|
||||
kwargs["config"] = default_config
|
||||
else:
|
||||
kwargs.update(dict(default_config.to_dict().items()))
|
||||
self.__llm_model__ = peft_class.from_pretrained(self.__llm_model__, peft_model_id, **kwargs)
|
||||
else:
|
||||
# in this case, the given base_model_name_or_path is None. This will be hit during training
|
||||
self.__llm_model__ = peft_class(self.__llm_model__, default_config)
|
||||
|
||||
# now we loop through the rest with add_adapter
|
||||
if len(adapter_mapping) > 0:
|
||||
for adapter_name, _peft_config in adapter_mapping.items():
|
||||
self.__llm_model__.add_adapter(adapter_name, _peft_config)
|
||||
|
||||
# optionally load adapters. In case of multiple adapters, or on Runner,
|
||||
# we will need to set load_adapters='all'
|
||||
if load_adapters is not None:
|
||||
adapters_to_load = adapter_mapping.keys() if load_adapters == "all" else load_adapters
|
||||
for adapter_name in adapters_to_load:
|
||||
_peft_config, _peft_model_id = adapter_mapping[adapter_name]
|
||||
self.__llm_model__.load_adapter(
|
||||
_peft_model_id,
|
||||
adapter_name=adapter_name,
|
||||
is_trainable=not inference_mode,
|
||||
**dict(_peft_config.to_dict()),
|
||||
)
|
||||
|
||||
return self.__llm_model__
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> _T:
|
||||
"""The tokenizer to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it."""
|
||||
@@ -837,18 +1160,45 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
llm_type: str
|
||||
identifying_params: dict[str, t.Any]
|
||||
|
||||
def __init_subclass__(cls, llm_type: str, identifying_params: dict[str, t.Any], **_: t.Any):
|
||||
cls.llm_type = llm_type
|
||||
cls.identifying_params = identifying_params
|
||||
|
||||
def __init__(__self: _Runnable):
|
||||
# NOTE: The side effect of this line
|
||||
# is that it will load the imported model during
|
||||
# runner creation. So don't remove it!!
|
||||
self.model
|
||||
# runner startup. So don't remove it!!
|
||||
assert self.model, "Internal error: Model is not loaded"
|
||||
if self.adapters_mapping is not None:
|
||||
logger.info("Applying LoRA to %s...", self.runner_name)
|
||||
self.apply_adapter(inference_mode=True, load_adapters="all")
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
def list_adapter(__self) -> dict[str, t.Any]:
|
||||
if not is_peft_available():
|
||||
return {
|
||||
"success": False,
|
||||
"result": {},
|
||||
"error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'",
|
||||
}
|
||||
if not isinstance(self.model, peft.PeftModel):
|
||||
return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"}
|
||||
return {"success": True, "result": self.model.peft_config, "error_msg": ""}
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
def set_adapter(__self, adapter_name: str) -> dict[t.Literal["success", "error_msg"], bool | str]:
|
||||
if not is_peft_available():
|
||||
return {
|
||||
"success": False,
|
||||
"error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'",
|
||||
}
|
||||
if not isinstance(self.model, peft.PeftModel):
|
||||
return {"success": False, "error_msg": "Model is not a PeftModel"}
|
||||
try:
|
||||
self.model.set_adapter(adapter_name)
|
||||
return {"success": True, "error_msg": ""}
|
||||
except ValueError:
|
||||
logger.info("Adapter %s not found", adapter_name)
|
||||
return {
|
||||
"success": False,
|
||||
"error_msg": f"Adapter {adapter_name} not found. Available adapters: {list(self.model.peft_config)}",
|
||||
}
|
||||
|
||||
@bentoml.Runnable.method(
|
||||
batchable=generate_sig.batchable,
|
||||
@@ -916,19 +1266,21 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
"__call__": _wrapped_generate_run,
|
||||
"__module__": f"openllm.models.{self.config['model_name']}",
|
||||
"__doc__": self.config["env"].start_docstring,
|
||||
"__repr_keys__": lambda _: {"llm", "config", "llm_type", "identifying_params"},
|
||||
}
|
||||
),
|
||||
)(
|
||||
types.new_class(
|
||||
inflection.camelize(self.config["model_name"]) + "Runnable",
|
||||
(_Runnable,),
|
||||
{
|
||||
"SUPPORTED_RESOURCES": ("nvidia.com/gpu", "cpu")
|
||||
if self.config["requires_gpu"]
|
||||
else ("nvidia.com/gpu",),
|
||||
"llm_type": self.llm_type,
|
||||
"identifying_params": self.identifying_params,
|
||||
},
|
||||
{},
|
||||
lambda ns: ns.update(
|
||||
{
|
||||
"SUPPORTED_RESOURCES": ("nvidia.com/gpu", "cpu")
|
||||
if self.config["requires_gpu"]
|
||||
else ("nvidia.com/gpu",),
|
||||
}
|
||||
),
|
||||
),
|
||||
name=self.runner_name,
|
||||
embedded=False,
|
||||
@@ -962,7 +1314,7 @@ class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
return self.postprocess_generate(prompt, generated_result, **postprocess_kwargs)
|
||||
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
def Runner(
|
||||
model_name: str,
|
||||
*,
|
||||
@@ -973,7 +1325,7 @@ def Runner(
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
def Runner(
|
||||
model_name: str,
|
||||
*,
|
||||
|
||||
@@ -23,7 +23,9 @@ import typing as t
|
||||
from pathlib import Path
|
||||
|
||||
import fs
|
||||
import fs.copy
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
@@ -38,8 +40,16 @@ from .utils import is_flax_available
|
||||
from .utils import is_tf_available
|
||||
from .utils import is_torch_available
|
||||
from .utils import pkg
|
||||
from .utils import resolve_user_filepath
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
|
||||
@@ -83,19 +93,23 @@ def construct_python_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
llm_fs: FS,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
) -> PythonOptions:
|
||||
packages = ["openllm"]
|
||||
if adapter_map is not None:
|
||||
packages += ["openllm[fine-tune]"]
|
||||
# NOTE: add openllm to the default dependencies
|
||||
# if users has openllm custom built wheels, it will still respect
|
||||
# that since bentoml will always install dependencies from requirements.txt
|
||||
# first, then proceed to install everything inside the wheels/ folder.
|
||||
if extra_dependencies is not None:
|
||||
packages += [f"openllm[{k}]" for k in extra_dependencies]
|
||||
filtered = set(extra_dependencies + ("fine-tune",))
|
||||
packages += [f"openllm[{k}]" for k in filtered]
|
||||
|
||||
if llm.config["requirements"] is not None:
|
||||
packages.extend(llm.config["requirements"])
|
||||
|
||||
if not (str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false"):
|
||||
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false":
|
||||
packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
env: ModelEnv = llm.config["env"]
|
||||
@@ -149,6 +163,7 @@ def construct_docker_options(
|
||||
workers_per_resource: int | float,
|
||||
quantize: t.LiteralString | None,
|
||||
bettertransformer: bool | None,
|
||||
adapter_map: dict[str, str | None] | None,
|
||||
) -> DockerOptions:
|
||||
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options_opts = [
|
||||
@@ -164,10 +179,14 @@ def construct_docker_options(
|
||||
env.config: f"'{llm.config.model_dump_json().decode()}'",
|
||||
"OPENLLM_MODEL": llm.config["model_name"],
|
||||
"OPENLLM_MODEL_ID": llm.model_id,
|
||||
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'",
|
||||
"BENTOML_DEBUG": str(get_debug_mode()),
|
||||
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options,
|
||||
}
|
||||
|
||||
if adapter_map:
|
||||
env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
|
||||
# We need to handle None separately here, as env from subprocess doesn't
|
||||
# accept None value.
|
||||
_env = ModelEnv(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize)
|
||||
@@ -181,38 +200,6 @@ def construct_docker_options(
|
||||
return DockerOptions(cuda_version="11.6", env=env_dict, system_packages=["git"])
|
||||
|
||||
|
||||
@t.overload
|
||||
def build(
|
||||
model_name: str,
|
||||
*,
|
||||
model_id: str | None = ...,
|
||||
quantize: t.LiteralString | None = ...,
|
||||
bettertransformer: bool | None = ...,
|
||||
_extra_dependencies: tuple[str, ...] | None = ...,
|
||||
_workers_per_resource: int | float | None = ...,
|
||||
_overwrite_existing_bento: bool = ...,
|
||||
__cli__: t.Literal[False] = ...,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Bento:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def build(
|
||||
model_name: str,
|
||||
*,
|
||||
model_id: str | None = ...,
|
||||
quantize: t.LiteralString | None = ...,
|
||||
bettertransformer: bool | None = ...,
|
||||
_extra_dependencies: tuple[str, ...] | None = ...,
|
||||
_workers_per_resource: int | float | None = ...,
|
||||
_overwrite_existing_bento: bool = ...,
|
||||
__cli__: t.Literal[True] = ...,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[bentoml.Bento, bool]:
|
||||
...
|
||||
|
||||
|
||||
def _build_bento(
|
||||
bento_tag: bentoml.Tag,
|
||||
service_name: str,
|
||||
@@ -221,34 +208,87 @@ def _build_bento(
|
||||
workers_per_resource: int | float,
|
||||
quantize: t.LiteralString | None,
|
||||
bettertransformer: bool | None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
build_ctx: str | None = None,
|
||||
) -> bentoml.Bento:
|
||||
framework_envvar = llm.config["env"]["framework_value"]
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({"_type": llm.llm_type, "_framework": framework_envvar})
|
||||
logger.info("Building Bento for LLM '%s'", llm.config["start_name"])
|
||||
|
||||
if adapter_map is not None:
|
||||
assert build_ctx is not None, "build_ctx is required when 'adapter_map' is not None"
|
||||
updated_mapping: dict[str, str | None] = {}
|
||||
for adapter_id, name in adapter_map.items():
|
||||
try:
|
||||
resolve_user_filepath(adapter_id, build_ctx)
|
||||
src_folder_name = os.path.basename(adapter_id)
|
||||
src_fs = fs.open_fs(build_ctx)
|
||||
llm_fs.makedir(src_folder_name, recreate=True)
|
||||
fs.copy.copy_dir(src_fs, adapter_id, llm_fs, src_folder_name)
|
||||
updated_mapping[src_folder_name] = name
|
||||
except FileNotFoundError:
|
||||
# this is the remote adapter, then just added back
|
||||
# note that there is a drawback here. If the path of the local adapter
|
||||
# path have the same name as the remote, then we currently don't support
|
||||
# that edge case.
|
||||
updated_mapping[adapter_id] = name
|
||||
adapter_map = updated_mapping
|
||||
|
||||
# add service.py definition to this temporary folder
|
||||
codegen.write_service(llm.config["model_name"], llm.model_id, adapter_map, llm.config["service_name"], llm_fs)
|
||||
|
||||
return bentoml.bentos.build(
|
||||
f"{service_name}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(
|
||||
llm_fs.walk.files(filter=["*.py"])
|
||||
), # NOTE: By default, we are using _service.py as the default service, for now.
|
||||
exclude=["/venv", "__pycache__/", "*.py[cod]", "*$py.class"],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies),
|
||||
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer),
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map),
|
||||
version=bento_tag.version,
|
||||
build_ctx=llm_fs.getsyspath("/"),
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def build(
|
||||
model_name: str,
|
||||
*,
|
||||
model_id: str | None = ...,
|
||||
quantize: t.LiteralString | None = ...,
|
||||
bettertransformer: bool | None = ...,
|
||||
adapter_map: dict[str, str | None] | None = ...,
|
||||
__cli__: t.Literal[False] = False,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Bento:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def build(
|
||||
model_name: str,
|
||||
*,
|
||||
model_id: str | None = ...,
|
||||
quantize: t.LiteralString | None = ...,
|
||||
bettertransformer: bool | None = ...,
|
||||
adapter_map: dict[str, str | None] | None = ...,
|
||||
__cli__: t.Literal[True] = ...,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[bentoml.Bento, bool]:
|
||||
...
|
||||
|
||||
|
||||
def build(
|
||||
model_name: str,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
quantize: t.LiteralString | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
_build_ctx: str | None = None,
|
||||
_extra_dependencies: tuple[str, ...] | None = None,
|
||||
_workers_per_resource: int | float | None = None,
|
||||
_overwrite_existing_bento: bool = False,
|
||||
@@ -270,14 +310,14 @@ def build(
|
||||
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
|
||||
logger.info("Packing '%s' into a Bento with kwargs=%s...", model_name, attrs)
|
||||
logger.info("Packing '%s' into a Bento%s...", model_name, f" with 'kwargs={attrs}' " if attrs else "")
|
||||
|
||||
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
|
||||
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
|
||||
try:
|
||||
os.environ["OPENLLM_MODEL"] = inflection.underscore(model_name)
|
||||
|
||||
framework_envvar = llm_config["env"]["framework_value"]
|
||||
framework_envvar = llm_config["env"].framework_value
|
||||
llm = t.cast(
|
||||
"_BaseAutoLLMClass",
|
||||
openllm[framework_envvar], # type: ignore (internal API)
|
||||
@@ -286,7 +326,9 @@ def build(
|
||||
model_id=model_id,
|
||||
llm_config=llm_config,
|
||||
quantize=quantize,
|
||||
adapter_map=adapter_map,
|
||||
bettertransformer=bettertransformer,
|
||||
return_runner_kwargs=False,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
@@ -294,41 +336,40 @@ def build(
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({"_type": llm.llm_type, "_framework": framework_envvar})
|
||||
service_name = f"generated_{llm_config['model_name']}_service.py"
|
||||
workers_per_resource = first_not_none(_workers_per_resource, default=llm_config["workers_per_resource"])
|
||||
|
||||
with fs.open_fs(f"temp://llm_{llm_config['model_name']}") as llm_fs:
|
||||
# add service.py definition to this temporary folder
|
||||
codegen.write_service(model_name, llm.model_id, service_name, llm_fs)
|
||||
|
||||
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
|
||||
try:
|
||||
bento = bentoml.get(bento_tag)
|
||||
if _overwrite_existing_bento:
|
||||
logger.info("Overwriting previously saved Bento.")
|
||||
bentoml.delete(bento_tag)
|
||||
bento = _build_bento(
|
||||
bento_tag,
|
||||
service_name,
|
||||
llm.config["service_name"],
|
||||
llm_fs,
|
||||
llm,
|
||||
workers_per_resource=workers_per_resource,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
bettertransformer=bettertransformer,
|
||||
extra_dependencies=_extra_dependencies,
|
||||
build_ctx=_build_ctx,
|
||||
)
|
||||
_previously_built = True
|
||||
except bentoml.exceptions.NotFound:
|
||||
logger.info("Building Bento for LLM '%s'", llm_config["start_name"])
|
||||
bento = _build_bento(
|
||||
bento_tag,
|
||||
service_name,
|
||||
llm.config["service_name"],
|
||||
llm_fs,
|
||||
llm,
|
||||
workers_per_resource=workers_per_resource,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
bettertransformer=bettertransformer,
|
||||
extra_dependencies=_extra_dependencies,
|
||||
build_ctx=_build_ctx,
|
||||
)
|
||||
return (bento, _previously_built) if __cli__ else bento
|
||||
except Exception as e:
|
||||
|
||||
@@ -25,6 +25,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
import warnings
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
@@ -40,8 +41,25 @@ if t.TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
# The following warnings from bitsandbytes, and probably not that important
|
||||
# for users to see
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
"The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization"
|
||||
" are unavailable."
|
||||
),
|
||||
)
|
||||
|
||||
model = os.environ.get("OPENLLM_MODEL", "{__model_name__}") # openllm: model name
|
||||
model_id = os.environ.get("OPENLLM_MODEL_ID", "{__model_id__}") # openllm: model id
|
||||
adapter_map = os.environ.get("OPENLLM_ADAPTER_MAP", """{__model_adapter_map__}""") # openllm: model adapter map
|
||||
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
|
||||
@@ -51,6 +69,7 @@ runner = openllm.Runner(
|
||||
llm_config=llm_config,
|
||||
bettertransformer=llm_config["env"]["bettertransformer_value"],
|
||||
quantize=llm_config["env"]["quantize_value"],
|
||||
adapter_map=orjson.loads(adapter_map),
|
||||
ensure_available=False,
|
||||
init_local=False,
|
||||
)
|
||||
@@ -70,7 +89,19 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
|
||||
|
||||
@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON(), route="/v1/metadata")
|
||||
@svc.api(
|
||||
input=bentoml.io.Text(),
|
||||
output=bentoml.io.JSON.from_sample(
|
||||
sample={
|
||||
"model_id": model_id,
|
||||
"timeout": 3600,
|
||||
"model_name": llm_config["model_name"],
|
||||
"framework": "pt",
|
||||
"configuration": "",
|
||||
}
|
||||
),
|
||||
route="/v1/metadata",
|
||||
)
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return openllm.MetadataOutput(
|
||||
model_id=model_id,
|
||||
@@ -81,6 +112,15 @@ def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
)
|
||||
|
||||
|
||||
@svc.api(
|
||||
input=bentoml.io.Text.from_sample(sample="default"),
|
||||
output=bentoml.io.JSON.from_sample(sample={"success": True, "error_msg": "some error message"}),
|
||||
route="/v1/adapters",
|
||||
)
|
||||
async def adapters_v1(adapter_name: str) -> dict[str, bool | str]:
|
||||
return await runner.set_adapter.async_run(adapter_name)
|
||||
|
||||
|
||||
@attr.define
|
||||
class HfAgentInput:
|
||||
inputs: str
|
||||
@@ -105,3 +145,14 @@ async def hf_agent(request: Request) -> Response:
|
||||
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
|
||||
|
||||
svc.mount_asgi_app(hf_app, path="/hf")
|
||||
|
||||
|
||||
async def list_adapter_v1(_: Request) -> Response:
|
||||
res = await runner.list_adapter.async_run()
|
||||
if res["success"]:
|
||||
res["result"] = {k: v.to_dict() for k, v in res["result"].items()}
|
||||
return JSONResponse(res, status_code=200)
|
||||
|
||||
|
||||
metadata_app = Starlette(debug=True, routes=[Route("/adapters", list_adapter_v1, methods=["GET"])])
|
||||
svc.mount_asgi_app(metadata_app, path="/v1")
|
||||
|
||||
25
src/openllm/_trainer.py
Normal file
25
src/openllm/_trainer.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
class PeftTrainer(transformers.Trainer):
|
||||
...
|
||||
|
||||
|
||||
class PeftSaveCallback(transformers.TrainerCallback):
|
||||
...
|
||||
@@ -56,8 +56,10 @@ from .utils import first_not_none
|
||||
from .utils import get_debug_mode
|
||||
from .utils import get_quiet_mode
|
||||
from .utils import gpu_count
|
||||
from .utils import is_peft_available
|
||||
from .utils import is_torch_available
|
||||
from .utils import is_transformers_supports_agent
|
||||
from .utils import resolve_user_filepath
|
||||
from .utils import set_debug_mode
|
||||
from .utils import set_quiet_mode
|
||||
|
||||
@@ -105,13 +107,14 @@ class NargsOptions(cog.GroupedOption):
|
||||
options.
|
||||
"""
|
||||
|
||||
_nargs_parser: click.parser.Option
|
||||
_prev_parser_process: t.Callable[[t.Any, click.parser.ParsingState], None]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
nargs = attrs.pop("nargs", -1)
|
||||
if nargs != -1:
|
||||
raise OpenLLMException(f"'nargs' is set, and must be -1 instead of {nargs}")
|
||||
super(NargsOptions, self).__init__(*args, **attrs)
|
||||
self._prev_parser_process: t.Callable[[t.Any, click.parser.ParsingState], None] | None = None
|
||||
self._nargs_parser: click.parser.Option | None = None
|
||||
|
||||
def add_to_parser(self, parser: click.OptionParser, ctx: click.Context) -> None:
|
||||
def _parser(value: t.Any, state: click.parser.ParsingState):
|
||||
@@ -238,7 +241,7 @@ def quantize_option(factory: t.Any, build: bool = False):
|
||||
)
|
||||
|
||||
|
||||
def bettertransformer_option(factory: t.Any, model_env: ModelEnv | None = None):
|
||||
def bettertransformer_option(factory: t.Any, build: bool = False, model_env: ModelEnv | None = None):
|
||||
envvar = None
|
||||
if model_env is not None:
|
||||
envvar = model_env.bettertransformer
|
||||
@@ -246,14 +249,35 @@ def bettertransformer_option(factory: t.Any, model_env: ModelEnv | None = None):
|
||||
"--bettertransformer",
|
||||
is_flag=True,
|
||||
default=None,
|
||||
help="Use BetterTransformer wrapper to serve model. This will applies during serving time.",
|
||||
help="Apply FasterTransformer wrapper to serve model. This will applies during serving time."
|
||||
if not build
|
||||
else "Set defaul environment variable whether to serve this model with FasterTransformer in build time.",
|
||||
envvar=envvar,
|
||||
show_envvar=True if envvar is not None else False,
|
||||
)
|
||||
|
||||
|
||||
_adapter_mapping_key = "adapter_map"
|
||||
|
||||
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> dict[str, str] | None:
|
||||
if not value:
|
||||
return
|
||||
if _adapter_mapping_key not in ctx.params:
|
||||
ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
|
||||
try:
|
||||
# try to resolve the full path if users pass in relative,
|
||||
# currently only support one level of resolve path.
|
||||
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
|
||||
|
||||
class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
NUMBER_OF_COMMON_PARAMS = 3
|
||||
NUMBER_OF_COMMON_PARAMS = 4 # parameters in common_params + 1 faked group option header
|
||||
|
||||
@staticmethod
|
||||
def common_params(f: F[P, t.Any]) -> ClickFunctionWrapper[..., t.Any]:
|
||||
@@ -265,11 +289,14 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
from bentoml._internal.configuration import DEBUG_ENV_VAR
|
||||
from bentoml._internal.configuration import QUIET_ENV_VAR
|
||||
|
||||
@click.option("-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output.")
|
||||
@click.option(
|
||||
@cog.optgroup.group("Miscellaneous options")
|
||||
@cog.optgroup.option(
|
||||
"-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output."
|
||||
)
|
||||
@cog.optgroup.option(
|
||||
"--debug", "--verbose", envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help="Print out debug logs."
|
||||
)
|
||||
@click.option(
|
||||
@cog.optgroup.option(
|
||||
"--do-not-track",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
@@ -507,7 +534,7 @@ _http_server_args = parse_serve_args(False)
|
||||
_grpc_server_args = parse_serve_args(True)
|
||||
|
||||
|
||||
def start_model_command(
|
||||
def start_command_factory(
|
||||
model_name: str,
|
||||
_context_settings: dict[str, t.Any] | None = None,
|
||||
_serve_grpc: bool = False,
|
||||
@@ -531,7 +558,7 @@ def start_model_command(
|
||||
configure_logging()
|
||||
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
env: ModelEnv = llm_config["env"]
|
||||
env = llm_config["env"]
|
||||
|
||||
docstring = f"""\
|
||||
{env.start_docstring}
|
||||
@@ -576,7 +603,7 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
|
||||
@group.command(**command_attrs)
|
||||
@llm_config.to_click_options
|
||||
@serve_decorator
|
||||
@cog.optgroup.group("General LLM Options")
|
||||
@cog.optgroup.group("General LLM Options", help="The following options are related to running the LLM Server.")
|
||||
@cog.optgroup.option(
|
||||
"--server-timeout",
|
||||
type=int,
|
||||
@@ -591,7 +618,25 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
|
||||
default=False,
|
||||
help="Bypass auto model checks and setup. This option is ahead-of-serving time.",
|
||||
)
|
||||
@cog.optgroup.group("LLM Optimization Options.")
|
||||
@cog.optgroup.group(
|
||||
"LLM Optimization Options",
|
||||
help="""\
|
||||
These options are related for dynamic optimization on the fly. Current supported strategies:
|
||||
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
|
||||
- int4: Quantize the model with 4bit (bitsandbytes required)
|
||||
|
||||
- bettertransformer: Convert given model to FastTransformer
|
||||
|
||||
The following are currently being worked on:
|
||||
|
||||
- GPTQ: [paper](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
|
||||
""",
|
||||
)
|
||||
@cog.optgroup.option(
|
||||
"--device",
|
||||
type=tuple,
|
||||
@@ -604,6 +649,32 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
|
||||
)
|
||||
@quantize_option(cog.optgroup)
|
||||
@bettertransformer_option(cog.optgroup, model_env=env)
|
||||
@cog.optgroup.group(
|
||||
"Fine-tuning related options",
|
||||
help="""\
|
||||
Note that the argument `--adapter-id` can accept the following format:
|
||||
|
||||
- `--adapter-id /path/to/adapter` (local adapter)
|
||||
|
||||
- `--adapter-id remote/adapter` (remote adapter from HuggingFace Hub)
|
||||
|
||||
- `--adapter-id remote/adapter:eng_lora` (two previous adapter options with the given adapter_name)
|
||||
|
||||
```bash
|
||||
|
||||
openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora
|
||||
|
||||
```
|
||||
""",
|
||||
)
|
||||
@cog.optgroup.option(
|
||||
"--adapter-id",
|
||||
default=None,
|
||||
help="Optional name or path for given LoRA adapter" + f" to wrap '{model_name}'",
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]",
|
||||
)
|
||||
@click.pass_context
|
||||
def model_start(
|
||||
ctx: click.Context,
|
||||
@@ -616,6 +687,10 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
|
||||
fast: bool,
|
||||
**attrs: t.Any,
|
||||
) -> openllm.LLMConfig:
|
||||
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
# remove adapter_id
|
||||
attrs.pop("adapter_id", None)
|
||||
|
||||
config, server_attrs = llm_config.model_validate_click(**attrs)
|
||||
|
||||
# Create a new model env to work with the envvar during CLI invocation
|
||||
@@ -631,6 +706,13 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
|
||||
_echo("Quantization is currently only available for PyTorch models.", fg="red")
|
||||
ctx.exit(1)
|
||||
|
||||
if adapter_map and not is_peft_available():
|
||||
_echo(
|
||||
"Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'",
|
||||
fg="red",
|
||||
)
|
||||
ctx.exit(1)
|
||||
|
||||
# We need to handle None separately here, as env from subprocess doesn't
|
||||
# accept None value.
|
||||
env = ModelEnv(env.model_name, bettertransformer=bettertransformer, quantize=quantize)
|
||||
@@ -686,22 +768,29 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
|
||||
|
||||
if fast and not get_quiet_mode():
|
||||
_echo(
|
||||
f"Make sure to download the model before 'start': 'openllm download {model_name}{'--model-id ' + model_id if model_id else ''}'",
|
||||
f"Fast mode is enabled. Make sure to download the model before 'start': 'openllm download {model_name}{'--model-id ' + model_id if model_id else ''}'",
|
||||
fg="yellow",
|
||||
)
|
||||
automodel_attrs = {
|
||||
"model_id": model_id,
|
||||
"llm_config": config,
|
||||
"ensure_available": not fast,
|
||||
}
|
||||
|
||||
automodel_attrs: dict[str, t.Any] = {}
|
||||
if framework_envvar == "pt":
|
||||
automodel_attrs.update({"quantize": quantize, "bettertransformer": bettertransformer})
|
||||
|
||||
if adapter_map:
|
||||
_echo(f"OpenLLM will convert '{model_name}' to use provided adapters layers: {list(adapter_map)}")
|
||||
automodel_attrs.update({"adapter_map": adapter_map})
|
||||
|
||||
llm = t.cast(
|
||||
"_BaseAutoLLMClass",
|
||||
openllm[framework_envvar], # type: ignore (internal API)
|
||||
).for_model(model_name, **automodel_attrs)
|
||||
).for_model(
|
||||
model_name,
|
||||
model_id=model_id,
|
||||
llm_config=config,
|
||||
ensure_available=not fast,
|
||||
return_runner_kwargs=False,
|
||||
**automodel_attrs,
|
||||
)
|
||||
|
||||
start_env.update(
|
||||
{
|
||||
@@ -709,6 +798,7 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
|
||||
env.config: llm.config.model_dump_json().decode(),
|
||||
"OPENLLM_MODEL": model_name,
|
||||
"OPENLLM_MODEL_ID": llm.model_id,
|
||||
"OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map),
|
||||
"BENTOML_DEBUG": str(get_debug_mode()),
|
||||
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options_env,
|
||||
"BENTOML_HOME": os.environ.get("BENTOML_HOME", BentoMLContainer.bentoml_home.get()),
|
||||
@@ -746,9 +836,9 @@ Available model_id(s): {llm_config['model_ids']} [default: {llm_config['default_
|
||||
return model_start
|
||||
|
||||
|
||||
_cached_http = {key: start_model_command(key, _context_settings=_CONTEXT_SETTINGS) for key in openllm.CONFIG_MAPPING}
|
||||
_cached_http = {key: start_command_factory(key, _context_settings=_CONTEXT_SETTINGS) for key in openllm.CONFIG_MAPPING}
|
||||
_cached_grpc = {
|
||||
key: start_model_command(key, _context_settings=_CONTEXT_SETTINGS, _serve_grpc=True)
|
||||
key: start_command_factory(key, _context_settings=_CONTEXT_SETTINGS, _serve_grpc=True)
|
||||
for key in openllm.CONFIG_MAPPING
|
||||
}
|
||||
|
||||
@@ -765,7 +855,7 @@ def _start(
|
||||
|
||||
if framework is not None:
|
||||
os.environ[_ModelEnv.framework] = framework
|
||||
start_model_command(model_name, _serve_grpc=_serve_grpc)(standalone_mode=False, **attrs)
|
||||
start_command_factory(model_name, _serve_grpc=_serve_grpc)(standalone_mode=False, **attrs)
|
||||
|
||||
|
||||
start = functools.partial(_start, _serve_grpc=False)
|
||||
@@ -792,7 +882,17 @@ start_grpc = functools.partial(_start, _serve_grpc=True)
|
||||
nargs=1,
|
||||
metavar="FEATURE[,FEATURE]",
|
||||
)
|
||||
@click.option(
|
||||
"--adapter-id",
|
||||
default=None,
|
||||
help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.",
|
||||
multiple=True,
|
||||
metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]",
|
||||
)
|
||||
@click.option("--build-ctx", default=".", help="Build context. This is required if --adapter-id uses relative path")
|
||||
@click.pass_context
|
||||
def build(
|
||||
ctx: click.Context,
|
||||
model_name: str,
|
||||
model_id: str | None,
|
||||
overwrite: bool,
|
||||
@@ -801,6 +901,9 @@ def build(
|
||||
enable_features: tuple[str] | None,
|
||||
bettertransformer: bool | None,
|
||||
workers_per_resource: float | None,
|
||||
adapter_id: tuple[str, ...],
|
||||
build_ctx: str | None,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
"""Package a given models into a Bento.
|
||||
|
||||
@@ -813,6 +916,20 @@ def build(
|
||||
> NOTE: To run a container built from this Bento with GPU support, make sure
|
||||
> to have https://github.com/NVIDIA/nvidia-container-toolkit install locally.
|
||||
"""
|
||||
adapter_map: dict[str, str | None] | None = None
|
||||
|
||||
if adapter_id:
|
||||
if not build_ctx:
|
||||
_echo("'build_ctx' must not be None when '--adapter-id' is passsed.", fg="red")
|
||||
ctx.exit(1)
|
||||
|
||||
adapter_map = {}
|
||||
for v in adapter_id:
|
||||
_adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
|
||||
# We don't resolve full path here, leave it to build
|
||||
# we are just doing the parsing here.
|
||||
adapter_map[_adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
|
||||
if output == "porcelain":
|
||||
set_quiet_mode(True)
|
||||
configure_logging()
|
||||
@@ -824,12 +941,15 @@ def build(
|
||||
if enable_features:
|
||||
enable_features = tuple(itertools.chain.from_iterable((s.split(",") for s in enable_features)))
|
||||
|
||||
# TODO: xxx
|
||||
bento, _previously_built = openllm.build(
|
||||
model_name,
|
||||
__cli__=True,
|
||||
model_id=model_id,
|
||||
quantize=quantize,
|
||||
bettertransformer=bettertransformer,
|
||||
adapter_map=adapter_map,
|
||||
_build_ctx=build_ctx,
|
||||
_extra_dependencies=enable_features,
|
||||
_workers_per_resource=workers_per_resource,
|
||||
_overwrite_existing_bento=overwrite,
|
||||
@@ -851,8 +971,8 @@ def build(
|
||||
+ "* Push to BentoCloud with `bentoml push`:\n"
|
||||
+ f" $ bentoml push {bento.tag}\n"
|
||||
+ "* Containerize your Bento with `bentoml containerize`:\n"
|
||||
+ f" $ bentoml containerize {bento.tag}\n"
|
||||
+ " Tip: To enable additional BentoML feature for 'containerize', "
|
||||
+ f" $ bentoml containerize {bento.tag}\n\n"
|
||||
+ " Tip: To enable additional BentoML features for 'containerize', "
|
||||
+ "use '--enable-features=FEATURE[,FEATURE]' "
|
||||
+ "[see 'bentoml containerize -h' for more advanced usage]\n",
|
||||
fg="blue",
|
||||
|
||||
@@ -41,3 +41,7 @@ class MissingAnnotationAttributeError(OpenLLMException):
|
||||
|
||||
class MissingDependencyError(BaseException):
|
||||
"""Raised when a dependency is missing."""
|
||||
|
||||
|
||||
class FineTuneStrategyNotSupportedError(OpenLLMException):
|
||||
"""Raised when a fine-tune strategy is not supported for given LLM."""
|
||||
|
||||
@@ -27,6 +27,14 @@ import openllm
|
||||
from .configuration_auto import AutoConfig
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from collections import _odict_items
|
||||
from collections import _odict_keys
|
||||
@@ -54,7 +62,7 @@ class _BaseAutoLLMClass:
|
||||
"Please use '{self.__class__.__name__}.Runner(model_name)' instead."
|
||||
)
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
@classmethod
|
||||
def for_model(
|
||||
cls,
|
||||
@@ -67,7 +75,7 @@ class _BaseAutoLLMClass:
|
||||
) -> openllm.LLM[t.Any, t.Any]:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
@classmethod
|
||||
def for_model(
|
||||
cls,
|
||||
@@ -116,7 +124,7 @@ class _BaseAutoLLMClass:
|
||||
llm = cls._model_mapping[type(llm_config)].from_pretrained(
|
||||
model_id,
|
||||
llm_config=llm_config,
|
||||
**llm_config.__openllm_extras__,
|
||||
**attrs,
|
||||
)
|
||||
if ensure_available:
|
||||
logger.debug(
|
||||
@@ -234,13 +242,13 @@ class _LazyAutoMapping(ConfigModelOrderedDict):
|
||||
]
|
||||
return t.cast(ConfigModelKeysView, mapping_keys + list(self._extra_content.keys()))
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
def get(
|
||||
self, key: type[openllm.LLMConfig], default: t.Any, mapping_type: t.Literal["default"] = "default"
|
||||
) -> type[openllm.LLM[t.Any, t.Any]]:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
def get(self, key: str, default: t.Any, mapping_type: t.Literal["name2model", "name2config"] = ...) -> str:
|
||||
...
|
||||
|
||||
|
||||
@@ -70,9 +70,6 @@ class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedToken
|
||||
external_modules=[importlib.import_module(pipeline.__module__)],
|
||||
)
|
||||
finally:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -65,10 +65,6 @@ class Falcon(openllm.LLM["transformers.TextGenerationPipeline", "transformers.Pr
|
||||
external_modules=[importlib.import_module(model.__module__)],
|
||||
)
|
||||
finally:
|
||||
import gc
|
||||
|
||||
# NOTE: We need to free the cache after saving here so that we can load it back later on.
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
|
||||
@@ -43,6 +43,16 @@ class OPTConfig(openllm.LLMConfig):
|
||||
"facebook/opt-6.7b",
|
||||
"facebook/opt-66b",
|
||||
],
|
||||
"fine_tune_strategies": (
|
||||
{
|
||||
"adapter_type": "lora",
|
||||
"r": 16,
|
||||
"lora_alpha": 32,
|
||||
"target_modules": ["q_proj", "v_proj"],
|
||||
"lora_dropout": 0.05,
|
||||
"bias": "none",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
format_outputs: bool = openllm.LLMConfig.Field(
|
||||
|
||||
@@ -21,15 +21,18 @@ import bentoml
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from ...utils import is_peft_available
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
import torch
|
||||
|
||||
import transformers # noqa
|
||||
else:
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
peft = openllm.utils.LazyLoader("peft", globals(), "peft")
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -135,9 +138,17 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
self.model.cuda()
|
||||
|
||||
input_ids = t.cast(torch.Tensor, self.tokenizer(prompt, return_tensors="pt").input_ids).to(self.device)
|
||||
if is_peft_available() and isinstance(self.model, peft.PeftModel):
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
else:
|
||||
inputs = {
|
||||
"inputs": t.cast(torch.Tensor, self.tokenizer(prompt, return_tensors="pt").input_ids).to(
|
||||
self.device
|
||||
)
|
||||
}
|
||||
|
||||
generated_tensors = self.model.generate(
|
||||
input_ids,
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
)
|
||||
|
||||
@@ -80,10 +80,7 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
|
||||
try:
|
||||
return bentoml.transformers.save_model(tag, model, custom_objects={"tokenizer": tokenizer})
|
||||
finally:
|
||||
import gc
|
||||
|
||||
# NOTE: We need to free the cache after saving here so that we can load it back later on.
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def sanitize_parameters(
|
||||
|
||||
@@ -7,3 +7,5 @@ Refer to each script docstring for more information.
|
||||
python -m openllm.playground.general --help
|
||||
|
||||
python -m openllm.plaground.ft_opt_lora --help
|
||||
|
||||
NOte that this folder will be considered 'breaking' most of the cases, so use will care.
|
||||
|
||||
@@ -48,12 +48,15 @@ if openllm.utils.DEBUG:
|
||||
if os.system(f"pip install -U {' '.join(_deps)}") != 0:
|
||||
raise SystemExit(1)
|
||||
|
||||
os.environ["BITSANDBYTES_NOWELCOME"] = str(1)
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from peft import get_peft_model
|
||||
from peft import prepare_model_for_int8_training
|
||||
|
||||
|
||||
if openllm.utils.pkg.pkg_version_info("peft")[:2] >= (0, 4):
|
||||
from peft import prepare_model_for_kbit_training
|
||||
else:
|
||||
from peft import prepare_model_for_int8_training as prepare_model_for_kbit_training
|
||||
|
||||
import transformers
|
||||
|
||||
@@ -75,7 +78,7 @@ def load_model(model_id: str) -> tuple[PeftModel, transformers.GPT2TokenizerFast
|
||||
model, tokenizer = opt.model, opt.tokenizer
|
||||
|
||||
# prep the model for int8 training
|
||||
model = prepare_model_for_int8_training(model)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
|
||||
@@ -103,6 +103,8 @@ _LOGGING_CONFIG["loggers"].update(
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure logging for OpenLLM. Behaves similar to how BentoML loggers
|
||||
are being configured."""
|
||||
if get_quiet_mode():
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
|
||||
@@ -136,17 +138,22 @@ _import_structure = {
|
||||
"analytics": [],
|
||||
"codegen": [],
|
||||
"dantic": [],
|
||||
"constants": [],
|
||||
"representation": ["ReprMixin"],
|
||||
"import_utils": [
|
||||
"OPTIONAL_DEPENDENCIES",
|
||||
"ENV_VARS_TRUE_VALUES",
|
||||
"DummyMetaclass",
|
||||
"ModelEnv",
|
||||
"requires_dependencies",
|
||||
"is_cpm_kernels_available",
|
||||
"is_einops_available",
|
||||
"is_flax_available",
|
||||
"is_tf_available",
|
||||
"is_torch_available",
|
||||
"is_bitsandbytes_available",
|
||||
"is_peft_available",
|
||||
"is_datasets_available",
|
||||
"is_transformers_supports_kbit",
|
||||
"is_transformers_supports_agent",
|
||||
"require_backends",
|
||||
@@ -161,6 +168,7 @@ if t.TYPE_CHECKING:
|
||||
from . import bentoml_cattr as bentoml_cattr
|
||||
from . import codegen as codegen
|
||||
from . import configure_logging as configure_logging
|
||||
from . import constants as constants
|
||||
from . import copy_file_to_fs_folder as copy_file_to_fs_folder
|
||||
from . import dantic as dantic
|
||||
from . import first_not_none as first_not_none
|
||||
@@ -180,14 +188,18 @@ if t.TYPE_CHECKING:
|
||||
from .import_utils import ModelEnv as ModelEnv
|
||||
from .import_utils import is_bitsandbytes_available as is_bitsandbytes_available
|
||||
from .import_utils import is_cpm_kernels_available as is_cpm_kernels_available
|
||||
from .import_utils import is_datasets_available as is_datasets_available
|
||||
from .import_utils import is_einops_available as is_einops_available
|
||||
from .import_utils import is_flax_available as is_flax_available
|
||||
from .import_utils import is_peft_available as is_peft_available
|
||||
from .import_utils import is_tf_available as is_tf_available
|
||||
from .import_utils import is_torch_available as is_torch_available
|
||||
from .import_utils import is_transformers_supports_agent as is_transformers_supports_agent
|
||||
from .import_utils import is_transformers_supports_kbit as is_transformers_supports_kbit
|
||||
from .import_utils import require_backends as require_backends
|
||||
from .import_utils import requires_dependencies as requires_dependencies
|
||||
from .lazy import LazyModule as LazyModule
|
||||
from .representation import ReprMixin as ReprMixin
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -15,10 +15,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import string
|
||||
import typing as t
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
@@ -33,12 +36,13 @@ else:
|
||||
|
||||
DictStrAny = dict
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_MODEL_NAME = "# openllm: model name"
|
||||
OPENLLM_MODEL_ID = "# openllm: model id"
|
||||
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
|
||||
|
||||
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
@@ -63,10 +67,16 @@ class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: t.LiteralString = "__model_id__"
|
||||
|
||||
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: t.LiteralString = "__model_adapter_map__"
|
||||
|
||||
|
||||
_service_file = Path(__file__).parent.parent / "_service.py"
|
||||
|
||||
|
||||
def write_service(model_name: str, model_id: str, target_path: str, llm_fs: FS):
|
||||
def write_service(
|
||||
model_name: str, model_id: str, adapter_map: dict[str, str | None] | None, target_path: str, llm_fs: FS
|
||||
):
|
||||
from . import DEBUG
|
||||
|
||||
logger.debug("Generating service for %s to %s", model_name, target_path)
|
||||
@@ -76,11 +86,22 @@ def write_service(model_name: str, model_id: str, target_path: str, llm_fs: FS):
|
||||
# modify with model name
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it:
|
||||
src_contents[src_contents.index(it)] = ModelNameFormatter(model_name).vformat(it)
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelNameFormatter(model_name).vformat(it)[: -(len(OPENLLM_MODEL_NAME) + 3)] + "\n"
|
||||
)
|
||||
elif OPENLLM_MODEL_ID in it:
|
||||
src_contents[src_contents.index(it)] = ModelIdFormatter(model_id).vformat(it)
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelIdFormatter(model_id).vformat(it)[: -(len(OPENLLM_MODEL_ID) + 3)] + "\n"
|
||||
)
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it and adapter_map:
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[
|
||||
: -(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)
|
||||
]
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n" + "".join(src_contents)
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
|
||||
|
||||
if DEBUG:
|
||||
logger.info("Generated script:\n%s", script)
|
||||
@@ -192,7 +213,7 @@ def generate_function(
|
||||
if annotations:
|
||||
meth.__annotations__ = annotations
|
||||
|
||||
if DEBUG:
|
||||
if DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3:
|
||||
logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
|
||||
return meth
|
||||
|
||||
@@ -31,6 +31,13 @@ from click import ParamType
|
||||
import openllm
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import _ValidatorType
|
||||
|
||||
@@ -42,7 +49,7 @@ if t.TYPE_CHECKING:
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
@@ -53,7 +60,7 @@ def attrs_to_options(
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
def attrs_to_options( # type: ignore (overlapping overload)
|
||||
name: str,
|
||||
field: attr.Attribute[O_co],
|
||||
|
||||
@@ -17,6 +17,7 @@ Some imports utils are vendorred from transformers/utils/import_utils.py for per
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
@@ -32,9 +33,19 @@ from packaging import version
|
||||
from bentoml._internal.utils import LazyLoader
|
||||
from bentoml._internal.utils import pkg
|
||||
|
||||
from .representation import ReprMixin
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
|
||||
from .._types import P
|
||||
else:
|
||||
BackendOrderredDict = OrderedDict
|
||||
|
||||
@@ -64,9 +75,12 @@ def _is_package_available(package: str) -> bool:
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
|
||||
_peft_available = _is_package_available("peft")
|
||||
_einops_available = _is_package_available("einops")
|
||||
_cpm_kernel_available = _is_package_available("cpm_kernels")
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_datasets_available = _is_package_available("datasets")
|
||||
|
||||
|
||||
def is_transformers_supports_kbit() -> bool:
|
||||
@@ -77,6 +91,14 @@ def is_transformers_supports_agent() -> bool:
|
||||
return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
|
||||
|
||||
|
||||
def is_datasets_available() -> bool:
|
||||
return _datasets_available
|
||||
|
||||
|
||||
def is_peft_available() -> bool:
|
||||
return _peft_available
|
||||
|
||||
|
||||
def is_einops_available():
|
||||
return _einops_available
|
||||
|
||||
@@ -157,6 +179,33 @@ def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
||||
def requires_dependencies(
|
||||
package: str | list[str], *, extra: str | list[str] | None = None
|
||||
) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
|
||||
import openllm.utils
|
||||
|
||||
if isinstance(package, str):
|
||||
package = [package]
|
||||
if isinstance(extra, str):
|
||||
extra = [extra]
|
||||
|
||||
def decorator(func: t.Callable[P, t.Any]):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
for p in package:
|
||||
cached_check: t.Callable[[], bool] | None = getattr(openllm.utils, f"is_{p}_available", None)
|
||||
if not ((cached_check is not None and cached_check()) or _is_package_available(p)):
|
||||
raise ImportError(
|
||||
f"{func.__name__} requires '{p}' to be available locally (Currently missing)."
|
||||
f"Make sure to have {p} to be installed: 'pip install \"{p if not extra else 'openllm['+', '.join(extra)+']'}\"'"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
PYTORCH_IMPORT_ERROR_WITH_TF = """\
|
||||
{0} requires the PyTorch library but it was not found in your environment.
|
||||
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
||||
@@ -249,19 +298,44 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]):
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
|
||||
class ModelEnv:
|
||||
class ModelEnv(ReprMixin):
|
||||
model_name: str
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
config: property
|
||||
model_id: property
|
||||
quantize: property
|
||||
framework: property
|
||||
bettertransformer: property
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
return {"config", "model_id", "quantize", "framework", "bettertransformer"}
|
||||
|
||||
framework_value: property
|
||||
quantize_value: property
|
||||
bettertransformer_value: property
|
||||
if t.TYPE_CHECKING:
|
||||
config: str
|
||||
model_id: str
|
||||
quantize: str
|
||||
framework: str
|
||||
bettertransformer: str
|
||||
|
||||
framework_value: t.Literal["pt", "tf", "flax"]
|
||||
quantize_value: str | None
|
||||
bettertransformer_value: str | None
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["config"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_id"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["quantize"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["framework"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['framework_value']) -> t.Literal['pt', 'tf', 'flax']: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['quantize_value']) -> str | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['bettertransformer_value']) -> str | None: ...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def __getitem__(self, item: str | t.Any) -> t.Any:
|
||||
if hasattr(self, item):
|
||||
@@ -284,7 +358,7 @@ class ModelEnv:
|
||||
|
||||
# gen properties env value
|
||||
attributes_with_values = {
|
||||
"quantize": (bool, quantize),
|
||||
"quantize": (str, quantize),
|
||||
"bettertransformer": (bool, bettertransformer),
|
||||
"framework": (str, "pt"),
|
||||
}
|
||||
|
||||
61
src/openllm/utils/representation.py
Normal file
61
src/openllm/utils/representation.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from abc import abstractmethod
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
|
||||
|
||||
|
||||
class ReprMixin:
|
||||
"""This class display possible representation of given class.
|
||||
It can be used for implementing __rich_pretty__ and __pretty__ methods in the future.
|
||||
Most subclass needs to implement a __repr_keys__ property.
|
||||
|
||||
Based on the design from Pydantic.
|
||||
The __repr__ will display the json representation of the object for easier interaction.
|
||||
The __str__ will display either __attrs_repr__ or __repr_str__.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
"""This can be overriden by base class using this mixin."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
from . import bentoml_cattr
|
||||
|
||||
serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}
|
||||
return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr_str__(" ")
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f"{a}={repr(v)}" for a, v in self.__repr_args__())
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
attrs = ((k, getattr(self, k)) for k in self.__repr_keys__)
|
||||
return tuple((k, v) for k, v in attrs if v)
|
||||
@@ -26,6 +26,13 @@ import openllm
|
||||
import transformers
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm.models.auto.factory import _BaseAutoLLMClass
|
||||
|
||||
@@ -162,11 +169,11 @@ class BaseClient(ClientMixin):
|
||||
def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
@@ -222,11 +229,11 @@ class BaseAsyncClient(ClientMixin):
|
||||
async def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
@overload
|
||||
async def query(
|
||||
self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any
|
||||
) -> dict[str, t.Any]:
|
||||
|
||||
@@ -4,9 +4,11 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
|
||||
|
||||
md = MarkdownIt()
|
||||
|
||||
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -16,7 +18,17 @@ with open(os.path.join(ROOT, "README.md"), "r") as f:
|
||||
# NOTE: Currently, we only have one table in README, which is the Model readme.
|
||||
table = [r for r in readme if r.type == "html_block" and r.content.startswith("<td><a")]
|
||||
|
||||
available = subprocess.check_output(["openllm", "models", "-o", "porcelain"]).strip().decode("utf-8").count("\n") + 1
|
||||
prev = os.environ.pop("OPENLLMDEVDEBUG", str(0))
|
||||
available = (
|
||||
subprocess.check_output(
|
||||
[sys.executable, "-m", "openllm", "models", "-o", "porcelain"],
|
||||
)
|
||||
.strip()
|
||||
.decode("utf-8")
|
||||
.count("\n")
|
||||
+ 1
|
||||
)
|
||||
os.environ["OPENLLMDEVDEBUG"] = prev
|
||||
|
||||
on_table = len(table) # NOTE: minus the header
|
||||
|
||||
|
||||
90
tools/update-config-stubs.py
Executable file
90
tools/update-config-stubs.py
Executable file
@@ -0,0 +1,90 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import openllm
|
||||
import importlib
|
||||
from openllm._configuration import ModelSettings
|
||||
|
||||
# currently we are assuming the indentatio level is 4 for comments
|
||||
START_COMMENT = f"# {os.path.basename(__file__)}: start\n"
|
||||
END_COMMENT = f"# {os.path.basename(__file__)}: stop\n"
|
||||
|
||||
_TARGET_FILE = Path(__file__).parent.parent / "src" / "openllm" / "_configuration.py"
|
||||
|
||||
_imported = importlib.import_module(ModelSettings.__module__)
|
||||
|
||||
|
||||
def process_annotations(annotations: str) -> str:
|
||||
if "NotRequired" in annotations:
|
||||
return annotations[len("NotRequired[") : -1]
|
||||
elif "Required" in annotations:
|
||||
return annotations[len("Required[") : -1]
|
||||
else:
|
||||
return annotations
|
||||
|
||||
|
||||
def main() -> int:
|
||||
transformed = {"fine_tune_strategies": "t.Dict[AdapterType, FineTuneConfig]"}
|
||||
with _TARGET_FILE.open("r") as f:
|
||||
processed = f.readlines()
|
||||
|
||||
start_idx, end_idx = processed.index(" " * 4 + START_COMMENT), processed.index(" " * 4 + END_COMMENT)
|
||||
|
||||
# convention to use t.TYPE_CHECKING
|
||||
lines = [" " * 4 + "if t.TYPE_CHECKING:\n"]
|
||||
for keys, ForwardRef in openllm.utils.codegen.get_annotations(ModelSettings).items():
|
||||
lines.extend(
|
||||
list(
|
||||
map(
|
||||
lambda line: " " * 8 + line,
|
||||
[
|
||||
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
|
||||
f'def __getitem__(self, item: t.Literal["{keys}"] = ...) -> {transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n',
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
# special case variables: generation_class, extras
|
||||
lines.extend(
|
||||
list(
|
||||
map(
|
||||
lambda line: " " * 8 + line,
|
||||
[
|
||||
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
|
||||
'def __getitem__(self, item: t.Literal["generation_class"] = ...) -> t.Type[GenerationConfig]: ...\n',
|
||||
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
|
||||
'def __getitem__(self, item: t.Literal["extras"] = ...) -> t.Dict[str, t.Any]: ...\n',
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
processed = (
|
||||
processed[:start_idx] + [" " * 4 + START_COMMENT] + lines + [" " * 4 + END_COMMENT] + processed[end_idx + 1 :]
|
||||
)
|
||||
with _TARGET_FILE.open("w") as f:
|
||||
f.writelines(processed)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -85,6 +85,7 @@ _NIGHTLY_MAPPING: dict[str, Dependencies] = {
|
||||
"optimum": Dependencies.from_tuple("optimum", "huggingface/optimum", "main", None),
|
||||
"accelerate": Dependencies.from_tuple("accelerate", "huggingface/accelerate", "main", None),
|
||||
"bitsandbytes": Dependencies.from_tuple("bitsandbytes", "TimDettmers/bitsandbytes", "main", None),
|
||||
"deepspeed": Dependencies.from_tuple("deepspeed", "microsoft/deepspeed", "master", None),
|
||||
}
|
||||
|
||||
FINE_TUNE_DEPS = ["peft", "bitsandbytes", "datasets", "accelerate", "deepspeed"]
|
||||
@@ -125,8 +126,8 @@ def main() -> int:
|
||||
with open(os.path.join(ROOT, "pyproject.toml"), "w") as f:
|
||||
f.write(tomlkit.dumps(pyproject))
|
||||
|
||||
with open(os.path.join(ROOT, "nightly-requirements.generated.txt"), "w") as f:
|
||||
f.write("# This file is generated by `./tools/update-optional-dependencies.py`\n# DO NOT EDIT\n-e .\n")
|
||||
with open(os.path.join(ROOT, "nightly-requirements.txt"), "w") as f:
|
||||
f.write("# This file is generated by `./tools/update-optional-dependencies.py`\n# DO NOT EDIT\n-e .[all]\n")
|
||||
f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values()])
|
||||
|
||||
if shutil.which("taplo"):
|
||||
|
||||
@@ -493,7 +493,7 @@ dataclass = ...
|
||||
|
||||
def _make_init(
|
||||
cls: type[AttrsInstance],
|
||||
attrs: tuple[Attribute[_T]],
|
||||
attrs: tuple[Attribute[Any]],
|
||||
pre_init: bool,
|
||||
post_init: bool,
|
||||
frozen: bool,
|
||||
|
||||
@@ -99,7 +99,7 @@ class _OptGroup(Generic[O_co]):
|
||||
:param attrs: Additional parameters of option group class
|
||||
"""
|
||||
...
|
||||
def option(self, *param_decls: Any, **attrs: Any) -> F[P, ClickFunctionWrapper[P, O_co]]:
|
||||
def option(self, *param_decls: Any, **attrs: Any) -> FC:
|
||||
"""The decorator adds a new option to the group
|
||||
|
||||
The decorator is lazy. It adds option decls and attrs.
|
||||
|
||||
@@ -21,6 +21,6 @@ class Merger:
|
||||
fallback_strategies: List[str],
|
||||
type_conflict_strategies: List[str],
|
||||
) -> None: ...
|
||||
def merge(self, base: ConfigDictType, nxt: ConfigDictType) -> None: ...
|
||||
def merge(self, base: ConfigDictType, nxt: ConfigDictType) -> ConfigDictType: ...
|
||||
def type_conflict_strategy(self, *args: Any) -> Any: ...
|
||||
def value_strategy(self, path: str, base: StrategyList, nxt: StrategyList) -> None: ...
|
||||
|
||||
Reference in New Issue
Block a user