infra: enable compiled wheels for all supported Python (#201)

This commit is contained in:
Aaron Pham
2023-08-12 04:54:50 -04:00
committed by GitHub
parent dc776e9c5a
commit f6317d8003
124 changed files with 1086 additions and 2048 deletions

View File

@@ -32,7 +32,7 @@ concurrency:
cancel-in-progress: true
jobs:
pure-wheels-sdist:
name: Building pure wheels and sdist
name: Pure wheels and sdist distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
@@ -50,31 +50,26 @@ jobs:
path: dist/*
if-no-files-found: error
mypyc:
name: Building compiled mypyc wheels (${{ matrix.name }})
name: Compiled mypyc wheels (${{ matrix.name }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
# NOTE: ubuntu x86
- os: ubuntu-latest
name: linux-x86_64
python: 311
# NOTE: darwin amd64
- os: macos-latest
name: macos-x86_64
macos_arch: "x86_64"
python: 311
# NOTE: darwin arm64
- os: macos-latest
name: macos-arm64
macos_arch: "arm64"
python: 311
# NOTE: darwin universal2
- os: macos-latest
name: macos-universal2
macos_arch: "universal2"
python: 311
steps:
- uses: actions/checkout@v3
with:
@@ -85,9 +80,9 @@ jobs:
with:
python-version: 3.9
- name: Build wheels via cibuildwheel
uses: pypa/cibuildwheel@v2.14.1
uses: pypa/cibuildwheel@v2.15.0
env:
CIBW_BUILD: cp${{ matrix.python }}-*
CIBW_BEFORE_BUILD_MACOS: "rustup target add aarch64-apple-darwin"
CIBW_ARCHS_MACOS: "${{ matrix.macos_arch }}"
MYPYPATH: /project/typings
- name: Upload wheels as workflow artifacts

View File

@@ -19,7 +19,7 @@ All the relevant code for incorporating a new model resides within
`src/openllm/models/{model_name}/__init__.py`
- [ ] Adjust the entrypoints for files at `src/openllm/models/auto/*`
- [ ] Modify the main `__init__.py`: `src/openllm/models/__init__.py`
- [ ] Run the following script to update dummy objects: `hatch run update-dummy`
- [ ] Run the following to update stubs: `hatch run check-stubs`
For a working example, check out any pre-implemented model.

View File

@@ -87,12 +87,11 @@ _If you have any suggestions, feel free to give it on our discord server!_
def foo(x): return rotate_cv(x) if x > 0 else -x
```
- imports should be grouped by their types, and each import should be designated
on its own line.
- imports should be grouped by their types: standard library, third-party, and local
```python
import os
import sys
import os, sys
import orjson, bentoml
```
This is partially to make it easier to work with merge-conflicts, and easier
for IDE to navigate context definition.

View File

@@ -0,0 +1 @@
Added all compiled wheels for all supported Python version for Linux and MacOS

View File

@@ -75,25 +75,21 @@ dependencies = [
]
[envs.default.scripts]
changelog = "towncrier build --version main --draft"
check-stubs = [
"./tools/update-config-stubs.py",
"./tools/update-models-import.py",
"./tools/update-init-import.py",
"update-dummy",
]
check-stubs = ["./tools/update-config-stubs.py", "./tools/update-models-import.py", "update-dummy"]
compile = "bash ./compile.sh {args}"
update-dummy = ["- ./tools/update-dummy.py", "./tools/update-dummy.py"]
inplace-changelog = "towncrier build --version main --keep"
quality = [
"./tools/dependencies.py",
"./tools/update-readme.py",
"- ./tools/update-brew-tap.py",
"check-stubs",
"pre-commit run --all-files",
"- pre-commit run --all-files",
]
recompile = ["bash ./clean.sh", "compile"]
setup = "pre-commit install"
tool = ["quality", "recompile -nx"]
typing = ["- pre-commit run mypy {args:-a}", "- pre-commit run pyright {args:-a}"]
update-dummy = ["- ./tools/update-dummy.py", "./tools/update-dummy.py"]
[envs.tests]
dependencies = [
# NOTE: interact with docker for container tests.

View File

@@ -44,6 +44,7 @@ dependencies = [
"httpx",
"click>=8.1.3",
"typing_extensions",
"mypy_extensions",
"ghapi",
"cuda-python;platform_system!=\"Darwin\"",
"bitsandbytes<0.42",
@@ -124,12 +125,12 @@ vllm = ["vllm", "ray"]
[tool.cibuildwheel]
build-verbosity = 1
# So the following enviuronment will be targeted for compiled wheels:
# - Python: CPython 3.8+ only
# So the following environment will be targeted for compiled wheels:
# - Python: CPython 3.8-3.11
# - Architecture (64-bit only): amd64 / x86_64, universal2, and arm64
# - OS: Linux (no musl), and macOS
build = "cp3*-*"
skip = ["*-manylinux_i686", "*-musllinux_*", "*-win32", "pp-*"]
skip = ["*-manylinux_i686", "*-musllinux_*", "*-win32", "pp-*", "cp312-*"]
[tool.cibuildwheel.environment]
HATCH_BUILD_HOOKS_ENABLE = "1"
@@ -183,6 +184,9 @@ fail-under = 100
verbose = 2
whitelist-regex = ["test_.*"]
[tool.check-wheel-contents]
toplevel = ["openllm"]
[tool.pytest.ini_options]
addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"]
python_files = ["test_*.py", "*_test.py"]
@@ -194,7 +198,6 @@ extend-exclude = [
"examples",
"src/openllm/playground",
"src/openllm/__init__.py",
"src/openllm/_types.py",
"src/openllm/_version.py",
"src/openllm/utils/dummy_*.py",
"src/openllm/models/__init__.py",
@@ -233,7 +236,9 @@ ignore = [
"PLR0915",
"PLR2004", # magic value to use constant
"E501", # ignore line length violation
"E401", # ignore multiple line import
"E702",
"I001", # unsorted imports
"PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs.
"D103", # Just missing docstring for magic methods.
"D102",
@@ -245,13 +250,10 @@ ignore = [
"D105", # magic docstring
"E701", # multiple statement on single line
]
line-length = 119
target-version = "py312"
unfixable = [
"F401", # Don't touch unused imports, just warn about it.
"TCH004", # Don't touch import outside of TYPE_CHECKING block
"RUF100", # unused noqa, just warn about it
]
line-length = 768
target-version = "py38"
typing-modules = ["openllm._typing_compat"]
unfixable = ["TCH004"]
[tool.ruff.flake8-type-checking]
exempt-modules = ["typing", "typing_extensions", "."]
runtime-evaluated-base-classes = [
@@ -260,7 +262,7 @@ runtime-evaluated-base-classes = [
"openllm._configuration.GenerationConfig",
"openllm._configuration.ModelSettings",
]
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen"]
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen", "trait"]
[tool.ruff.pydocstyle]
convention = "google"
[tool.ruff.pycodestyle]
@@ -271,7 +273,7 @@ force-single-line = false
force-wrap-aliases = true
known-first-party = ["openllm", "bentoml"]
known-third-party = ["transformers", "click", "huggingface_hub", "torch", "vllm", "auto_gptq"]
lines-after-imports = 1
lines-after-imports = 0
lines-between-types = 0
no-lines-before = ["future", "standard-library"]
relative-imports-order = "closest-to-furthest"
@@ -279,12 +281,11 @@ required-imports = ["from __future__ import annotations"]
[tool.ruff.flake8-quotes]
avoid-escape = false
[tool.ruff.extend-per-file-ignores]
"src/openllm/_service.py" = ["I001", "E401"]
"src/openllm/_service.py" = ["E401"]
"src/openllm/cli/entrypoint.py" = ["D301"]
"src/openllm/models/**" = ["I001", "E", "D", "F"]
"src/openllm/utils/__init__.py" = ["I001"]
"src/openllm/utils/import_utils.py" = ["PLW0603"]
"src/openllm/client/runtimes/*" = ["D107"]
"src/openllm/models/**" = ["E", "D", "F"]
"src/openllm/utils/import_utils.py" = ["PLW0603"]
"tests/**/*" = [
"S101",
"TID252",
@@ -348,6 +349,7 @@ omit = [
"src/openllm/__init__.py",
"src/openllm/__main__.py",
"src/openllm/utils/dummy_*.py",
"src/openllm/_typing_compat.py",
]
source_pkgs = ["openllm"]
[tool.coverage.report]
@@ -378,6 +380,7 @@ omit = [
"src/openllm/__init__.py",
"src/openllm/__main__.py",
"src/openllm/utils/dummy_*.py",
"src/openllm/_typing_compat.py",
]
precision = 2
show_missing = true
@@ -387,16 +390,17 @@ analysis.useLibraryCodeForTypes = true
exclude = [
"__pypackages__/*",
"src/openllm/playground/",
"src/openllm/models/",
"src/openllm/__init__.py",
"src/openllm/__main__.py",
"src/openllm/utils/dummy_*.py",
"src/openllm/models",
"src/openllm/_typing_compat.py",
"tools",
"examples",
"tests",
]
include = ["src/openllm"]
pythonVersion = "3.12"
pythonVersion = "3.8"
reportMissingImports = "warning"
reportMissingTypeStubs = false
reportPrivateUsage = "warning"
@@ -408,13 +412,16 @@ reportWildcardImportFromLibrary = "warning"
typeCheckingMode = "strict"
[tool.mypy]
# TODO: Enable model for strict type checking
exclude = ["src/openllm/playground/", "src/openllm/utils/dummy_*.py", "src/openllm/models"]
local_partial_types = true
exclude = [
"src/openllm/playground/",
"src/openllm/utils/dummy_*.py",
"src/openllm/models",
"src/openllm/_typing_compat.py",
]
modules = ["openllm"]
mypy_path = "typings"
pretty = true
python_version = "3.11"
python_version = "3.8"
show_error_codes = true
warn_no_return = false
warn_return_any = false
@@ -430,6 +437,7 @@ module = [
"optimum.*",
"inflection.*",
"huggingface_hub.*",
"click_option_group.*",
"peft.*",
"auto_gptq.*",
"vllm.*",
@@ -437,13 +445,13 @@ module = [
"httpx.*",
"cloudpickle.*",
"circus.*",
"grpc_health.*",
"grpc_health.v1.*",
"transformers.*",
"ghapi.*",
]
[[tool.mypy.overrides]]
ignore_errors = true
module = ["openllm.models.*", "openllm._types", "openllm.playground.*"]
module = ["openllm.models.*", "openllm.playground.*", "openllm._typing_compat"]
[tool.hatch.version]
fallback-version = "0.0.0"
@@ -476,24 +484,12 @@ dependencies = [
"types-protobuf",
]
enable-by-default = false
exclude = [
# no reason to compile model since framework is already written in C++
"src/openllm/models/**",
"src/openllm/bundle/**",
"src/openllm/cli/**",
"src/openllm/playground/**",
"src/openllm/__main__.py",
"src/openllm/_types.py",
"src/openllm/_llm.py",
"src/openllm/_service.py",
"src/openllm/_configuration.py",
# can't compile serialisation for transformers since it will raise segfault
"src/openllm/serialisation/transformers",
"src/openllm/utils/analytics.py",
]
include = [
"src/openllm/serialisation",
"src/openllm/bundle",
"src/openllm/models/__init__.py",
"src/openllm/models/auto/__init__.py",
"src/openllm/utils/__init__.py",
"src/openllm/utils/codegen.py",
"src/openllm/__init__.py",
"src/openllm/_prompt.py",
"src/openllm/_schema.py",
@@ -505,20 +501,18 @@ include = [
]
# NOTE: This is consistent with pyproject.toml
mypy-args = [
"--pretty",
"--strict",
# this is because all transient library doesn't have types
"--allow-subclassing-any",
"--check-untyped-defs",
"--follow-imports=skip",
"--python-version=3.11",
"--check-untyped-defs",
"--ignore-missing-imports",
"--no-warn-return-any",
"--warn-unreachable",
"--no-warn-no-return",
"--no-warn-unused-ignores",
"--exclude='/src\\/openllm\\/playground\\/**'",
"--exclude='/src\\/openllm\\/_types\\.py$'",
"--exclude='/src\\/openllm\\/_typing_compat\\.py$'",
]
options = { verbose = true, debug_level = "2", opt_level = "3" }
options = { verbose = true, strip_asserts = true, debug_level = "2", opt_level = "3", include_runtime_files = true }
require-runtime-dependencies = true

View File

@@ -9,12 +9,8 @@ deploy, and monitor any LLMs with ease.
* Native integration with BentoML and LangChain for custom LLM apps
"""
from __future__ import annotations
import logging as _logging
import os as _os
import typing as _t
import warnings as _warnings
import logging as _logging, os as _os, typing as _t, warnings as _warnings
from pathlib import Path as _Path
from . import exceptions as exceptions, utils as utils
if utils.DEBUG:
@@ -35,7 +31,6 @@ _import_structure: dict[str, list[str]] = {
"exceptions": [], "models": [], "client": [], "bundle": [], "playground": [], "testing": [], "utils": ["infer_auto_class"], "serialisation": ["ggml", "transformers"], "cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"], "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "LLMEmbeddings"], "_configuration": ["LLMConfig", "GenerationConfig", "SamplingParams"], "_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"], "_quantisation": ["infer_quantisation_config"], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs", "HfAgentInput"],
"models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": ["ChatGLMConfig"], "models.baichuan": ["BaichuanConfig"], "models.dolly_v2": ["DollyV2Config"], "models.falcon": ["FalconConfig"], "models.flan_t5": ["FlanT5Config"], "models.gpt_neox": ["GPTNeoXConfig"], "models.llama": ["LlamaConfig"], "models.mpt": ["MPTConfig"], "models.opt": ["OPTConfig"], "models.stablelm": ["StableLMConfig"], "models.starcoder": ["StarCoderConfig"]
}
COMPILED = _Path(__file__).suffix in (".pyd", ".so")
if _t.TYPE_CHECKING:
@@ -61,25 +56,38 @@ if _t.TYPE_CHECKING:
from .serialisation import ggml as ggml, transformers as transformers
from openllm.utils import infer_auto_class as infer_auto_class
try:
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise exceptions.MissingDependencyError
except exceptions.MissingDependencyError:
_import_structure["utils.dummy_pt_objects"] = ["ChatGLM", "Baichuan"]
else:
_import_structure["models.chatglm"].extend(["ChatGLM"])
_import_structure["models.baichuan"].extend(["Baichuan"])
if _t.TYPE_CHECKING:
from .models.baichuan import Baichuan as Baichuan
from .models.chatglm import ChatGLM as ChatGLM
try:
if not (utils.is_torch_available() and utils.is_triton_available()): raise exceptions.MissingDependencyError
except exceptions.MissingDependencyError:
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["MPT"])
else: _import_structure["utils.dummy_pt_objects"] = ["MPT"]
else:
_import_structure["models.mpt"].extend(["MPT"])
if _t.TYPE_CHECKING: from .models.mpt import MPT as MPT
try:
if not (utils.is_torch_available() and utils.is_einops_available()): raise exceptions.MissingDependencyError
except exceptions.MissingDependencyError:
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["Falcon"])
else: _import_structure["utils.dummy_pt_objects"] = ["Falcon"]
else:
_import_structure["models.falcon"].extend(["Falcon"])
if _t.TYPE_CHECKING: from .models.falcon import Falcon as Falcon
try:
if not utils.is_torch_available(): raise exceptions.MissingDependencyError
except exceptions.MissingDependencyError:
_import_structure["utils.dummy_pt_objects"] = utils.dummy_pt_objects.__all__
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(utils.dummy_pt_objects) if not name.startswith("_") and name not in ("ChatGLM", "Baichuan", "MPT", "Falcon", "annotations")]
else:
if utils.is_cpm_kernels_available():
_import_structure["models.chatglm"].extend(["ChatGLM"])
_import_structure["models.baichuan"].extend(["Baichuan"])
if _t.TYPE_CHECKING:
from .models.baichuan import Baichuan as Baichuan
from .models.chatglm import ChatGLM as ChatGLM
if utils.is_einops_available():
_import_structure["models.falcon"].extend(["Falcon"])
if _t.TYPE_CHECKING:
from .models.falcon import Falcon as Falcon
if utils.is_triton_available():
_import_structure["models.mpt"].extend(["MPT"])
if _t.TYPE_CHECKING:
from .models.mpt import MPT as MPT
_import_structure["models.flan_t5"].extend(["FlanT5"])
_import_structure["models.dolly_v2"].extend(["DollyV2"])
_import_structure["models.starcoder"].extend(["StarCoder"])
@@ -100,7 +108,7 @@ else:
try:
if not utils.is_vllm_available(): raise exceptions.MissingDependencyError
except exceptions.MissingDependencyError:
_import_structure["utils.dummy_vllm_objects"] = utils.dummy_vllm_objects.__all__
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(utils.dummy_vllm_objects) if not name.startswith("_") and name not in ("annotations",)]
else:
_import_structure["models.baichuan"].extend(["VLLMBaichuan"])
_import_structure["models.llama"].extend(["VLLMLlama"])
@@ -124,7 +132,7 @@ else:
try:
if not utils.is_flax_available(): raise exceptions.MissingDependencyError
except exceptions.MissingDependencyError:
_import_structure["utils.dummy_flax_objects"] = utils.dummy_flax_objects.__all__
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(utils.dummy_flax_objects) if not name.startswith("_") and name not in ("annotations",)]
else:
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
_import_structure["models.opt"].extend(["FlaxOPT"])
@@ -136,7 +144,7 @@ else:
try:
if not utils.is_tf_available(): raise exceptions.MissingDependencyError
except exceptions.MissingDependencyError:
_import_structure["utils.dummy_tf_objects"] = utils.dummy_tf_objects.__all__
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(utils.dummy_tf_objects) if not name.startswith("_") and name not in ("annotations",)]
else:
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
_import_structure["models.opt"].extend(["TFOPT"])

View File

@@ -1,3 +1,4 @@
# mypy: disable-error-code="attr-defined,no-untyped-call,type-var,operator,arg-type,no-redef"
"""Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``.
Highlight feature: Each fields in ``openllm.LLMConfig`` will also automatically generate a environment
@@ -33,29 +34,16 @@ dynamically during serve, ahead-of-serve or per requests.
Refer to ``openllm.LLMConfig`` docstring for more information.
"""
from __future__ import annotations
import copy
import enum
import logging
import os
import sys
import types
import typing as t
import attr
import click_option_group as cog
import inflection
import orjson
import copy, enum, logging, os, sys, types, typing as t
import attr, click_option_group as cog, inflection, orjson, openllm
from cattr.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
from deepmerge.merger import Merger
import openllm
from ._strategies import LiteralResourceSpec, available_resource_spec, resource_spec
from ._typing_compat import LiteralString, NotRequired, Required, overload, AdapterType, LiteralRuntime
from .exceptions import ForbiddenAttributeError
from .utils import (
ENV_VARS_TRUE_VALUES,
MYPY,
LazyType,
ReprMixin,
bentoml_cattr,
codegen,
@@ -63,43 +51,18 @@ from .utils import (
field_env_key,
first_not_none,
lenient_issubclass,
non_intrusive_setattr,
)
from .utils.import_utils import BACKENDS_MAPPING
# NOTE: We need to do check overload import
# so that it can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import NotRequired, Required, dataclass_transform, overload
else:
from typing_extensions import NotRequired, Required, dataclass_transform, overload
# NOTE: Using internal API from attr here, since we are actually
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
# NOTE: Using internal API from attr here, since we are actually allowing subclass of openllm.LLMConfig to become 'attrs'-ish
from attr._compat import set_closure_cell
from attr._make import _CountingAttr, _make_init, _transform_attrs
_T = t.TypeVar("_T")
LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"]
from ._typing_compat import AnyCallable, At, Self, ListStr, DictStrAny
if t.TYPE_CHECKING:
import click
import peft
import transformers
import vllm
import click, peft, transformers, vllm
from transformers.generation.beam_constraints import Constraint
from ._types import AnyCallable, At
DictStrAny = dict[str, t.Any]
ListStr = list[str]
else:
Constraint = t.Any
ListStr = list
DictStrAny = dict
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
peft = openllm.utils.LazyLoader("peft", globals(), "peft")
@@ -107,7 +70,7 @@ else:
__all__ = ["LLMConfig", "GenerationConfig", "SamplingParams"]
logger = logging.getLogger(__name__)
config_merger = Merger([(DictStrAny, "merge")], ["override"], ["override"])
config_merger = Merger([(dict, "merge")], ["override"], ["override"])
# case insensitive, but rename to conform with type
class _PeftEnumMeta(enum.EnumMeta):
@@ -141,8 +104,6 @@ class PeftType(str, enum.Enum, metaclass=_PeftEnumMeta):
_PEFT_TASK_TYPE_TARGET_MAPPING = {"causal_lm": "CAUSAL_LM", "seq2seq_lm": "SEQ_2_SEQ_LM"}
AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning", "ia3"]
_object_setattr = object.__setattr__
def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType:
@@ -263,8 +224,7 @@ class GenerationConfig(ReprMixin):
if t.TYPE_CHECKING and not MYPY:
# stubs this for pyright as mypy already has a attr plugin builtin
def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None:
...
def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: ...
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
if not _internal: raise RuntimeError("GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config")
@@ -322,7 +282,7 @@ class SamplingParams(ReprMixin):
def to_vllm(self) -> vllm.SamplingParams: return vllm.SamplingParams(max_tokens=self.max_tokens, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, **bentoml_cattr.unstructure(self))
@classmethod
def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> t.Self:
def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> Self:
"""The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``."""
stop = attrs.pop("stop", None)
if stop is not None and isinstance(stop, str): stop = [stop]
@@ -480,7 +440,7 @@ bentoml_cattr.register_structure_hook(_ModelSettingsAttr, structure_settings)
def _setattr_class(attr_name: str, value_var: t.Any) -> str: return f"setattr(cls, '{attr_name}', {value_var})"
def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: t.LiteralString = "openllm") -> t.Callable[..., None]:
def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: LiteralString = "openllm") -> t.Callable[..., None]:
"""Generate the assignment script with prefix attributes __openllm_<value>__."""
args: ListStr = []
globs: DictStrAny = {"cls": cls, "_cached_attribute": attributes, "_cached_getattribute_get": _object_getattribute.__get__}
@@ -497,11 +457,6 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance
_reserved_namespace = {"__config__", "GenerationConfig", "SamplingParams"}
@dataclass_transform(kw_only_default=True, order_default=True, field_specifiers=(attr.field, dantic.Field))
def llm_config_transform(cls: type[LLMConfig]) -> type[LLMConfig]:
non_intrusive_setattr(cls, "__dataclass_transform__", {"order_default": True, "kw_only_default": True, "field_specifiers": (attr.field, dantic.Field)})
return cls
@attr.define(slots=True)
class _ConfigAttr:
Field = dantic.Field
@@ -669,7 +624,7 @@ class _ConfigBuilder:
__slots__ = ("_cls", "_cls_dict", "_attr_names", "_attrs", "_model_name", "_base_attr_map", "_base_names", "_has_pre_init", "_has_post_init")
def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr[t.Any]], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True):
def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True):
attrs, base_attrs, base_attr_map = _transform_attrs(cls, these, auto_attribs, kw_only, collect_by_mro, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__))
self._cls, self._model_name, self._cls_dict, self._attrs, self._base_names, self._base_attr_map = cls, cls.__openllm_model_name__, dict(cls.__dict__), attrs, {a.name for a in base_attrs}, base_attr_map
self._attr_names = tuple(a.name for a in attrs)
@@ -742,17 +697,16 @@ class _ConfigBuilder:
if not closure_cells: continue # Catch None or the empty list.
for cell in closure_cells:
try: match = cell.cell_contents is self._cls
except ValueError: pass # ValueError: Cell is empty
except ValueError: pass # noqa: PERF203 # ValueError: Cell is empty
else:
if match: set_closure_cell(cell, cls)
return cls
return llm_config_transform(cls)
def add_attrs_init(self) -> t.Self:
def add_attrs_init(self) -> Self:
self._cls_dict["__attrs_init__"] = codegen.add_method_dunders(self._cls, _make_init(self._cls, self._attrs, self._has_pre_init, self._has_post_init, False, True, False, self._base_attr_map, False, None, True))
return self
def add_repr(self) -> t.Self:
def add_repr(self) -> Self:
for key, fn in ReprMixin.__dict__.items():
if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"): self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config", "sampling_config"})
@@ -865,7 +819,7 @@ class LLMConfig(_ConfigAttr):
# auto assignment attributes generated from __config__ after create the new slot class.
_make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettingsAttr))(cls)
def _make_subclass(class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: t.LiteralString | None = None) -> type[At]:
def _make_subclass(class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: LiteralString | None = None) -> type[At]:
camel_name = cls.__name__.replace("Config", "")
klass = attr.make_class(f"{camel_name}{class_attr}", [], bases=(base,), slots=True, weakref_slot=True, frozen=True, repr=False, init=False, collect_by_mro=True, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__, suffix=suffix_env, globs=globs, default_callback=lambda field_name, field_default: getattr(getattr(cls, class_attr), field_name, field_default) if codegen.has_own_attribute(cls, class_attr) else field_default))
# For pickling to work, the __module__ variable needs to be set to the
@@ -882,19 +836,19 @@ class LLMConfig(_ConfigAttr):
anns = codegen.get_annotations(cls)
# _CountingAttr is the underlying representation of attr.field
ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)}
these: dict[str, _CountingAttr[t.Any]] = {}
these: dict[str, _CountingAttr] = {}
annotated_names: set[str] = set()
for attr_name, typ in anns.items():
if codegen.is_class_var(typ): continue
annotated_names.add(attr_name)
val = cd.get(attr_name, attr.NOTHING)
if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val):
if not isinstance(val, _CountingAttr):
if val is attr.NOTHING: val = cls.Field(env=field_env_key(cls.__openllm_model_name__, attr_name))
else: val = cls.Field(default=val, env=field_env_key(cls.__openllm_model_name__, attr_name))
these[attr_name] = val
unannotated = ca_names - annotated_names
if len(unannotated) > 0:
missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter)
missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr", cd.get(n)).counter)
raise openllm.exceptions.MissingAnnotationAttributeError(f"The following field doesn't have a type annotation: {missing_annotated}")
# We need to set the accepted key before generation_config
# as generation_config is a special field that users shouldn't pass.
@@ -1102,7 +1056,7 @@ class LLMConfig(_ConfigAttr):
def __getitem__(self, item: t.Literal["ia3"]) -> dict[str, t.Any]: ...
# update-config-stubs.py: stop
def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any:
def __getitem__(self, item: LiteralString | t.Any) -> t.Any:
"""Allowing access LLMConfig as a dictionary. The order will always evaluate as.
__openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__
@@ -1179,13 +1133,13 @@ class LLMConfig(_ConfigAttr):
def model_dump_json(self, **kwargs: t.Any) -> bytes: return orjson.dumps(self.model_dump(**kwargs))
@classmethod
def model_construct_json(cls, json_str: str | bytes) -> t.Self:
def model_construct_json(cls, json_str: str | bytes) -> Self:
try: attrs = orjson.loads(json_str)
except orjson.JSONDecodeError as err: raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") from None
return bentoml_cattr.structure(attrs, cls)
@classmethod
def model_construct_env(cls, **attrs: t.Any) -> t.Self:
def model_construct_env(cls, **attrs: t.Any) -> Self:
"""A helpers that respect configuration values environment variables."""
attrs = {k: v for k, v in attrs.items() if v is not None}
model_config = cls.__openllm_env__.config
@@ -1198,7 +1152,7 @@ class LLMConfig(_ConfigAttr):
if "generation_config" in attrs:
generation_config = attrs.pop("generation_config")
if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
if not isinstance(generation_config, dict): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
else: generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(cls.__openllm_generation_class__)}
for k in tuple(attrs.keys()):
@@ -1258,7 +1212,7 @@ class LLMConfig(_ConfigAttr):
total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set(attr.fields_dict(cls.__openllm_sampling_class__))
if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return f
if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return t.cast("click.Command", f)
# We pop out 'generation_config' as it is a attribute that we don't need to expose to CLI.
for name, field in attr.fields_dict(cls).items():
ty = cls.__openllm_hints__.get(name)
@@ -1285,13 +1239,13 @@ def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig:
Otherwise, we will filter out all keys are first in LLMConfig, parse it in, then
parse the remaining keys into LLMConfig.generation_config
"""
if not LazyType(DictStrAny).isinstance(data): raise RuntimeError(f"Expected a dictionary, but got {type(data)}")
if not isinstance(data, dict): raise RuntimeError(f"Expected a dictionary, but got {type(data)}")
cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_accepted_keys__}
generation_cls_fields = attr.fields_dict(cls.__openllm_generation_class__)
if "generation_config" in data:
generation_config = data.pop("generation_config")
if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
if not isinstance(generation_config, dict): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
config_merger.merge(generation_config, {k: v for k, v in data.items() if k in generation_cls_fields})
else:
generation_config = {k: v for k, v in data.items() if k in generation_cls_fields}

View File

@@ -1,17 +1,9 @@
"""Generation utilities to be reused throughout."""
from __future__ import annotations
import typing as t
import transformers
if t.TYPE_CHECKING:
import torch
import openllm
import typing as t, transformers
if t.TYPE_CHECKING: import torch, openllm
LogitsProcessorList = transformers.LogitsProcessorList
StoppingCriteriaList = transformers.StoppingCriteriaList
class StopSequenceCriteria(transformers.StoppingCriteria):
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast):
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]

View File

@@ -1,32 +1,13 @@
from __future__ import annotations
import collections
import functools
import inspect
import logging
import os
import re
import sys
import traceback
import types
import typing as t
import uuid
import functools, inspect, logging, os, re, traceback, types, typing as t, uuid
from abc import ABC, abstractmethod
from pathlib import Path
import attr
import fs.path
import inflection
import orjson
import attr, fs.path, inflection, orjson, bentoml, openllm
from huggingface_hub import hf_hub_download
import bentoml
import openllm
from bentoml._internal.models.model import ModelSignature
from ._configuration import (
AdapterType,
FineTuneConfig,
LiteralRuntime,
LLMConfig,
_object_getattribute,
_setattr_class,
@@ -57,54 +38,37 @@ from .utils import (
validate_is_path,
)
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import NotRequired, overload
else:
from typing_extensions import NotRequired, overload
from ._typing_compat import (
AdaptersMapping,
AdaptersTuple,
AnyCallable,
AdapterType,
LiteralRuntime,
DictStrAny,
ListStr,
LLMEmbeddings,
LLMRunnable,
LLMRunner,
ModelSignatureDict as _ModelSignatureDict,
PeftAdapterOutput,
TupleAny,
NotRequired, overload, M, T, LiteralString
)
if t.TYPE_CHECKING:
import auto_gptq as autogptq
import peft
import torch
import transformers
import vllm
import auto_gptq as autogptq, peft, torch, transformers, vllm
from ._configuration import PeftType
from ._types import (
AdaptersMapping,
AdaptersTuple,
AnyCallable,
DictStrAny,
ListStr,
LLMEmbeddings,
LLMRunnable,
LLMRunner,
ModelSignatureDict as _ModelSignatureDict,
PeftAdapterOutput,
TupleAny,
)
from .utils.representation import ReprArgs
UserDictAny = collections.UserDict[str, t.Any]
ResolvedAdaptersMapping = dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]]
else:
DictStrAny = dict
TupleAny = tuple
UserDictAny = collections.UserDict
LLMRunnable = bentoml.Runnable
LLMRunner = bentoml.Runner
LLMEmbeddings = dict
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
vllm = LazyLoader("vllm", globals(), "vllm")
transformers = LazyLoader("transformers", globals(), "transformers")
torch = LazyLoader("torch", globals(), "torch")
peft = LazyLoader("peft", globals(), "peft")
logger = logging.getLogger(__name__)
ResolvedAdaptersMapping = t.Dict[AdapterType, t.Dict[t.Union[str, t.Literal["default"]], t.Tuple["peft.PeftConfig", str]]]
logger = logging.getLogger(__name__)
class ModelSignatureDict(t.TypedDict, total=False):
batchable: bool
batch_dim: t.Union[t.Tuple[int, int], int]
@@ -150,9 +114,6 @@ def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapp
_reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"}
M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]")
T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]")
class LLMInterface(ABC, t.Generic[M, T]):
"""This defines the loose contract for all openllm.LLM implementations."""
@property
@@ -257,7 +218,7 @@ class LLMInterface(ABC, t.Generic[M, T]):
raise NotImplementedError
# NOTE: All fields below are attributes that can be accessed by users.
config_class: type[LLMConfig]
config_class: t.Type[LLMConfig]
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
bettertransformer: bool
"""Whether to load this LLM with FasterTransformer enabled. The order of loading is:
@@ -270,7 +231,7 @@ class LLMInterface(ABC, t.Generic[M, T]):
"""
device: "torch.device"
"""The device to be used for this LLM. If the implementation is 'pt', then it will be torch.device, else string."""
tokenizer_id: t.LiteralString | t.Literal["local"]
tokenizer_id: t.Union[t.Literal["local"], LiteralString]
"""optional tokenizer_id for loading with vLLM if the model supports vLLM."""
# NOTE: The following will be populated by __init_subclass__, note that these should be immutable.
__llm_trust_remote_code__: bool
@@ -290,13 +251,13 @@ class LLMInterface(ABC, t.Generic[M, T]):
An additional naming for all VLLM backend: VLLMLlama -> `vllm`
"""
__llm_model__: M | None
__llm_model__: t.Optional[M]
"""A reference to the actual model. Instead of access this directly, you should use `model` property instead."""
__llm_tokenizer__: T | None
__llm_tokenizer__: t.Optional[T]
"""A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead."""
__llm_bentomodel__: bentoml.Model | None
__llm_bentomodel__: t.Optional[bentoml.Model]
"""A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead."""
__llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None
__llm_adapter_map__: t.Optional[ResolvedAdaptersMapping]
"""A reference to the the cached LoRA adapter mapping."""
__llm_supports_embeddings__: bool
"""A boolean to determine whether models does implement ``LLM.embeddings``."""
@@ -307,11 +268,10 @@ class LLMInterface(ABC, t.Generic[M, T]):
__llm_supports_generate_iterator__: bool
"""A boolean to determine whether models does implement ``LLM.generate_iterator``."""
if t.TYPE_CHECKING and not MYPY:
def __attrs_init__(self, config: LLMConfig, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None, model_id: str, runtime: t.Literal["ggml", "transformers"], model_decls: TupleAny, model_attrs: DictStrAny, tokenizer_attrs: DictStrAny, tag: bentoml.Tag, adapters_mapping: AdaptersMapping | None, model_version: str | None, quantize_method: t.Literal["int8", "int4", "gptq"] | None, serialisation_format: t.Literal["safetensors", "legacy"], _local: bool, **attrs: t.Any) -> None:
def __attrs_init__(self, config: LLMConfig, quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, autogptq.BaseQuantizeConfig]], model_id: str, runtime: t.Literal["ggml", "transformers"], model_decls: TupleAny, model_attrs: DictStrAny, tokenizer_attrs: DictStrAny, tag: bentoml.Tag, adapters_mapping: t.Optional[AdaptersMapping], model_version: t.Optional[str], quantize_method: t.Optional[t.Literal["int8", "int4", "gptq"]], serialisation_format: t.Literal["safetensors", "legacy"], _local: bool, **attrs: t.Any) -> None:
"""Generated __attrs_init__ for openllm.LLM."""
_R = t.TypeVar("_R", covariant=True)
class _import_model_wrapper(t.Generic[_R, M, T], t.Protocol):
def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R: ...
class _load_model_wrapper(t.Generic[M, T], t.Protocol):
@@ -333,7 +293,7 @@ def _wrapped_import_model(f: _import_model_wrapper[bentoml.Model, M, T]) -> t.Ca
(model_decls, model_attrs), _ = self.llm_parameters
decls = (*model_decls, *decls)
attrs = {**model_attrs, **attrs}
return f(self, *decls, trust_remote_code=t.cast(bool, trust_remote_code), **attrs)
return f(self, *decls, trust_remote_code=trust_remote_code, **attrs)
return wrapper
_DEFAULT_TOKENIZER = "hf-internal-testing/llama-tokenizer"
@@ -483,7 +443,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
@overload
def __getitem__(self, item: t.Literal["bentomodel"]) -> bentoml.Model | None: ...
@overload
def __getitem__(self, item: t.Literal["adapter_map"]) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None: ...
def __getitem__(self, item: t.Literal["adapter_map"]) -> ResolvedAdaptersMapping | None: ...
@overload
def __getitem__(self, item: t.Literal["supports_embeddings"]) -> bool: ...
@overload
@@ -492,7 +452,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
def __getitem__(self, item: t.Literal["supports_generate_one"]) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["supports_generate_iterator"]) -> bool: ...
def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any:
def __getitem__(self, item: t.Union[LiteralString, t.Any]) -> t.Any:
if item is None: raise TypeError(f"{self} doesn't understand how to index None.")
item = inflection.underscore(item)
internal_attributes = f"__llm_{item}__"
@@ -575,7 +535,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
"""
cfg_cls = cls.config_class
_local = False
_model_id: str = t.cast(str, first_not_none(model_id, os.environ.get(cfg_cls.__openllm_env__["model_id"]), default=cfg_cls.__openllm_default_id__))
_model_id: str = first_not_none(model_id, os.environ.get(cfg_cls.__openllm_env__["model_id"]), default=cfg_cls.__openllm_default_id__)
if validate_is_path(_model_id): _model_id, _local = resolve_filepath(_model_id), True
quantize = first_not_none(quantize, t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.environ.get(cfg_cls.__openllm_env__["quantize"])), default=None)
@@ -889,7 +849,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
adapter_mapping = _mapping[adapter_type]
self.__llm_model__ = self._wrap_default_peft_model(adapter_mapping, inference_mode=inference_mode)
# now we loop through the rest with add_adapter
if len(adapter_mapping) > 0:
for adapter_name, (_peft_config, _) in adapter_mapping.items():
@@ -1108,11 +1067,10 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
def available_adapters(_: LLMRunner[M, T]) -> PeftAdapterOutput:
if not is_peft_available(): return {"success": False, "result": {}, "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'"}
if self.__llm_adapter_map__ is None: return {"success": False, "result": {}, "error_msg": "No adapters available for current running server."}
if not isinstance(self.model, peft.PeftModel): return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"}
return {"success": True, "result": self.model.peft_config, "error_msg": ""}
if not is_peft_available(): return PeftAdapterOutput(success=False, result={}, error_msg="peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'")
if self.__llm_adapter_map__ is None: return PeftAdapterOutput(success=False, result={}, error_msg="No adapters available for current running server.")
if not isinstance(self.model, peft.PeftModel): return PeftAdapterOutput(success=False, result={}, error_msg="Model is not a PeftModel")
return PeftAdapterOutput(success=True, result=self.model.peft_config, error_msg="")
def _wrapped_generate_run(__self: LLMRunner[M, T], prompt: str, **kwargs: t.Any) -> t.Any:
"""Wrapper for runner.generate.run() to handle the prompt and postprocessing.

View File

@@ -1,22 +1,17 @@
from __future__ import annotations
import string
import typing as t
import string, typing as t
class PromptFormatter(string.Formatter):
"""This PromptFormatter is largely based on langchain's implementation."""
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
if len(args) > 0: raise ValueError("Positional arguments are not supported")
return super().vformat(format_string, args, kwargs)
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
extras = set(kwargs).difference(used_args)
if extras: raise KeyError(f"Extra params passed: {extras}")
def extract_template_variables(self, template: str) -> t.Sequence[str]:
return [field[1] for field in self.parse(template) if field[1] is not None]
default_formatter = PromptFormatter()
def process_prompt(prompt: str, template: str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str:
# Currently, all default prompt will always have `instruction` key.
if not use_prompt_template: return prompt
@@ -24,7 +19,5 @@ def process_prompt(prompt: str, template: str | None = None, use_prompt_template
template_variables = default_formatter.extract_template_variables(template)
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
try:
return template.format(instruction=prompt, **prompt_variables)
except KeyError as e:
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None
try: return template.format(instruction=prompt, **prompt_variables)
except KeyError as e: raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None

View File

@@ -1,18 +1,9 @@
# mypy: disable-error-code="name-defined"
from __future__ import annotations
import logging
import sys
import typing as t
import logging, sys, typing as t
from .utils import LazyLoader, is_autogptq_available, is_bitsandbytes_available, is_transformers_supports_kbit, pkg
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload
else:
from typing_extensions import overload
if sys.version_info[:2] >= (3, 11): from typing import overload
else: from typing_extensions import overload
if t.TYPE_CHECKING:
from ._llm import LLM
from ._types import DictStrAny

View File

@@ -1,18 +1,10 @@
"""Schema definition for OpenLLM. This can be use for client interaction."""
from __future__ import annotations
import functools
import typing as t
import attr
import inflection
import openllm
import functools, typing as t
import attr, inflection, openllm
from ._configuration import GenerationConfig, LLMConfig
from .utils import bentoml_cattr
if t.TYPE_CHECKING:
import vllm
if t.TYPE_CHECKING: import vllm
@attr.frozen(slots=True)
class GenerationInput:
@@ -30,7 +22,6 @@ class GenerationInput:
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]: return cls.from_llm_config(openllm.AutoConfig.for_model(model_name, **attrs))
@classmethod
def from_llm_config(cls, llm_config: openllm.LLMConfig) -> type[GenerationInput]: return attr.make_class(inflection.camelize(llm_config["model_name"]) + "GenerationInput", attrs={"prompt": attr.field(type=str), "llm_config": attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__)), "adapter_name": attr.field(default=None, type=str)})
@attr.frozen(slots=True)
class GenerationOutput:
responses: t.List[t.Any]
@@ -43,7 +34,6 @@ class GenerationOutput:
if hasattr(self, key): return getattr(self, key)
elif key in self.configuration: return self.configuration[key]
else: raise KeyError(key)
@attr.frozen(slots=True)
class MetadataOutput:
model_id: str
@@ -53,14 +43,11 @@ class MetadataOutput:
configuration: str
supports_embeddings: bool
supports_hf_agent: bool
@attr.frozen(slots=True)
class EmbeddingsOutput:
embeddings: t.List[t.List[float]]
num_tokens: int
def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> dict[str, t.Any]: return dict(request_id=request_output.request_id, prompt=request_output.prompt, finished=request_output.finished, prompt_token_ids=request_output.prompt_token_ids, outputs=[dict(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason) for it in request_output.outputs])
@attr.define
class HfAgentInput:
inputs: str

View File

@@ -1,13 +1,11 @@
# mypy: disable-error-code="arg-type,misc"
"""The service definition for running any LLMService.
Note that the line `model = ...` is a special line and should not be modified. This will be handled by openllm
internally to generate the correct model service when bundling the LLM to a Bento.
This will ensure that 'bentoml serve llm-bento' will work accordingly.
The generation code lives under utils/codegen.py
For line with comment '# openllm: ...', it must not be modified as it is managed internally by OpenLLM.
Codegen can be found under 'openllm.utils.codegen'
"""
from __future__ import annotations
import os, typing as t, warnings, orjson, bentoml, openllm
import os, warnings, orjson, bentoml, openllm, typing as t
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
@@ -24,48 +22,31 @@ llm_config = openllm.AutoConfig.for_model(model)
runner = openllm.Runner(model, llm_config=llm_config, ensure_available=False, adapter_map=orjson.loads(adapter_map))
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner])
@svc.api(route="/v1/generate", input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), # type: ignore[arg-type] # XXX: remove once JSON supports Attrs class
output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)}))
@svc.api(route="/v1/generate", input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)}))
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
config = qa_inputs.llm_config.model_dump()
responses = await runner.generate.async_run(qa_inputs.prompt, **{"adapter_name": qa_inputs.adapter_name, **config})
return openllm.GenerationOutput(responses=responses, configuration=config)
@svc.api(
route="/v1/metadata", input=bentoml.io.Text(), # type: ignore[misc] # XXX: remove once JSON supports Attrs class
output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent})
)
@svc.api(route="/v1/metadata", input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent}))
def metadata_v1(_: str) -> openllm.MetadataOutput:
return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,)
return openllm.MetadataOutput(timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], model_id=runner.llm.model_id, configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent)
if runner.supports_embeddings:
@svc.api( # type: ignore[arg-type] # XXX: remove once JSON supports Attrs class
input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({
"embeddings": [
0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918,
0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076
], "num_tokens": 20
}), route="/v1/embeddings"
)
@svc.api(route="/v1/embeddings", input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}))
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
responses = await runner.embeddings.async_run(phrases)
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"])
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
async def hf_agent(request: Request) -> Response:
json_str = await request.body()
try:
input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), openllm.HfAgentInput)
except orjson.JSONDecodeError as err:
raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
try: input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), openllm.HfAgentInput)
except orjson.JSONDecodeError as err: raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
stop = input_data.parameters.pop("stop", ["\n"])
try:
return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
except NotImplementedError:
return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
try: return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
except NotImplementedError: return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
svc.mount_asgi_app(hf_app, path="/hf")

View File

@@ -1,38 +1,17 @@
from __future__ import annotations
import functools
import inspect
import logging
import math
import os
import sys
import types
import typing as t
import warnings
import psutil
import bentoml
import functools, inspect, logging, math, os, sys, types, typing as t, warnings, psutil, bentoml
from bentoml._internal.resource import get_resource, system_resources
from bentoml._internal.runner.strategy import THREAD_ENVS
from .utils import DEBUG, ReprMixin
if sys.version_info[:2] >= (3, 11): from typing import overload
else: from typing_extensions import overload
class DynResource(t.Protocol):
resource_id: t.ClassVar[str]
@classmethod
def from_system(cls) -> t.Sequence[t.Any]: ...
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload
else:
from typing_extensions import overload
logger = logging.getLogger(__name__)
def _strtoul(s: str) -> int:
"""Return -1 or positive integer sequence string starts with,."""
if not s: return -1

View File

@@ -1,100 +0,0 @@
"""Types definition for OpenLLM.
Note that this module SHOULD NOT BE IMPORTED DURING RUNTIME, as this serve only for typing purposes.
It will raises a RuntimeError if this is imported eagerly.
"""
from __future__ import annotations
import typing as t
if not t.TYPE_CHECKING: raise RuntimeError(f"{__name__} should not be imported during runtime")
import attr
import bentoml
from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict
from ._configuration import (
AdapterType,
LiteralRuntime as LiteralRuntime,
)
if t.TYPE_CHECKING:
import peft
import openllm
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
from bentoml._internal.runner.strategy import Strategy
from ._llm import (
M as _M,
T as _T,
)
from .bundle.oci import LiteralContainerVersionStrategy
from .utils.lazy import VersionInfo
AnyCallable = t.Callable[..., t.Any]
DictStrAny = dict[str, t.Any]
ListAny = list[t.Any]
ListStr = list[str]
TupleAny = tuple[t.Any, ...]
P = t.ParamSpec("P")
T = t.TypeVar("T")
At = t.TypeVar("At", bound=attr.AttrsInstance)
class PeftAdapterOutput(t.TypedDict):
success: bool
result: dict[str, peft.PeftConfig]
error_msg: str
class LLMEmbeddings(t.TypedDict):
embeddings: t.List[t.List[float]]
num_tokens: int
class AdaptersTuple(TupleAny):
adapter_id: str
name: str | None
config: DictStrAny
class RefTuple(TupleAny):
git_hash: str
version: VersionInfo
strategy: LiteralContainerVersionStrategy
AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]]
class LLMRunnable(bentoml.Runnable, t.Generic[_M, _T]):
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
__call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
class LLMRunner(bentoml.Runner, t.Generic[_M, _T]):
__doc__: str
__module__: str
llm_type: str
identifying_params: dict[str, t.Any]
llm: openllm.LLM[_M, _T]
config: openllm.LLMConfig
implementation: LiteralRuntime
supports_embeddings: bool
supports_hf_agent: bool
has_adapters: bool
embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
def __init__(self, runnable_class: type[LLMRunnable[_M, _T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,) -> None: ...
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ...
def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ...
def run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
def download_model(self) -> bentoml.Model: ...
@property
def peft_adapters(self) -> PeftAdapterOutput: ...
@property
def __repr_keys__(self) -> set[str]: ...

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
import sys, typing as t, bentoml, attr, abc
from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict
if t.TYPE_CHECKING:
import openllm, peft, transformers, auto_gptq as autogptq, vllm
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
from bentoml._internal.runner.strategy import Strategy
from .bundle.oci import LiteralContainerVersionStrategy
from .utils.lazy import VersionInfo
M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]")
T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]")
AnyCallable = t.Callable[..., t.Any]
DictStrAny = t.Dict[str, t.Any]
ListAny = t.List[t.Any]
ListStr = t.List[str]
TupleAny = t.Tuple[t.Any, ...]
At = t.TypeVar("At", bound=attr.AttrsInstance)
LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"]
AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning", "ia3"]
if sys.version_info[:2] >= (3,11):
from typing import LiteralString as LiteralString, Self as Self, overload as overload
from typing import NotRequired as NotRequired, Required as Required, dataclass_transform as dataclass_transform
else:
from typing_extensions import LiteralString as LiteralString, Self as Self, overload as overload
from typing_extensions import NotRequired as NotRequired, Required as Required, dataclass_transform as dataclass_transform
if sys.version_info[:2] >= (3,10):
from typing import TypeAlias as TypeAlias, ParamSpec as ParamSpec, Concatenate as Concatenate
else:
from typing_extensions import TypeAlias as TypeAlias, ParamSpec as ParamSpec, Concatenate as Concatenate
if sys.version_info[:2] >= (3,9):
from typing import TypedDict as TypedDict
else:
from typing_extensions import TypedDict as TypedDict
class PeftAdapterOutput(TypedDict):
success: bool
result: t.Dict[str, peft.PeftConfig]
error_msg: str
class LLMEmbeddings(t.TypedDict):
embeddings: t.List[t.List[float]]
num_tokens: int
class AdaptersTuple(TupleAny):
adapter_id: str
name: t.Optional[str]
config: DictStrAny
AdaptersMapping = t.Dict[AdapterType, t.Tuple[AdaptersTuple, ...]]
class RefTuple(TupleAny):
git_hash: str
version: VersionInfo
strategy: LiteralContainerVersionStrategy
class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
__call__: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
set_adapter: RunnableMethod[LLMRunnable[M, T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings]
generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Generator[t.Any, None, None]]
class LLMRunner(bentoml.Runner, t.Generic[M, T]):
__doc__: str
__module__: str
llm_type: str
identifying_params: dict[str, t.Any]
llm: openllm.LLM[M, T]
config: openllm.LLMConfig
implementation: LiteralRuntime
supports_embeddings: bool
supports_hf_agent: bool
has_adapters: bool
embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings]
generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]]
generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[t.Any, None, None]]
def __init__(self, runnable_class: type[LLMRunnable[M, T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,) -> None: ...
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ...
@abc.abstractmethod
def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ...
def run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
@abc.abstractmethod
def download_model(self) -> bentoml.Model: ...
@property
@abc.abstractmethod
def peft_adapters(self) -> PeftAdapterOutput: ...
@property
@abc.abstractmethod
def __repr_keys__(self) -> set[str]: ...

View File

@@ -3,10 +3,8 @@
These utilities will stay internal, and its API can be changed or updated without backward-compatibility.
"""
from __future__ import annotations
import sys
import typing as t
import openllm
import os, typing as t
from openllm.utils import LazyModule
_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]}
@@ -29,5 +27,8 @@ if t.TYPE_CHECKING:
get_base_container_tag as get_base_container_tag,
supported_registries as supported_registries,
)
else:
sys.modules[__name__] = openllm.utils.LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
__lazy=LazyModule(__name__, os.path.abspath("__file__"), _import_structure)
__all__=__lazy.__all__
__dir__=__lazy.__dir__
__getattr__=__lazy.__getattr__

View File

@@ -1,30 +1,18 @@
# mypy: disable-error-code="misc"
from __future__ import annotations
import importlib.metadata
import inspect
import logging
import os
import typing as t
import importlib.metadata, inspect, logging, os, typing as t
from pathlib import Path
import fs
import fs.copy
import fs.errors
import orjson
import fs, fs.copy, fs.errors, orjson, bentoml, openllm
from simple_di import Provide, inject
import bentoml
import openllm
from bentoml._internal.bento.build_config import BentoBuildConfig, DockerOptions, ModelSpec, PythonOptions
from bentoml._internal.configuration.containers import BentoMLContainer
from . import oci
if t.TYPE_CHECKING:
from fs.base import FS
from openllm._typing_compat import LiteralString
from bentoml._internal.bento import BentoStore
from bentoml._internal.models.model import ModelStore
from .oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
logger = logging.getLogger(__name__)
@@ -39,7 +27,7 @@ def build_editable(path: str) -> str | None:
from build.env import IsolatedEnvBuilder
module_location = openllm.utils.pkg.source_locations("openllm")
if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.")
pyproject_path = Path(module_location).parent.parent / "pyproject.toml"
pyproject_path = Path(module_location).parent.parent/"pyproject.toml"
if os.path.isfile(pyproject_path.__fspath__()):
logger.info("OpenLLM is installed in editable mode. Generating built wheels...")
with IsolatedEnvBuilder() as env:
@@ -80,31 +68,26 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
_tf_version = importlib.metadata.version(candidate)
packages.extend([f"tensorflow>={_tf_version}"])
break
except importlib.metadata.PackageNotFoundError:
pass
except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
else:
if not openllm.utils.is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
wheels: list[str] = []
built_wheels = build_editable(llm_fs.getsyspath("/"))
if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}"))
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"])
def construct_docker_options(
llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry,
container_version_strategy: LiteralContainerVersionStrategy,
) -> DockerOptions:
def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions:
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options_opts = [
"tracing.sample_rate=1.0", f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'api_server.traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
]
_bentoml_config_options_opts = ["tracing.sample_rate=1.0", f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'api_server.traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}']
_bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts)
env: openllm.utils.EnvVarMixin = llm.config["env"]
if env["framework_value"] == "vllm": serialisation_format = "legacy"
env_dict = {
env.framework: env["framework_value"], env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'",
env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}"
env.framework: env["framework_value"], env.config: f"'{llm.config.model_dump_json().decode()}'",
env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}",
"OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format,
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'",
}
if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
@@ -117,10 +100,9 @@ def construct_docker_options(
return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template)
@inject
def create_bento(
bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, device: tuple[str, ...] | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None, runtime: t.Literal[
"ggml", "transformers"] = "transformers", serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release", _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store],
) -> bentoml.Bento:
def create_bento(bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: LiteralString | None, bettertransformer: bool | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None,
runtime: t.Literal[ "ggml", "transformers"] = "transformers", serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release",
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento:
framework_envvar = llm.config["env"]["framework_value"]
labels = dict(llm.identifying_params)
labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"})
@@ -129,13 +111,9 @@ def create_bento(
if workers_per_resource == "round_robin": workers_per_resource = 1.0
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if openllm.utils.device_count() == 0 else float(1 / openllm.utils.device_count())
else:
try:
workers_per_resource = float(workers_per_resource)
except ValueError:
raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
elif isinstance(workers_per_resource, int):
workers_per_resource = float(workers_per_resource)
try: workers_per_resource = float(workers_per_resource)
except ValueError: raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
elif isinstance(workers_per_resource, int): workers_per_resource = float(workers_per_resource)
logger.info("Building Bento for '%s'", llm.config["start_name"])
# add service.py definition to this temporary folder
openllm.utils.codegen.write_service(llm, adapter_map, llm_fs)

View File

@@ -1,31 +1,21 @@
# mypy: disable-error-code="misc"
"""OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change."""
from __future__ import annotations
import functools
import importlib
import logging
import pathlib
import shutil
import subprocess
import typing as t
import functools, importlib, logging, os, pathlib, shutil, subprocess, typing as t
from datetime import datetime, timedelta, timezone
import attr
import orjson
import bentoml
import openllm
import attr, orjson, bentoml, openllm
from openllm.utils.lazy import VersionInfo
if t.TYPE_CHECKING:
from ghapi import all
from openllm._typing_compat import RefTuple
from openllm._types import DictStrAny, RefTuple
else:
all = openllm.utils.LazyLoader("all", globals(), "ghapi.all")
all = openllm.utils.LazyLoader("all", globals(), "ghapi.all") # noqa: F811
logger = logging.getLogger(__name__)
_BUILDER = bentoml.container.get_backend("buildx")
ROOT_DIR = pathlib.Path(__file__).parent.parent.parent
ROOT_DIR = pathlib.Path(os.path.abspath("__file__")).parent.parent.parent
# TODO: support quay
LiteralContainerRegistry = t.Literal["docker", "gh", "ecr"]
@@ -45,14 +35,10 @@ _module_location = openllm.utils.pkg.source_locations("openllm")
@functools.lru_cache
@openllm.utils.apply(str.lower)
def get_base_container_name(reg: LiteralContainerRegistry) -> str:
return _CONTAINER_REGISTRY[reg]
def get_base_container_name(reg: LiteralContainerRegistry) -> str: return _CONTAINER_REGISTRY[reg]
def _convert_version_from_string(s: str) -> openllm.utils.VersionInfo:
return openllm.utils.VersionInfo.from_version_string(s)
def _commit_time_range(r: int = 5) -> str:
return (datetime.now(timezone.utc) - timedelta(days=r)).strftime("%Y-%m-%dT%H:%M:%SZ")
def _convert_version_from_string(s: str) -> VersionInfo: return VersionInfo.from_version_string(s)
def _commit_time_range(r: int = 5) -> str: return (datetime.now(timezone.utc) - timedelta(days=r)).strftime("%Y-%m-%dT%H:%M:%SZ")
class VersionNotSupported(openllm.exceptions.OpenLLMException):
"""Raised when the stable release is too low that it doesn't include OpenLLM base container."""
@@ -66,54 +52,42 @@ def nightly_resolver(cls: type[RefResolver]) -> str:
docker_bin = shutil.which("docker")
if docker_bin is None:
logger.warning("To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)")
commits = t.cast("list[DictStrAny]", cls._ghapi.repos.list_commits(since=_commit_time_range()))
commits = t.cast("list[dict[str, t.Any]]", cls._ghapi.repos.list_commits(since=_commit_time_range()))
return next(f'sha-{it["sha"][:7]}' for it in commits if "[skip ci]" not in it["commit"]["message"])
# now is the correct behaviour
return orjson.loads(subprocess.check_output([docker_bin, "run", "--rm", "-it", "quay.io/skopeo/stable:latest", "list-tags", "docker://ghcr.io/bentoml/openllm"]).decode().strip())["Tags"][-2]
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
class RefResolver:
"""TODO: Support offline mode.
Maybe we need to save git hash when building the Bento.
"""
git_hash: str = attr.field()
version: openllm.utils.VersionInfo = attr.field(converter=_convert_version_from_string)
strategy: LiteralContainerVersionStrategy = attr.field()
_ghapi: all.GhApi = all.GhApi(owner=_OWNER, repo=_REPO) # TODO: support offline mode
_ghapi: t.ClassVar[all.GhApi] = all.GhApi(owner=_OWNER, repo=_REPO)
@classmethod
def _nightly_ref(cls) -> RefTuple:
return _RefTuple((nightly_resolver(cls), "refs/heads/main", "nightly"))
def _nightly_ref(cls) -> RefTuple: return _RefTuple((nightly_resolver(cls), "refs/heads/main", "nightly"))
@classmethod
def _release_ref(cls, version_str: str | None = None) -> RefTuple:
_use_base_strategy = version_str is None
if version_str is None:
# NOTE: This strategy will only support openllm>0.2.12
meta: DictStrAny = cls._ghapi.repos.get_latest_release()
meta: dict[str, t.Any] = cls._ghapi.repos.get_latest_release()
version_str = meta["name"].lstrip("v")
version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")["object"]["sha"], version_str)
else:
version = ("", version_str)
else: version = ("", version_str)
if openllm.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12): raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
return _RefTuple((*version, "release" if _use_base_strategy else "custom"))
@classmethod
@functools.lru_cache(maxsize=64)
def from_strategy(cls, strategy_or_version: t.Literal["release", "nightly"] | str | None = None) -> RefResolver:
if strategy_or_version is None or strategy_or_version == "release":
logger.debug("Using default strategy 'release' for resolving base image version.")
return cls(*cls._release_ref())
elif strategy_or_version == "latest":
return cls("latest", "0.0.0", "latest")
# using default strategy
if strategy_or_version is None or strategy_or_version == "release": return cls(*cls._release_ref())
elif strategy_or_version == "latest": return cls("latest", "0.0.0", "latest")
elif strategy_or_version == "nightly":
_ref = cls._nightly_ref()
return cls(_ref[0], "0.0.0", _ref[-1])
else:
logger.warning("Using custom %s. Make sure that it is at lease 0.2.12 for base container support.", strategy_or_version)
return cls(*cls._release_ref(version_str=strategy_or_version))
@property
def tag(self) -> str:
# NOTE: latest tag can also be nightly, but discouraged to use it. For nightly refer to use sha-<git_hash_short>
@@ -122,33 +96,24 @@ class RefResolver:
else: return repr(self.version)
@functools.lru_cache(maxsize=256)
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str:
return RefResolver.from_strategy(strategy).tag
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str: return RefResolver.from_strategy(strategy).tag
def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None, version_strategy: LiteralContainerVersionStrategy = "release", push: bool = False, machine: bool = False) -> dict[str | LiteralContainerRegistry, str]:
"""This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed.
Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry.
"""
try:
if not _BUILDER.health(): raise openllm.exceptions.Error
except (openllm.exceptions.Error, subprocess.CalledProcessError):
raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
except (openllm.exceptions.Error, subprocess.CalledProcessError): raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
if openllm.utils.device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
if not shutil.which("nvidia-container-runtime"): raise RuntimeError("NVIDIA Container Toolkit is required to compile CUDA kernel in container.")
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml"
if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
if t.TYPE_CHECKING: tags: dict[str | LiteralContainerRegistry, str]
if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy
if not registries: tags: dict[str | LiteralContainerRegistry, str] = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy
else:
registries = [registries] if isinstance(registries, str) else list(registries)
tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries}
try:
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), push=push, progress="plain" if openllm.utils.get_debug_mode() else "auto", quiet=machine)
if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip()
except Exception as err:
raise openllm.exceptions.OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
except Exception as err: raise openllm.exceptions.OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
return tags
if t.TYPE_CHECKING:
@@ -156,10 +121,7 @@ if t.TYPE_CHECKING:
supported_registries: list[str]
__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]
def __dir__() -> list[str]:
return sorted(__all__)
def __dir__() -> list[str]: return sorted(__all__)
def __getattr__(name: str) -> t.Any:
if name == "supported_registries": return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY

View File

@@ -1,37 +1,23 @@
from __future__ import annotations
import functools
import importlib.util
import os
import typing as t
import click
import click_option_group as cog
import inflection
import orjson
import functools, importlib.util, os, typing as t
import click, click_option_group as cog, inflection, orjson, bentoml, openllm
from bentoml_cli.utils import BentoMLCommandGroup
from click.shell_completion import CompletionItem
import bentoml
import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
from openllm._typing_compat import LiteralString, DictStrAny, ParamSpec, Concatenate
from . import termui
if t.TYPE_CHECKING:
import subprocess
from openllm._configuration import LLMConfig
from .._configuration import LLMConfig
from .._types import DictStrAny, P
TupleStr = tuple[str, ...]
else:
TupleStr = tuple
P = ParamSpec("P")
LiteralOutput = t.Literal["json", "pretty", "porcelain"]
_AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, click.Command])
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny:
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny:
# TODO: Support amd.com/gpu on k8s
_bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}']
@@ -44,17 +30,15 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res
_adapter_mapping_key = "adapter_map"
def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None:
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
if not value: return None
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
for v in value:
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
# try to resolve the full path if users pass in relative,
# currently only support one level of resolve path with current directory
try:
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError:
pass
try: adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError: pass
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
return None
@@ -104,7 +88,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
@start_decorator(llm_config, serve_grpc=_serve_grpc)
@click.pass_context
def start_cmd(
ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString, device: tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool,
ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | LiteralString, device: t.Tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool,
serialisation_format: t.Literal["safetensors", "legacy"], adapter_id: str | None, return_process: bool, **attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
fast = str(fast).upper() in openllm.utils.ENV_VARS_TRUE_VALUES
@@ -152,7 +136,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
llm = openllm.utils.infer_auto_class(env["framework_value"]).for_model(model, model_id=start_env[env.model_id], model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format)
start_env.update({env.config: llm.config.model_dump_json().decode()})
server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs)
server = bentoml.GrpcServer("_service:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service:svc", **server_attrs)
openllm.utils.analytics.track_start_init(llm.config)
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
@@ -192,7 +176,7 @@ def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, *
return noop
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
if adapter_map and not openllm.utils.is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config['env']['framework']}=\"pt\"' to run with quantization.")
requirements = llm_config["requirements"]
@@ -201,8 +185,8 @@ def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.Li
if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow")
def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
return lambda fn: openllm.utils.compose(
*[
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
composed = openllm.utils.compose(
llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
model_id_option(factory=cog.optgroup, model_env=llm_config["env"]),
@@ -246,13 +230,14 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
),
cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"),
click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True),
]
)(fn)
)
return composed(fn)
return wrapper
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> TupleStr | None:
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
if value is None: return value
if not openllm.utils.LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
el: TupleStr = tuple(i for k in value for i in k)
if not isinstance(value, tuple): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
# NOTE: --device all is a special case
if len(el) == 1 and el[0] == "all": return tuple(map(str, openllm.utils.available_devices()))
return el
@@ -269,7 +254,7 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
command = "serve" if not serve_grpc else "serve-grpc"
group = cog.optgroup.group(f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",)
def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]:
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
serve_command = cli.commands[command]
# The first variable is the argument bento
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
@@ -285,7 +270,6 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts"))
f = cog.optgroup.option(*param_decls, **attrs)(f)
return group(f)
return decorator
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
@@ -299,12 +283,10 @@ def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC |
factory = attrs.pop("factory", click)
factory_attr = attrs.pop("attr", "option")
if factory_attr != "argument": attrs.setdefault("help", "General option for OpenLLM CLI.")
def decorator(f: FC | None) -> FC:
callback = getattr(factory, factory_attr, None)
if callback is None: raise ValueError(f"Factory {factory} has no attribute {factory_attr}.")
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
return decorator
cli_option = functools.partial(_click_factory_type, attr="option")
@@ -312,12 +294,8 @@ cli_argument = functools.partial(_click_factory_type, attr="argument")
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = "pretty", **attrs: t.Any) -> t.Callable[[FC], FC]:
output = ["json", "pretty", "porcelain"]
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]:
return [CompletionItem(it) for it in output]
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]: return [CompletionItem(it) for it in output]
return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True, envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f)
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
"--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True, help="""Whether to skip checking if models is already in store.
@@ -325,18 +303,10 @@ def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC
This is useful if you already downloaded or setup the model beforehand.
""", **attrs
)(f)
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]:
return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required)(f)
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]: return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required)(f)
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
@@ -423,10 +393,8 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va
value = inflection.underscore(value)
if value in _wpr_strategies: return value
else:
try:
float(value) # type: ignore[arg-type]
except ValueError:
raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
try: float(value) # type: ignore[arg-type]
except ValueError: raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
else:
return value

View File

@@ -1,34 +1,21 @@
from __future__ import annotations
import itertools
import logging
import os
import re
import subprocess
import sys
import typing as t
import itertools, logging, os, re, subprocess, sys, typing as t
import bentoml, openllm
from simple_di import Provide, inject
import bentoml
import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
from openllm.exceptions import OpenLLMException
from . import termui
from ._factory import start_command_factory
if t.TYPE_CHECKING:
from openllm._typing_compat import LiteralString, LiteralRuntime
from bentoml._internal.bento import BentoStore
from openllm._configuration import LiteralRuntime, LLMConfig
from openllm._configuration import LLMConfig
from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
logger = logging.getLogger(__name__)
def _start(
model_name: str, /, *, model_id: str | None = None, timeout: int = 30, workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, device: tuple[str, ...] | t.Literal["all"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers",
fast: bool = False, adapter_map: dict[t.LiteralString, str | None] | None = None, framework: LiteralRuntime | None = None, additional_args: list[str] | None = None, _serve_grpc: bool = False, __test__: bool = False, **_: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
def _start(model_name: str, /, *, model_id: str | None = None, timeout: int = 30, workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, device: tuple[str, ...] | t.Literal["all"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", fast: bool = False, adapter_map: dict[LiteralString, str | None] | None = None, framework: LiteralRuntime | None = None, additional_args: list[str] | None = None, _serve_grpc: bool = False, __test__: bool = False, **_: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
"""Python API to start a LLM server. These provides one-to-one mapping to CLI arguments.
For all additional arguments, pass it as string to ``additional_args``. For example, if you want to
@@ -169,9 +156,7 @@ def _build(model_name: str, /, *, model_id: str | None = None, model_version: st
if matched is None: raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
return bentoml.get(matched.group(1), _bento_store=bento_store)
def _import_model(
model_name: str, /, *, model_id: str | None = None, model_version: str | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", implementation: LiteralRuntime = "pt", quantize: t.Literal["int8", "int4", "gptq"] | None = None, serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors", additional_args: t.Sequence[str] | None = None,
) -> bentoml.Model:
def _import_model(model_name: str, /, *, model_id: str | None = None, model_version: str | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", implementation: LiteralRuntime = "pt", quantize: t.Literal["int8", "int4", "gptq"] | None = None, serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors", additional_args: t.Sequence[str] | None = None) -> bentoml.Model:
"""Import a LLM into local store.
> [!NOTE]

View File

@@ -20,36 +20,12 @@ bentomodel = openllm.import_model("falcon", model_id='tiiuae/falcon-7b-instruct'
```
"""
from __future__ import annotations
import functools
import http.client
import inspect
import itertools
import logging
import os
import platform
import re
import subprocess
import sys
import time
import traceback
import typing as t
import attr
import click
import click_option_group as cog
import fs
import fs.copy
import fs.errors
import inflection
import orjson
import functools, http.client, inspect, itertools, logging, os, platform, re, subprocess, sys, time, traceback, typing as t
import attr, click, click_option_group as cog, fs, fs.copy, fs.errors, inflection, orjson, bentoml, openllm
from bentoml_cli.utils import BentoMLCommandGroup, opt_callback
from simple_di import Provide, inject
import bentoml
import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelStore
from . import termui
from ._factory import (
FC,
@@ -69,9 +45,9 @@ from ._factory import (
start_command_factory,
workers_per_resource_option,
)
from .. import bundle, serialisation
from ..exceptions import OpenLLMException
from ..models.auto import (
from openllm import bundle, serialisation
from openllm.exceptions import OpenLLMException
from openllm.models.auto import (
CONFIG_MAPPING,
MODEL_FLAX_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
@@ -80,7 +56,8 @@ from ..models.auto import (
AutoConfig,
AutoLLM,
)
from ..utils import (
from openllm._typing_compat import DictStrAny, ParamSpec, Concatenate, LiteralString, Self, LiteralRuntime
from openllm.utils import (
DEBUG,
DEBUG_ENV_VAR,
OPTIONAL_DEPENDENCIES,
@@ -105,19 +82,15 @@ from ..utils import (
if t.TYPE_CHECKING:
import torch
from bentoml._internal.bento import BentoStore
from bentoml._internal.container import DefaultBuilder
from openllm.client import BaseClient
from openllm._schema import EmbeddingsOutput
from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
else: torch = LazyLoader("torch", globals(), "torch")
from .._schema import EmbeddingsOutput
from .._types import DictStrAny, LiteralRuntime, P
from ..bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
else:
torch, jupytext, nbformat = LazyLoader("torch", globals(), "torch"), LazyLoader("jupytext", globals(), "jupytext"), LazyLoader("nbformat", globals(), "nbformat")
P = ParamSpec("P")
logger = logging.getLogger(__name__)
OPENLLM_FIGLET = """\
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
@@ -132,9 +105,7 @@ ServeCommand = t.Literal["serve", "serve-grpc"]
@attr.define
class GlobalOptions:
cloud_context: str | None = attr.field(default=None)
def with_options(self, **attrs: t.Any) -> t.Self:
return attr.evolve(self, **attrs)
def with_options(self, **attrs: t.Any) -> Self: return attr.evolve(self, **attrs)
GrpType = t.TypeVar("GrpType", bound=click.Group)
@@ -173,7 +144,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
return wrapper
@staticmethod
def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[t.Concatenate[bool, P], t.Any]:
def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[Concatenate[bool, P], t.Any]:
command_name = attrs.get("name", func.__name__)
@functools.wraps(func)
@@ -197,7 +168,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
event.return_code = 2 if isinstance(e, KeyboardInterrupt) else 1
analytics.track(event)
raise
return t.cast("t.Callable[t.Concatenate[bool, P], t.Any]", wrapper)
return t.cast(t.Callable[Concatenate[bool, P], t.Any], wrapper)
@staticmethod
def exception_handling(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[P, t.Any]:
@@ -266,20 +237,17 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
for subcommand in self.list_commands(ctx):
cmd = self.get_command(ctx, subcommand)
if cmd is None or cmd.hidden: continue
if subcommand in _cached_extensions:
extensions.append((subcommand, cmd))
else:
commands.append((subcommand, cmd))
if subcommand in _cached_extensions: extensions.append((subcommand, cmd))
else: commands.append((subcommand, cmd))
# allow for 3 times the default spacing
if len(commands):
limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands)
rows = []
rows: list[tuple[str, str]]= []
for subcommand, cmd in commands:
help = cmd.get_short_help_str(limit)
rows.append((subcommand, help))
if rows:
with formatter.section(_("Commands")):
formatter.write_dl(rows)
with formatter.section(_("Commands")): formatter.write_dl(rows)
if len(extensions):
limit = formatter.width - 6 - max(len(cmd[0]) for cmd in extensions)
rows = []
@@ -287,8 +255,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
help = cmd.get_short_help_str(limit)
rows.append((inflection.dasherize(subcommand), help))
if rows:
with formatter.section(_("Extensions")):
formatter.write_dl(rows)
with formatter.section(_("Extensions")): formatter.write_dl(rows)
@click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="openllm")
@click.version_option(None, "--version", "-v", message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}")
@@ -524,10 +491,11 @@ def build_command(
_previously_built = True
except bentoml.exceptions.NotFound:
bento = bundle.create_bento(
bento_tag, llm_fs, llm, workers_per_resource=workers_per_resource, device=device, adapter_map=adapter_map, quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=enable_features, dockerfile_template=dockerfile_template_path, runtime=runtime, container_registry=container_registry, container_version_strategy=container_version_strategy
bento_tag, llm_fs, llm, workers_per_resource=workers_per_resource, adapter_map=adapter_map,
quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=enable_features, dockerfile_template=dockerfile_template_path, runtime=runtime,
container_registry=container_registry, container_version_strategy=container_version_strategy
)
except Exception as err:
raise err from None
except Exception as err: raise err from None
if machine: termui.echo(f"__tag__:{bento.tag}", fg="white")
elif output == "pretty":
@@ -535,26 +503,17 @@ def build_command(
termui.echo("\n" + OPENLLM_FIGLET, fg="white")
if not _previously_built: termui.echo(f"Successfully built {bento}.", fg="green")
elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg="yellow")
termui.echo(
"📖 Next steps:\n\n" + "* Push to BentoCloud with 'bentoml push':\n" + f" $ bentoml push {bento.tag}\n\n" + "* Containerize your Bento with 'bentoml containerize':\n" + f" $ bentoml containerize {bento.tag} --opt progress=plain" + "\n\n" + " Tip: To enable additional BentoML features for 'containerize', " + "use '--enable-features=FEATURE[,FEATURE]' " +
"[see 'bentoml containerize -h' for more advanced usage]\n", fg="blue",
)
elif output == "json":
termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
else:
termui.echo(bento.tag)
termui.echo("📖 Next steps:\n\n" + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" + f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" + "\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n", fg="blue",)
elif output == "json": termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
else: termui.echo(bento.tag)
if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push)
elif containerize:
backend = t.cast("DefaultBuilder", os.environ.get("BENTOML_CONTAINERIZE_BACKEND", "docker"))
try:
bentoml.container.health(backend)
except subprocess.CalledProcessError:
raise OpenLLMException(f"Failed to use backend {backend}") from None
try:
bentoml.container.build(bento.tag, backend=backend, features=("grpc", "io"))
except Exception as err:
raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err
try: bentoml.container.health(backend)
except subprocess.CalledProcessError: raise OpenLLMException(f"Failed to use backend {backend}") from None
try: bentoml.container.build(bento.tag, backend=backend, features=("grpc", "io"))
except Exception as err: raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err
return bento
@cli.command()
@@ -613,7 +572,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
tabulate.PRESERVE_WHITESPACE = True
# llm, architecture, url, model_id, installation, cpu, gpu, runtime_impl
data: list[str | tuple[str, str, list[str], str, t.LiteralString, t.LiteralString, tuple[LiteralRuntime, ...]]] = []
data: list[str | tuple[str, str, list[str], str, LiteralString, LiteralString, tuple[LiteralRuntime, ...]]] = []
for m, v in json_data.items():
data.extend([(m, v["architecture"], v["model_id"], v["installation"], "" if not v["cpu"] else "", "", v["runtime_impl"],)])
column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 4),]
@@ -622,7 +581,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
termui.echo("Exception found while parsing models:\n", fg="yellow")
for m, err in failed_initialized:
termui.echo(f"- {m}: ", fg="yellow", nl=False)
termui.echo(traceback.print_exception(err, limit=3), fg="red")
termui.echo(traceback.print_exception(None, err, None, limit=5), fg="red") # type: ignore[func-returns-value]
sys.exit(1)
table = tabulate.tabulate(data, tablefmt="fancy_grid", headers=["LLM", "Architecture", "Models Id", "pip install", "CPU", "GPU", "Runtime"], maxcolwidths=column_widths)
@@ -699,7 +658,7 @@ def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal
@click.option("--remote", is_flag=True, default=False, help="Whether or not to use remote tools (inference endpoints) instead of local ones.", show_default=True)
@click.option("--opt", help="Define prompt options. "
"(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]")
def instruct_command(endpoint: str, timeout: int, agent: t.LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str:
def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str:
"""Instruct agents interactively for given tasks, from a terminal.
\b

View File

@@ -1,20 +1,11 @@
from __future__ import annotations
import typing as t
import click
import inflection
import orjson
import bentoml
import openllm
import typing as t, bentoml, openllm, orjson, inflection ,click
from bentoml._internal.utils import human_readable_size
from .. import termui
from .._factory import LiteralOutput, model_name_argument, output_option
from openllm.cli import termui
from openllm.cli._factory import LiteralOutput, model_name_argument, output_option
if t.TYPE_CHECKING:
from ..._types import DictStrAny
if t.TYPE_CHECKING: from openllm._typing_compat import DictStrAny
@click.command("list_models", context_settings=termui.CONTEXT_SETTINGS)
@model_name_argument(required=False)

View File

@@ -1,26 +1,13 @@
from __future__ import annotations
import importlib.machinery
import logging
import os
import pkgutil
import subprocess
import sys
import tempfile
import typing as t
import click
import yaml
from .. import termui
from ... import playground
from ...utils import is_jupyter_available, is_jupytext_available, is_notebook_available
import importlib.machinery, logging, os, pkgutil, subprocess, sys, tempfile, typing as t
import click, yaml
from openllm.cli import termui
from openllm import playground
from openllm.utils import is_jupyter_available, is_jupytext_available, is_notebook_available
if t.TYPE_CHECKING:
import jupytext
import nbformat
from openllm._types import DictStrAny
import jupytext, nbformat
from openllm._typing_compat import DictStrAny
logger = logging.getLogger(__name__)

View File

@@ -1,22 +1,11 @@
from __future__ import annotations
import os
import typing as t
import click
import inflection
import openllm
if t.TYPE_CHECKING:
from .._types import DictStrAny
import os, typing as t, click, inflection, openllm
if t.TYPE_CHECKING: from openllm._typing_compat import DictStrAny
def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.Any) -> None:
attrs["fg"] = fg if not openllm.utils.get_debug_mode() else None
if not openllm.utils.get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(text, **attrs)
COLUMNS: int = int(os.environ.get("COLUMNS", str(120)))
CONTEXT_SETTINGS: DictStrAny = {"help_option_names": ["-h", "--help"], "max_content_width": COLUMNS, "token_normalize_func": inflection.underscore}
__all__ = ["echo", "COLUMNS", "CONTEXT_SETTINGS"]

View File

@@ -1,12 +1,18 @@
"""The actual client implementation.
"""OpenLLM Python client.
Use ``openllm.client`` instead.
This holds the implementation of the client, which is used to communicate with the
OpenLLM server. It is used to send requests to the server, and receive responses.
```python
client = openllm.client.HTTPClient("http://localhost:8080")
client.query("What is the difference between gather and scatter?")
```
If the server has embedding supports, use it via `client.embed`:
```python
client.embed("What is the difference between gather and scatter?")
```
"""
from __future__ import annotations
from .runtimes import (
from openllm.client.runtimes import (
AsyncGrpcClient as AsyncGrpcClient,
AsyncHTTPClient as AsyncHTTPClient,
BaseAsyncClient as BaseAsyncClient,

View File

@@ -1,15 +1,15 @@
"""Client that supports REST/gRPC protocol to interact with a LLMServer."""
from __future__ import annotations
from .base import (
from openllm.client.runtimes.base import (
BaseAsyncClient as BaseAsyncClient,
BaseClient as BaseClient,
)
from .grpc import (
from openllm.client.runtimes.grpc import (
AsyncGrpcClient as AsyncGrpcClient,
GrpcClient as GrpcClient,
)
from .http import (
from openllm.client.runtimes.http import (
AsyncHTTPClient as AsyncHTTPClient,
HTTPClient as HTTPClient,
)

View File

@@ -1,51 +1,35 @@
# mypy: disable-error-code="name-defined"
from __future__ import annotations
import asyncio
import logging
import sys
import typing as t
import asyncio, logging, typing as t
import bentoml, bentoml.client, openllm, httpx
from abc import abstractmethod
from http import HTTPStatus
from urllib.parse import urljoin
import httpx
import bentoml
import openllm
from bentoml._internal.client import Client
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload
else:
from typing_extensions import overload
from openllm._typing_compat import overload, LiteralString
T = t.TypeVar("T")
T_co = t.TypeVar("T_co", covariant=True)
if t.TYPE_CHECKING:
import transformers
from openllm._typing_compat import DictStrAny, LiteralRuntime
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") # noqa: F811
from openllm._types import DictStrAny, LiteralRuntime
class AnnotatedClient(Client, t.Generic[T]):
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any:
...
async def async_health(self) -> t.Any:
...
def generate_v1(self, qa: openllm.GenerationInput) -> T:
...
def metadata_v1(self) -> T:
...
def embeddings_v1(self) -> t.Sequence[float]:
...
else:
AnnotatedClient = Client
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
DictStrAny = dict
class AnnotatedClient(t.Protocol[T_co]):
server_url: str
_svc: bentoml.Service
endpoints: list[str]
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ...
async def async_health(self) -> t.Any: ...
def generate_v1(self, qa: openllm.GenerationInput) -> T_co: ...
def metadata_v1(self) -> T_co: ...
def embeddings_v1(self) -> t.Sequence[float]: ...
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T_co: ...
async def async_call(self, name: str, *args: t.Any, **attrs: t.Any) -> T_co: ...
@staticmethod
def wait_until_server_ready(host: str, port: int, timeout: float = 30, **kwargs: t.Any) -> None: ...
@staticmethod
def from_url(server_url: str) -> AnnotatedClient[t.Any]: ...
logger = logging.getLogger(__name__)
@@ -53,12 +37,11 @@ def in_async_context() -> bool:
try:
_ = asyncio.get_running_loop()
return True
except RuntimeError:
return False
except RuntimeError: return False
class ClientMeta(t.Generic[T]):
_api_version: str
_client_class: type[bentoml.client.Client]
_client_type: t.Literal["GrpcClient", "HTTPClient"]
_host: str
_port: str
@@ -66,15 +49,8 @@ class ClientMeta(t.Generic[T]):
__agent__: transformers.HfAgent | None = None
__llm__: openllm.LLM[t.Any, t.Any] | None = None
def __init__(self, address: str, timeout: int = 30):
self._address = address
self._timeout = timeout
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
"""Initialise subclass for HTTP and gRPC client type."""
cls._client_class = t.cast(t.Type[bentoml.client.Client], bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient)
cls._api_version = api_version
def __init__(self, address: str, timeout: int = 30): self._address,self._timeout = address,timeout
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): cls._client_type, cls._api_version = "HTTPClient" if client_type == "http" else "GrpcClient", api_version
@property
def _hf_agent(self) -> transformers.HfAgent:
if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.")
@@ -82,99 +58,62 @@ class ClientMeta(t.Generic[T]):
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent"))
return self.__agent__
@property
def _metadata(self) -> T:
if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
return self.call("metadata")
def _metadata(self) -> T: return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() if in_async_context() else self.call("metadata")
@property
@abstractmethod
def model_name(self) -> str:
raise NotImplementedError
def model_name(self) -> str: raise NotImplementedError
@property
@abstractmethod
def framework(self) -> LiteralRuntime:
raise NotImplementedError
def framework(self) -> LiteralRuntime: raise NotImplementedError
@property
@abstractmethod
def timeout(self) -> int:
raise NotImplementedError
def timeout(self) -> int: raise NotImplementedError
@property
@abstractmethod
def model_id(self) -> str:
raise NotImplementedError
def model_id(self) -> str: raise NotImplementedError
@property
@abstractmethod
def configuration(self) -> dict[str, t.Any]:
raise NotImplementedError
def configuration(self) -> dict[str, t.Any]: raise NotImplementedError
@property
@abstractmethod
def supports_embeddings(self) -> bool:
raise NotImplementedError
def supports_embeddings(self) -> bool: raise NotImplementedError
@property
@abstractmethod
def supports_hf_agent(self) -> bool:
raise NotImplementedError
def supports_hf_agent(self) -> bool: raise NotImplementedError
@abstractmethod
def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ...
@abstractmethod
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
@property
def config(self) -> openllm.LLMConfig: return self.llm.config
@property
def llm(self) -> openllm.LLM[t.Any, t.Any]:
# XXX: if the server runs vllm or any framework that is not available from the user client, client will fail.
if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
return self.__llm__
@property
def config(self) -> openllm.LLMConfig:
return self.llm.config
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T:
return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> T:
return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> T: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
@property
def _cached(self) -> AnnotatedClient[T]:
client_class = t.cast(AnnotatedClient[T], getattr(bentoml.client, self._client_type))
if self.__client__ is None:
self._client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
self.__client__ = t.cast("AnnotatedClient[T]", self._client_class.from_url(self._address))
client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
self.__client__ = client_class.from_url(self._address)
return self.__client__
@abstractmethod
def postprocess(self, result: t.Any) -> openllm.GenerationOutput:
...
@abstractmethod
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
...
class BaseClient(ClientMeta[T]):
def health(self) -> t.Any:
raise NotImplementedError
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
raise NotImplementedError
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
raise NotImplementedError
def health(self) -> t.Any: raise NotImplementedError
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
@overload
def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
...
def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
@overload
def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
...
def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
@overload
def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
...
def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
return_raw_response = attrs.pop("return_raw_response", None)
if return_raw_response is not None:
@@ -197,21 +136,14 @@ class BaseClient(ClientMeta[T]):
# NOTE: Scikit interface
@overload
def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
...
def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
@overload
def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
...
def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
@overload
def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
...
def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs))
def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs))
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any:
if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
@@ -220,34 +152,21 @@ class BaseClient(ClientMeta[T]):
task = kwargs.pop("task", args[0])
return_code = kwargs.pop("return_code", False)
remote = kwargs.pop("remote", False)
try:
return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs)
try: return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs)
except Exception as err:
logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err)
logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address)
logger.info("Tip: LLMServer at '%s' might not support 'generate_one'.", self._address)
class BaseAsyncClient(ClientMeta[T]):
async def health(self) -> t.Any:
raise NotImplementedError
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
raise NotImplementedError
async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
raise NotImplementedError
async def health(self) -> t.Any: raise NotImplementedError
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
@overload
async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
...
async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
@overload
async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
...
async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
@overload
async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
...
async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
return_raw_response = attrs.pop("return_raw_response", None)
if return_raw_response is not None:
@@ -270,25 +189,16 @@ class BaseAsyncClient(ClientMeta[T]):
# NOTE: Scikit interface
@overload
async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
...
async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
@overload
async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
...
async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
@overload
async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
...
async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs))
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs))
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any:
"""Async version of agent.run."""
if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0")
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
@@ -317,9 +227,7 @@ class BaseAsyncClient(ClientMeta[T]):
# the below have the same logic as agent.run API
explanation, code = clean_code_for_run(result)
_hf_agent.log(f"==Explanation from the agent==\n{explanation}")
_hf_agent.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
_hf_agent.log("\n\n==Result==")

View File

@@ -1,102 +1,93 @@
from __future__ import annotations
import asyncio
import logging
import typing as t
import orjson
import openllm
import asyncio, logging, typing as t
import orjson, openllm
from openllm._typing_compat import LiteralRuntime
from .base import BaseAsyncClient, BaseClient
if t.TYPE_CHECKING:
from grpc_health.v1 import health_pb2
from bentoml.grpc.v1.service_pb2 import Response
logger = logging.getLogger(__name__)
LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"]
class GrpcClientMixin:
if t.TYPE_CHECKING:
@property
def _metadata(self) -> Response:
...
class GrpcClient(BaseClient["Response"], client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)
def health(self) -> health_pb2.HealthCheckResponse: return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
@property
def model_name(self) -> str:
try:
return self._metadata.json.struct_value.fields["model_name"].string_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata.json.struct_value.fields["model_name"].string_value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def framework(self) -> LiteralRuntime:
try:
value = t.cast(LiteralRuntime, self._metadata.json.struct_value.fields["framework"].string_value)
if value not in ("pt", "flax", "tf", "vllm"): raise KeyError
return value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def timeout(self) -> int:
try:
return int(self._metadata.json.struct_value.fields["timeout"].number_value)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return int(self._metadata.json.struct_value.fields["timeout"].number_value)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def model_id(self) -> str:
try:
return self._metadata.json.struct_value.fields["model_id"].string_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata.json.struct_value.fields["model_id"].string_value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def configuration(self) -> dict[str, t.Any]:
try:
v = self._metadata.json.struct_value.fields["configuration"].string_value
return orjson.loads(v)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return orjson.loads(self._metadata.json.struct_value.fields["configuration"].string_value)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_embeddings(self) -> bool:
try:
return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_hf_agent(self) -> bool:
try:
return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
if isinstance(result, dict):
return openllm.GenerationOutput(**result)
from google.protobuf.json_format import MessageToDict
if isinstance(result, dict): return openllm.GenerationOutput(**result)
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
class GrpcClient(GrpcClientMixin, BaseClient["Response"], client_type="grpc"):
class AsyncGrpcClient(BaseAsyncClient["Response"], client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)
def health(self) -> health_pb2.HealthCheckResponse:
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient["Response"], client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)
async def health(self) -> health_pb2.HealthCheckResponse:
return await self._cached.health("bentoml.grpc.v1.BentoService")
async def health(self) -> health_pb2.HealthCheckResponse: return await self._cached.health("bentoml.grpc.v1.BentoService")
@property
def model_name(self) -> str:
try: return self._metadata.json.struct_value.fields["model_name"].string_value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def framework(self) -> LiteralRuntime:
try:
value = t.cast(LiteralRuntime, self._metadata.json.struct_value.fields["framework"].string_value)
if value not in ("pt", "flax", "tf", "vllm"): raise KeyError
return value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def timeout(self) -> int:
try: return int(self._metadata.json.struct_value.fields["timeout"].number_value)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def model_id(self) -> str:
try: return self._metadata.json.struct_value.fields["model_id"].string_value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def configuration(self) -> dict[str, t.Any]:
try: return orjson.loads(self._metadata.json.struct_value.fields["configuration"].string_value)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_embeddings(self) -> bool:
try: return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_hf_agent(self) -> bool:
try: return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
from google.protobuf.json_format import MessageToDict
if isinstance(result, dict): return openllm.GenerationOutput(**result)
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))

View File

@@ -1,110 +1,98 @@
from __future__ import annotations
import logging
import typing as t
import logging, typing as t
from urllib.parse import urljoin, urlparse
import httpx
import orjson
import openllm
import httpx, orjson, openllm
from .base import BaseAsyncClient, BaseClient, in_async_context
if t.TYPE_CHECKING:
from openllm._types import DictStrAny, LiteralRuntime
else:
DictStrAny = dict
from openllm._typing_compat import DictStrAny, LiteralRuntime
logger = logging.getLogger(__name__)
def process_address(self: AsyncHTTPClient | HTTPClient, address: str) -> None:
address = address if "://" in address else "http://" + address
parsed = urlparse(address)
self._host, *_port = parsed.netloc.split(":")
if len(_port) == 0: self._port = "80" if parsed.scheme == "http" else "443"
else: self._port = next(iter(_port))
class HTTPClientMixin:
if t.TYPE_CHECKING:
class HTTPClient(BaseClient[DictStrAny]):
def __init__(self, address: str, timeout: int = 30):
process_address(self, address)
super().__init__(address, timeout)
@property
def _metadata(self) -> DictStrAny:
...
def health(self) -> t.Any: return self._cached.health()
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
if not self.supports_embeddings: raise ValueError("This model does not support embeddings.")
if isinstance(prompt, str): prompt = [prompt]
result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout).json() if in_async_context() else self.call("embeddings", list(prompt))
return openllm.EmbeddingsOutput(**result)
@property
def model_name(self) -> str:
try:
return self._metadata["model_name"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata["model_name"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def model_id(self) -> str:
try:
return self._metadata["model_name"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata["model_name"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def framework(self) -> LiteralRuntime:
try:
return self._metadata["framework"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata["framework"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def timeout(self) -> int:
try:
return self._metadata["timeout"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata["timeout"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def configuration(self) -> dict[str, t.Any]:
try:
return orjson.loads(self._metadata["configuration"])
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return orjson.loads(self._metadata["configuration"])
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_embeddings(self) -> bool:
try:
return self._metadata.get("supports_embeddings", False)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata.get("supports_embeddings", False)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_hf_agent(self) -> bool:
try:
return self._metadata.get("supports_hf_agent", False)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata.get("supports_hf_agent", False)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result)
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
return openllm.GenerationOutput(**result)
class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]):
class AsyncHTTPClient(BaseAsyncClient[DictStrAny]):
def __init__(self, address: str, timeout: int = 30):
address = address if "://" in address else "http://" + address
self._host, self._port = urlparse(address).netloc.split(":")
process_address(self, address)
super().__init__(address, timeout)
def health(self) -> t.Any:
return self._cached.health()
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
if not self.supports_embeddings:
raise ValueError("This model does not support embeddings.")
if isinstance(prompt, str): prompt = [prompt]
if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout).json()
else: result = self.call("embeddings", list(prompt))
return openllm.EmbeddingsOutput(**result)
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]):
def __init__(self, address: str, timeout: int = 30):
address = address if "://" in address else "http://" + address
self._host, self._port = urlparse(address).netloc.split(":")
super().__init__(address, timeout)
async def health(self) -> t.Any:
return await self._cached.async_health()
async def health(self) -> t.Any: return await self._cached.async_health()
async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
if not self.supports_embeddings:
raise ValueError("This model does not support embeddings.")
if not self.supports_embeddings: raise ValueError("This model does not support embeddings.")
if isinstance(prompt, str): prompt = [prompt]
res = await self.acall("embeddings", list(prompt))
return openllm.EmbeddingsOutput(**res)
@property
def model_name(self) -> str:
try: return self._metadata["model_name"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def model_id(self) -> str:
try: return self._metadata["model_name"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def framework(self) -> LiteralRuntime:
try: return self._metadata["framework"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def timeout(self) -> int:
try: return self._metadata["timeout"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def configuration(self) -> dict[str, t.Any]:
try: return orjson.loads(self._metadata["configuration"])
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_embeddings(self) -> bool:
try: return self._metadata.get("supports_embeddings", False)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_hf_agent(self) -> bool:
try: return self._metadata.get("supports_hf_agent", False)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result)

View File

@@ -1,28 +1,19 @@
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
from __future__ import annotations
import bentoml
class OpenLLMException(bentoml.exceptions.BentoMLException):
"""Base class for all OpenLLM exceptions. This extends BentoMLException."""
class GpuNotAvailableError(OpenLLMException):
"""Raised when there is no GPU available in given system."""
class ValidationError(OpenLLMException):
"""Raised when a validation fails."""
class ForbiddenAttributeError(OpenLLMException):
"""Raised when using an _internal field."""
class MissingAnnotationAttributeError(OpenLLMException):
"""Raised when a field under openllm.LLMConfig is missing annotations."""
class MissingDependencyError(BaseException):
"""Raised when a dependency is missing."""
class Error(BaseException):
"""To be used instead of naked raise."""
class FineTuneStrategyNotSupportedError(OpenLLMException):
"""Raised when a fine-tune strategy is not supported for given LLM."""

View File

@@ -1,11 +1,11 @@
# This file is generated by tools/update-models-import.py. DO NOT EDIT MANUALLY!
# To update this, run ./tools/update-models-import.py
from __future__ import annotations
import typing as t
import typing as t, os
from openllm.utils import LazyModule
_MODELS: set[str] = {"auto", "baichuan", "chatglm", "dolly_v2", "falcon", "flan_t5", "gpt_neox", "llama", "mpt", "opt", "stablelm", "starcoder"}
if t.TYPE_CHECKING: from . import auto as auto,baichuan as baichuan,chatglm as chatglm,dolly_v2 as dolly_v2,falcon as falcon,flan_t5 as flan_t5,gpt_neox as gpt_neox,llama as llama,mpt as mpt,opt as opt,stablelm as stablelm,starcoder as starcoder
__lazy=LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS})
__lazy=LazyModule(__name__, os.path.abspath("__file__"), {k: [] for k in _MODELS})
__all__=__lazy.__all__
__dir__=__lazy.__dir__
__getattr__=__lazy.__getattr__

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
import typing as t, os
import openllm
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available
@@ -40,4 +39,7 @@ else:
_import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
if t.TYPE_CHECKING: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
__lazy=LazyModule(__name__, os.path.abspath("__file__"), _import_structure)
__all__=__lazy.__all__
__dir__=__lazy.__dir__
__getattr__=__lazy.__getattr__

View File

@@ -1,16 +1,15 @@
# mypy: disable-error-code="type-arg"
from __future__ import annotations
import typing as t
from collections import OrderedDict
import inflection
import openllm
import inflection, openllm
from openllm.utils import ReprMixin
if t.TYPE_CHECKING:
import types
from openllm._typing_compat import LiteralString
from collections import _odict_items, _odict_keys, _odict_values
ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]]
ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]]
ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]]
ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]]
@@ -18,8 +17,8 @@ if t.TYPE_CHECKING:
# NOTE: This is the entrypoint when adding new model config
CONFIG_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLMConfig"), ("dolly_v2", "DollyV2Config"), ("falcon", "FalconConfig"), ("flan_t5", "FlanT5Config"), ("gpt_neox", "GPTNeoXConfig"), ("llama", "LlamaConfig"), ("mpt", "MPTConfig"), ("opt", "OPTConfig"), ("stablelm", "StableLMConfig"), ("starcoder", "StarCoderConfig"), ("baichuan", "BaichuanConfig")])
class _LazyConfigMapping(OrderedDict, ReprMixin): # type: ignore[type-arg]
def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]):
class _LazyConfigMapping(OrderedDict, ReprMixin):
def __init__(self, mapping: OrderedDict[LiteralString, LiteralString]):
self._mapping = mapping
self._extra_content: dict[str, t.Any] = {}
self._modules: dict[str, types.ModuleType] = {}

View File

@@ -1,23 +1,16 @@
# mypy: disable-error-code="type-arg"
from __future__ import annotations
import importlib
import inspect
import logging
import typing as t
import importlib, inspect, logging, typing as t
from collections import OrderedDict
import inflection
import openllm
import inflection, openllm
from openllm.utils import ReprMixin
if t.TYPE_CHECKING:
from openllm._typing_compat import LiteralString, LLMRunner
import types
from collections import _odict_items, _odict_keys, _odict_values
from _typeshed import SupportsIter
from ..._llm import LLMRunner
ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
@@ -25,7 +18,7 @@ if t.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class BaseAutoLLMClass:
_model_mapping: _LazyAutoMapping
_model_mapping: t.ClassVar[_LazyAutoMapping]
def __init__(self, *args: t.Any, **attrs: t.Any): raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
@classmethod
def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
@@ -83,13 +76,13 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
except ValueError: raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
raise ValueError(f"Could not find {attr} in {openllm_module}!")
class _LazyAutoMapping(OrderedDict, ReprMixin): # type: ignore[type-arg]
class _LazyAutoMapping(OrderedDict, ReprMixin):
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
This OrderedDict values() and keys() returns the list instead, so you don't
have to do list(mapping.values()) to get the list of values.
"""
def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]):
def __init__(self, config_mapping: OrderedDict[LiteralString, LiteralString], model_mapping: OrderedDict[LiteralString, LiteralString]):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
import typing as t
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass, _LazyAutoMapping
MODEL_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLM"), ("dolly_v2", "DollyV2"), ("falcon", "Falcon"), ("flan_t5", "FlanT5"), ("gpt_neox", "GPTNeoX"), ("llama", "Llama"), ("mpt", "MPT"), ("opt", "OPT"), ("stablelm", "StableLM"), ("starcoder", "StarCoder"), ("baichuan", "Baichuan")])
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
class AutoLLM(BaseAutoLLMClass):
_model_mapping = MODEL_MAPPING
_model_mapping: t.ClassVar = MODEL_MAPPING

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
import typing as t
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass, _LazyAutoMapping
MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT")])
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
class AutoFlaxLLM(BaseAutoLLMClass):
_model_mapping = MODEL_FLAX_MAPPING
_model_mapping: t.ClassVar = MODEL_FLAX_MAPPING

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
import typing as t
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass, _LazyAutoMapping
MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5"), ("opt", "TFOPT")])
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
class AutoTFLLM(BaseAutoLLMClass):
_model_mapping = MODEL_TF_MAPPING
_model_mapping: t.ClassVar = MODEL_TF_MAPPING

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
import typing as t
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass, _LazyAutoMapping
MODEL_VLLM_MAPPING_NAMES = OrderedDict([("baichuan", "VLLMBaichuan"), ("dolly_v2", "VLLMDollyV2"), ("gpt_neox", "VLLMGPTNeoX"), ("mpt", "VLLMMPT"), ("opt", "VLLMOPT"), ("stablelm", "VLLMStableLM"), ("starcoder", "VLLMStarCoder"), ("llama", "VLLMLlama")])
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
class AutoVLLM(BaseAutoLLMClass):
_model_mapping = MODEL_VLLM_MAPPING
_model_mapping: t.ClassVar = MODEL_VLLM_MAPPING

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_cpm_kernels_available, is_torch_available, is_vllm_available

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import openllm
class BaichuanConfig(openllm.LLMConfig):

View File

@@ -1,9 +1,6 @@
from __future__ import annotations
import sys, typing as t
import openllm
import typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import torch, transformers
@@ -15,6 +12,6 @@ class Baichuan(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrai
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined]
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

View File

@@ -1,15 +1,9 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import vllm, transformers
logger = logging.getLogger(__name__)
class VLLMBaichuan(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizerBase"]):
__openllm_internal__ = True
tokenizer_id = "local"

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_cpm_kernels_available, is_torch_available

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import openllm
class ChatGLMConfig(openllm.LLMConfig):

View File

@@ -1,8 +1,5 @@
from __future__ import annotations
import sys, typing as t
import openllm
import typing as t, openllm
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_torch_available, is_vllm_available

View File

@@ -1,8 +1,5 @@
from __future__ import annotations
import sys, typing as t
import openllm
import typing as t, openllm
if t.TYPE_CHECKING: import transformers
class DollyV2Config(openllm.LLMConfig):

View File

@@ -1,105 +1,107 @@
from __future__ import annotations
import logging
import re
import typing as t
import openllm
import logging, re, typing as t, openllm
from openllm._prompt import process_prompt
from openllm._typing_compat import overload
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE, END_KEY, RESPONSE_KEY, get_special_token_id
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
else: torch, transformers, tf = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("tf", globals(), "tensorflow")
logger = logging.getLogger(__name__)
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
class InstructionTextGenerationPipeline(transformers.Pipeline):
def __init__(self, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, /, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
if t.TYPE_CHECKING: assert self.tokenizer is not None
preprocess_params: dict[str, t.Any] = {}
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
# append a newline to yield a single token. find whatever token is configured for the response key.
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
response_key_token_id = None
end_key_token_id = None
if tokenizer_response_key:
try:
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
# Ensure generation stops once it generates "### End"
generate_kwargs["eos_token_id"] = end_key_token_id
except ValueError: pass
forward_params = generate_kwargs
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
return preprocess_params, forward_params, postprocess_params
def preprocess(self, input_: str, **generate_kwargs: t.Any) -> dict[str, t.Any]:
if t.TYPE_CHECKING: assert self.tokenizer is not None
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
inputs = self.tokenizer(prompt_text, return_tensors="pt")
inputs["prompt_text"] = prompt_text
inputs["instruction_text"] = input_
return inputs
def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any) -> dict[str, t.Any]:
if t.TYPE_CHECKING: assert self.tokenizer is not None
input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None)
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
else: in_b = input_ids.shape[0]
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
instruction_text = model_inputs.pop("instruction_text")
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal["generated_text"], str]]:
if t.TYPE_CHECKING: assert self.tokenizer is not None
_generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist()
records: list[dict[t.Literal["generated_text"], str]] = []
for sequence in generated_sequence:
# The response will be set to this variable if we can identify it.
decoded = None
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
if response_key_token_id and end_key_token_id:
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
# prompt, we should definitely find it. We will return the tokens found after this token.
try: response_pos = sequence.index(response_key_token_id)
except ValueError: response_pos = None
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
if response_pos:
# Next find where "### End" is located. The model has been trained to end its responses with this
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
# this token, as the response could be truncated. If we don't find it then just return everything
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
try: end_pos = sequence.index(end_key_token_id)
except ValueError: end_pos = None
decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip()
if not decoded:
# Otherwise we'll decode everything and use a regex to find the response and end.
fully_decoded = self.tokenizer.decode(sequence)
# The response appears after "### Response:". The model has been trained to append "### End" at the
# end.
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
if m: decoded = m.group(1).strip()
else:
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
# return everything after "### Response:".
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
@overload
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline: ...
@overload
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]: ...
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
class InstructionTextGenerationPipeline(transformers.Pipeline):
def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
if t.TYPE_CHECKING: assert self.tokenizer is not None
preprocess_params: dict[str, t.Any] = {}
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
# append a newline to yield a single token. find whatever token is configured for the response key.
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
response_key_token_id = None
end_key_token_id = None
if tokenizer_response_key:
try:
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
# Ensure generation stops once it generates "### End"
generate_kwargs["eos_token_id"] = end_key_token_id
except ValueError: pass
forward_params = generate_kwargs
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
return preprocess_params, forward_params, postprocess_params
def preprocess(self, input_: str, **generate_kwargs: t.Any) -> t.Dict[str, t.Any]:
if t.TYPE_CHECKING: assert self.tokenizer is not None
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
inputs = self.tokenizer(prompt_text, return_tensors="pt")
inputs["prompt_text"] = prompt_text
inputs["instruction_text"] = input_
return t.cast(t.Dict[str, t.Any], inputs)
def _forward(self, input_tensors: dict[str, t.Any], **generate_kwargs: t.Any) -> transformers.utils.generic.ModelOutput:
if t.TYPE_CHECKING: assert self.tokenizer is not None
input_ids, attention_mask = input_tensors["input_ids"], input_tensors.get("attention_mask", None)
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
else: in_b = input_ids.shape[0]
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
instruction_text = input_tensors.pop("instruction_text")
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal["generated_text"], str]]:
if t.TYPE_CHECKING: assert self.tokenizer is not None
_generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist()
records: list[dict[t.Literal["generated_text"], str]] = []
for sequence in generated_sequence:
# The response will be set to this variable if we can identify it.
decoded = None
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
if response_key_token_id and end_key_token_id:
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
# prompt, we should definitely find it. We will return the tokens found after this token.
try: response_pos = sequence.index(response_key_token_id)
except ValueError: response_pos = None
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
if response_pos:
# Next find where "### End" is located. The model has been trained to end its responses with this
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
# this token, as the response could be truncated. If we don't find it then just return everything
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
try: end_pos = sequence.index(end_key_token_id)
except ValueError: end_pos = None
decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip()
if not decoded:
# Otherwise we'll decode everything and use a regex to find the response and end.
fully_decoded = self.tokenizer.decode(sequence)
# The response appears after "### Response:". The model has been trained to append "### End" at the
# end.
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
if m: decoded = m.group(1).strip()
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
# If the full text is requested, then append the decoded text to the original instruction.
# This technically isn't the full text, as we format the instruction in the prompt the model has been
# trained on, but to the client it will appear to be the full text.
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
records.append({"generated_text": t.cast(str, decoded)})
return records
else:
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
# return everything after "### Response:".
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
if m: decoded = m.group(1).strip()
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
# If the full text is requested, then append the decoded text to the original instruction.
# This technically isn't the full text, as we format the instruction in the prompt the model has been
# trained on, but to the client it will appear to be the full text.
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
records.append({"generated_text": t.cast(str, decoded)})
return records
return InstructionTextGenerationPipeline() if _init else InstructionTextGenerationPipeline
class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedTokenizer"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.bfloat16}, {}
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return InstructionTextGenerationPipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, return_full_text=self.config.return_full_text)
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return get_pipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str: return generation_result[0]["generated_text"]
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:

View File

@@ -1,12 +1,7 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import vllm, transformers
logger = logging.getLogger(__name__)

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_torch_available

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import openllm
class FalconConfig(openllm.LLMConfig):

View File

@@ -1,23 +1,19 @@
from __future__ import annotations
import sys, typing as t
import openllm
import typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import torch, transformers
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() and torch.cuda_device_count() > 1 else None}, {}
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined]
return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import openllm
class FlanT5Config(openllm.LLMConfig):

View File

@@ -1,11 +1,7 @@
from __future__ import annotations
import sys, typing as t
import openllm
import typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")

View File

@@ -1,11 +1,7 @@
from __future__ import annotations
import sys, typing as t
import openllm
import typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import transformers
class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):

View File

@@ -1,11 +1,7 @@
from __future__ import annotations
import sys, typing as t
import openllm
import typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import transformers
class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_torch_available, is_vllm_available

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import openllm
class GPTNeoXConfig(openllm.LLMConfig):

View File

@@ -1,16 +1,11 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import torch, transformers
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}

View File

@@ -1,13 +1,7 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import typing as t, openllm, logging
from openllm._prompt import process_prompt
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import vllm, transformers
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import sys
import typing as t
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_torch_available, is_vllm_available

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import sys, typing as t
import openllm
import typing as t, openllm
class LlamaConfig(openllm.LLMConfig):
"""LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.

View File

@@ -1,18 +1,15 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaTokenizerFast"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
@@ -25,4 +22,4 @@ class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaToke
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item())
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item()))

View File

@@ -1,15 +1,10 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import vllm, transformers
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class VLLMLlama(openllm.LLM["vllm.LLMEngine", "transformers.LlamaTokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_torch_available, is_vllm_available

View File

@@ -1,18 +1,13 @@
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
import logging, typing as t, bentoml, openllm
from openllm._prompt import process_prompt
from openllm.utils import generate_labels, is_triton_available
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE, MPTPromptType
if t.TYPE_CHECKING: import transformers, torch
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def get_mpt_config(model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True) -> transformers.PretrainedConfig:
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
@@ -62,7 +57,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()}
with torch.inference_mode():
if torch.cuda.is_available():
with torch.autocast("cuda", torch.float16):
with torch.autocast("cuda", torch.float16): # type: ignore[attr-defined]
generated_tensors = self.model.generate(**inputs, **attrs)
else: generated_tensors = self.model.generate(**inputs, **attrs)
return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True)

View File

@@ -1,28 +1,8 @@
# Copyright 2023 BentoML Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE, MPTPromptType
if t.TYPE_CHECKING:
import transformers
import vllm
if t.TYPE_CHECKING: import transformers, vllm
logger = logging.getLogger(__name__)
class VLLMMPT(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]):

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import openllm
class OPTConfig(openllm.LLMConfig):

View File

@@ -1,14 +1,8 @@
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
import logging, typing as t, bentoml, openllm
from openllm._prompt import process_prompt
from openllm.utils import generate_labels
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import transformers
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")

View File

@@ -1,12 +1,7 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import torch, transformers
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")

View File

@@ -1,18 +1,12 @@
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
import logging, typing as t, bentoml, openllm
from openllm._prompt import process_prompt
from openllm.utils import generate_labels
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import transformers
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
__openllm_internal__ = True
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:

View File

@@ -1,12 +1,7 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import vllm, transformers
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import sys
import typing as t
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_torch_available, is_vllm_available

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import openllm
class StableLMConfig(openllm.LLMConfig):

View File

@@ -1,10 +1,6 @@
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT
if t.TYPE_CHECKING: import transformers, torch

View File

@@ -1,25 +1,7 @@
# Copyright 2023 BentoML Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from openllm._prompt import process_prompt
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT
if t.TYPE_CHECKING: import vllm, transformers
logger = logging.getLogger(__name__)

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys, typing as t
from openllm.exceptions import MissingDependencyError
from openllm.utils import LazyModule, is_torch_available, is_vllm_available

View File

@@ -1,17 +1,11 @@
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
import logging, typing as t, bentoml, openllm
from openllm.utils import generate_labels
from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
if t.TYPE_CHECKING: import torch, transformers
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
__openllm_internal__ = True
@property
@@ -33,7 +27,6 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
# default starcoder doesn't include the same value as santacoder EOD
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode():

View File

@@ -1,27 +1,9 @@
# Copyright 2023 BentoML Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import openllm
import logging, typing as t, openllm
from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
if t.TYPE_CHECKING: import vllm, transformers
logger = logging.getLogger(__name__)
class VLLMStarCoder(openllm.LLM["vllm.LLMEngine", "transformers.GPT2TokenizerFast"]):
__openllm_internal__ = True
tokenizer_id = "local"

View File

@@ -23,24 +23,20 @@ llm.save_pretrained("./path/to/local-dolly")
```
"""
from __future__ import annotations
import importlib
import typing as t
import cloudpickle
import fs
import openllm
import importlib, typing as t
import cloudpickle, fs, openllm
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
from openllm._typing_compat import M, T, ParamSpec, Concatenate
if t.TYPE_CHECKING:
from openllm._llm import T
import bentoml
from . import (
constants as constants,
ggml as ggml,
transformers as transformers,
)
P = ParamSpec("P")
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
"""Load the tokenizer from BentoML store.
@@ -67,9 +63,8 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
return tokenizer
_extras = ["get", "import_model", "save_pretrained", "load_model"]
def _make_dispatch_function(fn: str) -> t.Callable[..., t.Any]:
def caller(llm: openllm.LLM[t.Any, t.Any], *args: t.Any, **kwargs: t.Any) -> t.Any:
def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[openllm.LLM[t.Any, t.Any], P], t.Any]:
def caller(llm: openllm.LLM[t.Any, t.Any], *args: P.args, **kwargs: P.kwargs) -> t.Any:
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.runtime="transformers"'
@@ -77,11 +72,15 @@ def _make_dispatch_function(fn: str) -> t.Callable[..., t.Any]:
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.runtime="ggml"'
"""
return getattr(importlib.import_module(f".{llm.runtime}", __name__), fn)(llm, *args, **kwargs)
return caller
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": [], "constants": []}
if t.TYPE_CHECKING:
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
def save_pretrained(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> None: ...
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ...
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": [], "constants": []}
__all__ = ["ggml", "transformers", "constants", "load_tokenizer", *_extras]
def __dir__() -> list[str]: return sorted(__all__)
def __getattr__(name: str) -> t.Any:

View File

@@ -1,10 +1,3 @@
from __future__ import annotations
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
# NOTE: vllm will use PyTorch implementation of transformers for serialisation
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")
}
# this logic below is synonymous to handling `from_pretrained` attrs.
FRAMEWORK_TO_AUTOCLASS_MAPPING = {"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), "vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")}
HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token"]

View File

@@ -4,18 +4,13 @@ This requires ctransformers to be installed.
"""
from __future__ import annotations
import typing as t
import bentoml, openllm
import bentoml
import openllm
if t.TYPE_CHECKING:
from openllm._llm import M
if t.TYPE_CHECKING: from openllm._typing_compat import M
_conversion_strategy = {"pt": "ggml"}
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
raise NotImplementedError("Currently work in progress.")
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model: raise NotImplementedError("Currently work in progress.")
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
"""Return an instance of ``bentoml.Model`` from given LLM instance.
@@ -35,14 +30,5 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
if auto_import:
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
raise
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
"""Load the model from BentoML store.
By default, it will try to find check the model in the local store.
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
raise NotImplementedError("Currently work in progress.")
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None:
raise NotImplementedError("Currently work in progress.")
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: raise NotImplementedError("Currently work in progress.")
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None: raise NotImplementedError("Currently work in progress.")

View File

@@ -1,19 +1,12 @@
# mypy: disable-error-code="name-defined,misc"
"""Serialisation related implementation for Transformers-based implementation."""
from __future__ import annotations
import importlib
import logging
import typing as t
import importlib, logging, typing as t
import bentoml, openllm
from huggingface_hub import snapshot_download
from simple_di import Provide, inject
import bentoml
import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelOptions
from openllm.serialisation.transformers.weights import HfIgnore
from .weights import HfIgnore
from ._helpers import (
check_unintialised_params,
infer_autoclass_from_llm,
@@ -26,16 +19,16 @@ from ._helpers import (
if t.TYPE_CHECKING:
import types
import vllm, auto_gptq as autogptq, transformers ,torch
import torch.nn
from bentoml._internal.models import ModelStore
from openllm._llm import M, T
from openllm._types import DictStrAny
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq")
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
from openllm._typing_compat import DictStrAny, M, T
else:
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq")
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
logger = logging.getLogger(__name__)

View File

@@ -1,25 +1,14 @@
from __future__ import annotations
import copy
import typing as t
import openllm
import copy, typing as t, openllm
from bentoml._internal.models.model import ModelInfo, ModelSignature
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, HUB_ATTRS
if t.TYPE_CHECKING:
import torch
import transformers
import torch, transformers, bentoml
from transformers.models.auto.auto_factory import _BaseAutoModelClass
import bentoml
from bentoml._internal.models.model import ModelSignaturesType
from ..._llm import M, T
from ..._types import DictStrAny
else:
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
from openllm._typing_compat import DictStrAny, M, T
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
_object_setattr = object.__setattr__

View File

@@ -1,22 +1,17 @@
from __future__ import annotations
import typing as t
import attr
import typing as t, attr
from huggingface_hub import HfApi
if t.TYPE_CHECKING:
import openllm
from openllm._llm import M, T
from openllm._typing_compat import M, T
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool: return any(s.rfilename.endswith(".safetensors") for s in HfApi().model_info(model_id, revision=revision).siblings)
@attr.define(slots=True)
class HfIgnore:
safetensors = "*.safetensors"
pt = "*.bin"
tf = "*.h5"
flax = "*.msgpack"
@classmethod
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
if llm.__llm_implementation__ == "vllm": base = [cls.tf, cls.flax, cls.safetensors]

View File

@@ -1,37 +1,24 @@
"""Tests utilities for OpenLLM."""
from __future__ import annotations
import contextlib
import logging
import shutil
import subprocess
import typing as t
import bentoml
import openllm
if t.TYPE_CHECKING:
from ._configuration import LiteralRuntime
import contextlib, logging, shutil, subprocess, typing as t, bentoml, openllm
if t.TYPE_CHECKING: from ._typing_compat import LiteralRuntime
logger = logging.getLogger(__name__)
@contextlib.contextmanager
def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False,) -> t.Iterator[bentoml.Bento]:
def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False) -> t.Iterator[bentoml.Bento]:
logger.info("Building BentoML for %s", model)
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
yield bento
if cleanup:
logger.info("Deleting %s", bento.tag)
bentoml.bentos.delete(bento.tag)
@contextlib.contextmanager
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any,) -> t.Iterator[str]:
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any) -> t.Iterator[str]:
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
else: bento_tag = bentoml.Tag.from_taglike(bento)
if image_tag is None: image_tag = str(bento_tag)
executable = shutil.which("docker")
if not executable: raise RuntimeError("docker executable not found")
try:
logger.info("Building container for %s", bento_tag)
bentoml.container.build(bento_tag, backend="docker", image_tag=(image_tag,), progress="plain", **attrs,)
@@ -40,19 +27,15 @@ def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | N
if cleanup:
logger.info("Deleting container %s", image_tag)
subprocess.check_output([executable, "rmi", "-f", image_tag])
@contextlib.contextmanager
def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True,) -> t.Iterator[str]:
def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True) -> t.Iterator[str]:
if clean_context is None:
clean_context = contextlib.ExitStack()
cleanup = True
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
else: bento = bentoml.get(bento_tag)
container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_")
if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
yield container_name

View File

@@ -4,19 +4,9 @@ User can import these function for convenience, but
we won't ensure backward compatibility for these functions. So use with caution.
"""
from __future__ import annotations
import contextlib
import functools
import hashlib
import logging
import logging.config
import os
import sys
import types
import typing as t
import contextlib, functools, hashlib, logging, logging.config, os, sys, types, typing as t, openllm
from pathlib import Path
from circus.exc import ConflictError
from bentoml._internal.configuration import (
DEBUG_ENV_VAR as DEBUG_ENV_VAR,
GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR,
@@ -41,20 +31,14 @@ from openllm.utils.lazy import (
VersionInfo as VersionInfo,
)
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from openllm._typing_compat import AnyCallable, LiteralRuntime
logger = logging.getLogger(__name__)
try: from typing import GenericAlias as _TypingGenericAlias # type: ignore
except ImportError: _TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
else: _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # _GenericAlias is the actual GenericAlias implementation
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11): from typing import overload as _overload
else: from typing_extensions import overload as _overload
if t.TYPE_CHECKING:
import openllm
from .._types import AnyCallable, LiteralRuntime
DEV_DEBUG_VAR = "OPENLLMDEVDEBUG"
@@ -194,9 +178,7 @@ def compose(*funcs: AnyCallable) -> AnyCallable:
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable:
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
@@ -241,7 +223,6 @@ def resolve_filepath(path: str, ctx: str | None = None) -> str:
def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
def generate_context(framework_name: str) -> _ModelContext:
import openllm
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
if openllm.utils.is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
if openllm.utils.is_tf_available():
@@ -261,14 +242,6 @@ def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[dict[str, t
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
return attrs, tokenizer_attrs
@_overload
def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]: ...
@_overload
def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]: ...
@_overload
def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]: ...
@_overload
def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]: ...
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM] | type[openllm.AutoTFLLM] | type[openllm.AutoFlaxLLM] | type[openllm.AutoVLLM]:
import openllm
if implementation == "tf": return openllm.AutoTFLLM

View File

@@ -3,33 +3,17 @@
Users can disable this with OPENLLM_DO_NOT_TRACK envvar.
"""
from __future__ import annotations
import contextlib
import functools
import importlib.metadata
import logging
import os
import re
import sys
import typing as t
import attr
import openllm
import contextlib, functools, logging, os, re, typing as t, importlib.metadata
import attr, openllm
from bentoml._internal.utils import analytics as _internal_analytics
if sys.version_info[:2] >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
from openllm._typing_compat import ParamSpec
P = ParamSpec("P")
T = t.TypeVar("T")
logger = logging.getLogger(__name__)
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK"
DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper()
@functools.lru_cache(maxsize=1)
@@ -75,18 +59,15 @@ class EventMeta:
class ModelSaveEvent(EventMeta):
module: str
model_size_in_kb: float
@attr.define
class OpenllmCliEvent(EventMeta):
cmd_group: str
cmd_name: str
openllm_version: str = importlib.metadata.version("openllm")
# NOTE: reserved for the do_not_track logics
duration_in_ms: t.Any = attr.field(default=None)
error_type: str = attr.field(default=None)
return_code: int = attr.field(default=None)
@attr.define
class StartInitEvent(EventMeta):
model_name: str

View File

@@ -1,22 +1,14 @@
from __future__ import annotations
import functools
import inspect
import linecache
import logging
import string
import types
import typing as t
import functools, inspect, linecache, os, logging, string, types, typing as t
from operator import itemgetter
from pathlib import Path
import orjson
if t.TYPE_CHECKING:
from fs.base import FS
import openllm
from .._types import AnyCallable, DictStrAny, ListStr
from openllm._typing_compat import LiteralString, AnyCallable, DictStrAny, ListStr
PartialAny = functools.partial[t.Any]
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
@@ -24,7 +16,7 @@ logger = logging.getLogger(__name__)
OPENLLM_MODEL_NAME = "# openllm: model name"
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
class ModelNameFormatter(string.Formatter):
model_keyword: t.LiteralString = "__model_name__"
model_keyword: LiteralString = "__model_name__"
def __init__(self, model_name: str):
"""The formatter that extends model_name to be formatted the 'service.py'."""
super().__init__()
@@ -36,14 +28,13 @@ class ModelNameFormatter(string.Formatter):
return True
except ValueError: return False
class ModelIdFormatter(ModelNameFormatter):
model_keyword: t.LiteralString = "__model_id__"
model_keyword: LiteralString = "__model_id__"
class ModelAdapterMapFormatter(ModelNameFormatter):
model_keyword: t.LiteralString = "__model_adapter_map__"
model_keyword: LiteralString = "__model_adapter_map__"
_service_file = Path(__file__).parent.parent / "_service.py"
_service_file = Path(os.path.abspath("__file__")).parent.parent/"_service.py"
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
from . import DEBUG
from openllm.utils import DEBUG
model_name = llm.config["model_name"]
logger.debug("Generating service file for %s at %s (dir=%s)", model_name, llm.config["service_name"], llm_fs.getsyspath("/"))
with open(_service_file.__fspath__(), "r") as f: src_contents = f.readlines()
@@ -119,33 +110,26 @@ def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.
return globs[attr_class_name]
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None,) -> AnyCallable:
from . import SHOW_CODEGEN
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None) -> AnyCallable:
from openllm.utils import SHOW_CODEGEN
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
if annotations: meth.__annotations__ = annotations
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
return meth
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: t.LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
from . import dantic, field_env_key
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
from openllm.utils import dantic, field_env_key
def identity(_: str, x_value: t.Any) -> t.Any: return x_value
default_callback = identity if default_callback is None else default_callback
globs = {} if globs is None else globs
globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,})
lines: ListStr = ["__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },", " )", " for f in fields", "]"]
fields_ann = "list[attr.Attribute[t.Any]]"
return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann})
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
"""Enhance sdk with nice repr that plays well with your brain."""
from .representation import ReprMixin
from openllm.utils import ReprMixin
if name is None: name = func.__name__.strip("_")
_signatures = inspect.signature(func).parameters
def _repr(self: ReprMixin) -> str: return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
@@ -153,3 +137,5 @@ def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
else: doc = func.__doc__
return t.cast(_T, functools.update_wrapper(types.new_class(name, (t.cast("PartialAny", functools.partial), ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,))
__all__ = ["gen_sdk", "make_attr_tuple_class", "make_env_transformer", "generate_unique_filename", "generate_function", "OPENLLM_MODEL_NAME", "OPENLLM_MODEL_ADAPTER_MAP"]

View File

@@ -1,30 +1,22 @@
"""An interface provides the best of pydantic and attrs."""
from __future__ import annotations
import functools
import importlib
import os
import sys
import typing as t
import functools, importlib, os, sys, typing as t
from enum import Enum
import attr
import click
import click_option_group as cog
import inflection
import orjson
import attr, click, click_option_group as cog, inflection, orjson
from click import (
ParamType,
shell_completion as sc,
types as click_types,
)
if t.TYPE_CHECKING:
from attr import _ValidatorType
if t.TYPE_CHECKING: from attr import _ValidatorType
_T = t.TypeVar("_T")
AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
__all__ = ["FC", "attrs_to_options", "Field", "parse_type", "is_typing", "is_literal", "ModuleType", "EnumChoice", "LiteralChoice", "allows_multiple", "is_mapping", "is_container", "parse_container_args", "parse_single_arg", "CUDA", "JsonType", "BytesType"]
def __dir__() -> list[str]: return sorted(__all__)
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]:
# TODO: support parsing nested attrs class and Union
envvar = field.metadata["env"]
@@ -47,13 +39,11 @@ def env_converter(value: t.Any, env: str | None = None) -> t.Any:
if env is not None:
value = os.environ.get(env, value)
if value is not None and isinstance(value, str):
try:
return orjson.loads(value.lower())
except orjson.JSONDecodeError as err:
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
try: return orjson.loads(value.lower())
except orjson.JSONDecodeError as err: raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
return value
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[_T] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any,) -> t.Any:
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[t.Any] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any) -> t.Any:
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
By default, if both validator and ge are provided, then then ge will be

View File

@@ -1,34 +1,16 @@
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
from __future__ import annotations
import importlib
import importlib.metadata
import importlib.util
import logging
import os
import sys
import typing as t
from abc import ABCMeta
import importlib, importlib.metadata, importlib.util, logging, os, abc, typing as t
from collections import OrderedDict
import inflection
from packaging import version
import inflection, packaging.version
from bentoml._internal.utils import LazyLoader, pkg
from openllm._typing_compat import overload, LiteralString
from .representation import ReprMixin
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
from .._types import LiteralRuntime
else:
BackendOrderredDict = OrderedDict
BackendOrderedDict = OrderedDict[str, t.Tuple[t.Callable[[], bool], str]]
from openllm._typing_compat import LiteralRuntime
logger = logging.getLogger(__name__)
OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",}
@@ -104,10 +86,10 @@ def is_tf_available() -> bool:
try:
_tf_version = importlib.metadata.version(_pkg)
break
except importlib.metadata.PackageNotFoundError: pass
except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
_tf_available = _tf_version is not None
if _tf_available:
if _tf_version and version.parse(_tf_version) < version.parse("2"):
if _tf_version and packaging.version.parse(_tf_version) < packaging.version.parse("2"):
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
_tf_available = False
else:
@@ -232,11 +214,13 @@ You can install it with pip: `pip install fairscale`. Please note that you may n
your runtime after installation.
"""
BACKENDS_MAPPING = BackendOrderredDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))])
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))])
class DummyMetaclass(ABCMeta):
class DummyMetaclass(abc.ABCMeta):
"""Metaclass for dummy object.
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
@@ -258,19 +242,17 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
failed = [msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()]
if failed: raise ImportError("".join(failed))
class EnvVarMixin(ReprMixin):
model_name: str
if t.TYPE_CHECKING:
config: str
model_id: str
quantize: str
framework: str
bettertransformer: str
runtime: str
config: str
model_id: str
quantize: str
framework: str
bettertransformer: str
runtime: str
@overload
def __getitem__(self, item: t.Literal["config"]) -> str: ...
@overload
@@ -297,9 +279,9 @@ class EnvVarMixin(ReprMixin):
if item.endswith("_value") and hasattr(self, f"_{item}"): return object.__getattribute__(self, f"_{item}")()
elif hasattr(self, item): return getattr(self, item)
raise KeyError(f"Key {item} not found in {self}")
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
"""EnvVarMixin is a mixin class that returns the value extracted from environment variables."""
from .._configuration import field_env_key
from openllm._configuration import field_env_key
self.model_name = inflection.underscore(model_name)
self._implementation = implementation
self._model_id = model_id

View File

@@ -1,19 +1,6 @@
from __future__ import annotations
import functools
import importlib
import importlib.machinery
import importlib.metadata
import importlib.util
import itertools
import os
import time
import types
import typing as t
import warnings
import attr
import openllm
import functools, importlib, importlib.machinery, importlib.metadata, importlib.util, itertools, os, time, types, warnings, typing as t
import attr, openllm
__all__ = ["VersionInfo", "LazyModule"]
# vendorred from attrs
@@ -68,7 +55,7 @@ class LazyModule(types.ModuleType):
for key, values in import_structure.items():
for value in values: self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
self.__all__: list[str] = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec or importlib.util.find_spec(name)
self.__path__ = [os.path.dirname(module_file)]
@@ -93,7 +80,7 @@ class LazyModule(types.ModuleType):
if name in dunder_to_metadata:
if name not in {"__version_info__", "__copyright__", "__version__"}: warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2)
meta = importlib.metadata.metadata("openllm")
project_url = dict(url.split(", ") for url in meta.get_all("Project-URL"))
project_url = dict(url.split(", ") for url in t.cast(t.List[str], meta.get_all("Project-URL")))
if name == "__license__": return "Apache-2.0"
elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
elif name in ("__uri__", "__url__"): return project_url["GitHub"]

View File

@@ -1,57 +1,32 @@
from __future__ import annotations
import typing as t
from abc import abstractmethod
import attr, orjson
from openllm import utils
if t.TYPE_CHECKING: from openllm._typing_compat import TypeAlias
import attr
import orjson
if t.TYPE_CHECKING:
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None]
class ReprMixin:
"""This class display possible representation of given class.
It can be used for implementing __rich_pretty__ and __pretty__ methods in the future.
Most subclass needs to implement a __repr_keys__ property.
Based on the design from Pydantic.
The __repr__ will display the json representation of the object for easier interaction.
The __str__ will display either __attrs_repr__ or __repr_str__.
"""
@property
@abstractmethod
def __repr_keys__(self) -> set[str]:
"""This can be overriden by base class using this mixin."""
def __repr_keys__(self) -> set[str]: raise NotImplementedError
"""This can be overriden by base class using this mixin."""
def __repr__(self) -> str: return f"{self.__class__.__name__} {orjson.dumps({k: utils.bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}"
"""The `__repr__` for any subclass of Mixin.
def __repr__(self) -> str:
"""The `__repr__` for any subclass of Mixin.
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
"""
def __str__(self) -> str: return self.__repr_str__(" ")
"""The string representation of the given Mixin subclass.
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
"""
from . import bentoml_cattr
It will contains all of the attributes from __repr_keys__
"""
def __repr_name__(self) -> str: return self.__class__.__name__
"""Name of the instance's class, used in __repr__."""
def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
"""To be used with __str__."""
def __repr_args__(self) -> ReprArgs: return ((k, getattr(self, k)) for k in self.__repr_keys__)
"""This can also be overriden by base class using this mixin.
serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}
return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}"
def __str__(self) -> str:
"""The string representation of the given Mixin subclass.
It will contains all of the attributes from __repr_keys__
"""
return self.__repr_str__(" ")
def __repr_name__(self) -> str:
"""Name of the instance's class, used in __repr__."""
return self.__class__.__name__
def __repr_str__(self, join_str: str) -> str:
"""To be used with __str__."""
return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
def __repr_args__(self) -> ReprArgs:
"""This can also be overriden by base class using this mixin.
By default it does a getattr of the current object from __repr_keys__.
"""
return ((k, getattr(self, k)) for k in self.__repr_keys__)
By default it does a getattr of the current object from __repr_keys__.
"""

View File

@@ -1,16 +1,3 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os

View File

@@ -1,13 +0,0 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -1,17 +1,3 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import typing as t

Some files were not shown because too many files have changed in this diff Show More