feat(infra): add tools for managing optional-dependencies

based on llm config

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-06-08 08:57:19 -04:00
parent 23d98a2729
commit c0418b76ec
13 changed files with 98 additions and 21 deletions

View File

@@ -65,13 +65,15 @@ name = "openllm"
readme = "README.md"
requires-python = ">=3.8"
# NOTE: Don't modify project.optional-dependencies
# as it is managed by ./tools/update-optional-dependencies.py
[project.optional-dependencies]
all = ['openllm[fine-tune]', 'openllm[chatglm]', 'openllm[falcon]', 'openllm[flan-t5]', 'openllm[starcoder]']
chatglm = ['cpm_kernels', 'sentencepiece']
falcon = ['einops', 'xformers', 'safetensors']
fine-tune = ['peft', 'bitsandbytes', 'datasets', 'accelerate']
flan-t5 = ['flax', 'jax', 'jaxlib', 'tensorflow']
starcoder = ['bitsandbytes']
all = ["openllm[fine-tune]", "openllm[flan-t5]", "openllm[chatglm]", "openllm[starcoder]", "openllm[falcon]"]
chatglm = ["cpm_kernels", "sentencepiece"]
falcon = ["einops", "xformers", "safetensors"]
fine-tune = ["peft", "bitsandbytes", "datasets", "accelerate"]
flan-t5 = ["flax", "jax", "jaxlib", "tensorflow"]
starcoder = ["bitandbytes"]
[project.urls]
Documentation = "https://github.com/llmsys/openllm#readme"

View File

@@ -652,6 +652,10 @@ class LLMConfig:
__openllm_url__: str = Field(None, init=False)
"""The resolved url for this LLMConfig."""
__openllm_requirements__: list[str] | None = None
"""The default PyPI requirements needed to run this given LLM. By default, we will depend on
bentoml, torch, transformers."""
GenerationConfig: type = type
"""Users can override this subclass of any given LLMConfig to provide GenerationConfig
default value. For example:
@@ -682,6 +686,7 @@ class LLMConfig:
trust_remote_code: bool = False,
requires_gpu: bool = False,
url: str | None = None,
requirements: list[str] | None = None,
):
if name_type == "dasherize":
model_name = inflection.underscore(cls.__name__.replace("Config", ""))
@@ -699,6 +704,7 @@ class LLMConfig:
cls.__openllm_start_name__ = start_name
cls.__openllm_env__ = openllm.utils.ModelEnv(model_name)
cls.__openllm_url__ = url or "(not set)"
cls.__openllm_requirements__ = requirements
# NOTE: Since we want to enable a pydantic-like experience
# this means we will have to hide the attr abstraction, and generate

View File

@@ -173,7 +173,6 @@ _reserved_namespace = _required_namespace | {
"model",
"tokenizer",
"import_kwargs",
"requirements",
}
@@ -199,10 +198,6 @@ class LLMInterface(ABC):
"""The default import kwargs to used when importing the model.
This will be passed into 'openllm.LLM.import_model'."""
requirements: list[str] | None = None
"""The default PyPI requirements needed to run this given LLM. By default, we will depend on
bentoml, torch, transformers."""
@abstractmethod
def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any:
"""The main function implementation for generating from given prompt. It takes the prompt

View File

@@ -72,8 +72,8 @@ def construct_python_options(llm: openllm.LLM, llm_fs: FS) -> PythonOptions:
packages: list[str] = []
ModelEnv = openllm.utils.ModelEnv(llm.__openllm_start_name__)
if llm.requirements is not None:
packages.extend(llm.requirements)
if llm.config.__openllm_requirements__ is not None:
packages.extend(llm.config.__openllm_requirements__)
if not (str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false"):
packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")

View File

@@ -422,8 +422,11 @@ def start_model_command(
}
)
if llm.requirements is not None:
_echo(f"Make sure to have the following dependencies available: {llm.requirements}", fg="yellow")
if llm.config.__openllm_requirements__ is not None:
_echo(
f"Make sure to have the following dependencies available: {llm.config.__openllm_requirements__}",
fg="yellow",
)
if t.TYPE_CHECKING:
server_cls: type[bentoml.HTTPServer] if not _serve_grpc else type[bentoml.GrpcServer]

View File

@@ -23,6 +23,7 @@ class ChatGLMConfig(
default_timeout=3600000,
requires_gpu=True,
url="https://github.com/THUDM/ChatGLM-6B",
requirements=["cpm_kernels", "sentencepiece"],
):
"""
ChatGLM is an open bilingual language model based on

View File

@@ -64,8 +64,6 @@ class ChatGLM(openllm.LLM):
default_model = "THUDM/chatglm-6b-int4"
requirements = ["cpm_kernels", "sentencepiece"]
pretrained = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"]
device = torch.device("cuda")

View File

@@ -23,6 +23,7 @@ class FalconConfig(
requires_gpu=True,
default_timeout=3600000,
url="https://falconllm.tii.ae/",
requirements=["einops", "xformers", "safetensors"],
):
"""Falcon-7B is a 7B parameters causal decoder-only model built by
TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)

View File

@@ -36,8 +36,6 @@ class Falcon(openllm.LLM):
default_model = "tiiuae/falcon-7b"
requirements = ["einops", "xformers", "safetensors"]
pretrained = ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct"]
import_kwargs = {"torch_dtype": torch.bfloat16, "device_map": "auto"}

View File

@@ -21,6 +21,7 @@ class StarCoderConfig(
name_type="lowercase",
requires_gpu=True,
url="https://github.com/bigcode-project/starcoder",
requirements=["bitandbytes"],
):
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from
[The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.

View File

@@ -44,8 +44,6 @@ class StarCoder(openllm.LLM):
default_model = "bigcode/starcoder"
requirements = ["bitandbytes"]
pretrained = ["bigcode/starcoder", "bigcode/starcoderbase"]
device = torch.device("cuda")

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import shutil
import inflection
import tomlkit
import openllm
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
FINE_TUNE_DEPS = ["peft", "bitsandbytes", "datasets", "accelerate"]
FLAN_T5_DEPS = ["flax", "jax", "jaxlib", "tensorflow"]
def main() -> int:
with open(os.path.join(ROOT, "pyproject.toml"), "r") as f:
pyproject = tomlkit.parse(f.read())
table = tomlkit.table()
table.add("fine-tune", FINE_TUNE_DEPS)
for name, config in openllm.CONFIG_MAPPING.items():
dashed = inflection.dasherize(name)
if name == "flan_t5":
table.add(dashed, FLAN_T5_DEPS)
continue
if config.__openllm_requirements__:
table.add(dashed, config.__openllm_requirements__)
table.add("all", [f"openllm[{k}]" for k in table.keys()])
pyproject["project"]["optional-dependencies"] = table
with open(os.path.join(ROOT, "pyproject.toml"), "w") as f:
f.write(tomlkit.dumps(pyproject))
if shutil.which("taplo"):
return os.system(f"taplo fmt {os.path.join(ROOT, 'pyproject.toml')}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,4 +1,18 @@
#!/usr/bin/env python3
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations