mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-26 10:01:30 -04:00
infra: enable compiled wheels for all supported Python (#201)
This commit is contained in:
13
.github/workflows/compile-pypi.yml
vendored
13
.github/workflows/compile-pypi.yml
vendored
@@ -32,7 +32,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
pure-wheels-sdist:
|
||||
name: Building pure wheels and sdist
|
||||
name: Pure wheels and sdist distribution
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@@ -50,31 +50,26 @@ jobs:
|
||||
path: dist/*
|
||||
if-no-files-found: error
|
||||
mypyc:
|
||||
name: Building compiled mypyc wheels (${{ matrix.name }})
|
||||
name: Compiled mypyc wheels (${{ matrix.name }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
# NOTE: ubuntu x86
|
||||
- os: ubuntu-latest
|
||||
name: linux-x86_64
|
||||
python: 311
|
||||
# NOTE: darwin amd64
|
||||
- os: macos-latest
|
||||
name: macos-x86_64
|
||||
macos_arch: "x86_64"
|
||||
python: 311
|
||||
# NOTE: darwin arm64
|
||||
- os: macos-latest
|
||||
name: macos-arm64
|
||||
macos_arch: "arm64"
|
||||
python: 311
|
||||
# NOTE: darwin universal2
|
||||
- os: macos-latest
|
||||
name: macos-universal2
|
||||
macos_arch: "universal2"
|
||||
python: 311
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
@@ -85,9 +80,9 @@ jobs:
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Build wheels via cibuildwheel
|
||||
uses: pypa/cibuildwheel@v2.14.1
|
||||
uses: pypa/cibuildwheel@v2.15.0
|
||||
env:
|
||||
CIBW_BUILD: cp${{ matrix.python }}-*
|
||||
CIBW_BEFORE_BUILD_MACOS: "rustup target add aarch64-apple-darwin"
|
||||
CIBW_ARCHS_MACOS: "${{ matrix.macos_arch }}"
|
||||
MYPYPATH: /project/typings
|
||||
- name: Upload wheels as workflow artifacts
|
||||
|
||||
@@ -19,7 +19,7 @@ All the relevant code for incorporating a new model resides within
|
||||
`src/openllm/models/{model_name}/__init__.py`
|
||||
- [ ] Adjust the entrypoints for files at `src/openllm/models/auto/*`
|
||||
- [ ] Modify the main `__init__.py`: `src/openllm/models/__init__.py`
|
||||
- [ ] Run the following script to update dummy objects: `hatch run update-dummy`
|
||||
- [ ] Run the following to update stubs: `hatch run check-stubs`
|
||||
|
||||
For a working example, check out any pre-implemented model.
|
||||
|
||||
|
||||
7
STYLE.md
7
STYLE.md
@@ -87,12 +87,11 @@ _If you have any suggestions, feel free to give it on our discord server!_
|
||||
def foo(x): return rotate_cv(x) if x > 0 else -x
|
||||
```
|
||||
|
||||
- imports should be grouped by their types, and each import should be designated
|
||||
on its own line.
|
||||
- imports should be grouped by their types: standard library, third-party, and local
|
||||
|
||||
```python
|
||||
import os
|
||||
import sys
|
||||
import os, sys
|
||||
import orjson, bentoml
|
||||
```
|
||||
This is partially to make it easier to work with merge-conflicts, and easier
|
||||
for IDE to navigate context definition.
|
||||
|
||||
1
changelog.d/201.feature.md
Normal file
1
changelog.d/201.feature.md
Normal file
@@ -0,0 +1 @@
|
||||
Added all compiled wheels for all supported Python version for Linux and MacOS
|
||||
12
hatch.toml
12
hatch.toml
@@ -75,25 +75,21 @@ dependencies = [
|
||||
]
|
||||
[envs.default.scripts]
|
||||
changelog = "towncrier build --version main --draft"
|
||||
check-stubs = [
|
||||
"./tools/update-config-stubs.py",
|
||||
"./tools/update-models-import.py",
|
||||
"./tools/update-init-import.py",
|
||||
"update-dummy",
|
||||
]
|
||||
check-stubs = ["./tools/update-config-stubs.py", "./tools/update-models-import.py", "update-dummy"]
|
||||
compile = "bash ./compile.sh {args}"
|
||||
update-dummy = ["- ./tools/update-dummy.py", "./tools/update-dummy.py"]
|
||||
inplace-changelog = "towncrier build --version main --keep"
|
||||
quality = [
|
||||
"./tools/dependencies.py",
|
||||
"./tools/update-readme.py",
|
||||
"- ./tools/update-brew-tap.py",
|
||||
"check-stubs",
|
||||
"pre-commit run --all-files",
|
||||
"- pre-commit run --all-files",
|
||||
]
|
||||
recompile = ["bash ./clean.sh", "compile"]
|
||||
setup = "pre-commit install"
|
||||
tool = ["quality", "recompile -nx"]
|
||||
typing = ["- pre-commit run mypy {args:-a}", "- pre-commit run pyright {args:-a}"]
|
||||
update-dummy = ["- ./tools/update-dummy.py", "./tools/update-dummy.py"]
|
||||
[envs.tests]
|
||||
dependencies = [
|
||||
# NOTE: interact with docker for container tests.
|
||||
|
||||
@@ -44,6 +44,7 @@ dependencies = [
|
||||
"httpx",
|
||||
"click>=8.1.3",
|
||||
"typing_extensions",
|
||||
"mypy_extensions",
|
||||
"ghapi",
|
||||
"cuda-python;platform_system!=\"Darwin\"",
|
||||
"bitsandbytes<0.42",
|
||||
@@ -124,12 +125,12 @@ vllm = ["vllm", "ray"]
|
||||
|
||||
[tool.cibuildwheel]
|
||||
build-verbosity = 1
|
||||
# So the following enviuronment will be targeted for compiled wheels:
|
||||
# - Python: CPython 3.8+ only
|
||||
# So the following environment will be targeted for compiled wheels:
|
||||
# - Python: CPython 3.8-3.11
|
||||
# - Architecture (64-bit only): amd64 / x86_64, universal2, and arm64
|
||||
# - OS: Linux (no musl), and macOS
|
||||
build = "cp3*-*"
|
||||
skip = ["*-manylinux_i686", "*-musllinux_*", "*-win32", "pp-*"]
|
||||
skip = ["*-manylinux_i686", "*-musllinux_*", "*-win32", "pp-*", "cp312-*"]
|
||||
|
||||
[tool.cibuildwheel.environment]
|
||||
HATCH_BUILD_HOOKS_ENABLE = "1"
|
||||
@@ -183,6 +184,9 @@ fail-under = 100
|
||||
verbose = 2
|
||||
whitelist-regex = ["test_.*"]
|
||||
|
||||
[tool.check-wheel-contents]
|
||||
toplevel = ["openllm"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
@@ -194,7 +198,6 @@ extend-exclude = [
|
||||
"examples",
|
||||
"src/openllm/playground",
|
||||
"src/openllm/__init__.py",
|
||||
"src/openllm/_types.py",
|
||||
"src/openllm/_version.py",
|
||||
"src/openllm/utils/dummy_*.py",
|
||||
"src/openllm/models/__init__.py",
|
||||
@@ -233,7 +236,9 @@ ignore = [
|
||||
"PLR0915",
|
||||
"PLR2004", # magic value to use constant
|
||||
"E501", # ignore line length violation
|
||||
"E401", # ignore multiple line import
|
||||
"E702",
|
||||
"I001", # unsorted imports
|
||||
"PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs.
|
||||
"D103", # Just missing docstring for magic methods.
|
||||
"D102",
|
||||
@@ -245,13 +250,10 @@ ignore = [
|
||||
"D105", # magic docstring
|
||||
"E701", # multiple statement on single line
|
||||
]
|
||||
line-length = 119
|
||||
target-version = "py312"
|
||||
unfixable = [
|
||||
"F401", # Don't touch unused imports, just warn about it.
|
||||
"TCH004", # Don't touch import outside of TYPE_CHECKING block
|
||||
"RUF100", # unused noqa, just warn about it
|
||||
]
|
||||
line-length = 768
|
||||
target-version = "py38"
|
||||
typing-modules = ["openllm._typing_compat"]
|
||||
unfixable = ["TCH004"]
|
||||
[tool.ruff.flake8-type-checking]
|
||||
exempt-modules = ["typing", "typing_extensions", "."]
|
||||
runtime-evaluated-base-classes = [
|
||||
@@ -260,7 +262,7 @@ runtime-evaluated-base-classes = [
|
||||
"openllm._configuration.GenerationConfig",
|
||||
"openllm._configuration.ModelSettings",
|
||||
]
|
||||
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen"]
|
||||
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen", "trait"]
|
||||
[tool.ruff.pydocstyle]
|
||||
convention = "google"
|
||||
[tool.ruff.pycodestyle]
|
||||
@@ -271,7 +273,7 @@ force-single-line = false
|
||||
force-wrap-aliases = true
|
||||
known-first-party = ["openllm", "bentoml"]
|
||||
known-third-party = ["transformers", "click", "huggingface_hub", "torch", "vllm", "auto_gptq"]
|
||||
lines-after-imports = 1
|
||||
lines-after-imports = 0
|
||||
lines-between-types = 0
|
||||
no-lines-before = ["future", "standard-library"]
|
||||
relative-imports-order = "closest-to-furthest"
|
||||
@@ -279,12 +281,11 @@ required-imports = ["from __future__ import annotations"]
|
||||
[tool.ruff.flake8-quotes]
|
||||
avoid-escape = false
|
||||
[tool.ruff.extend-per-file-ignores]
|
||||
"src/openllm/_service.py" = ["I001", "E401"]
|
||||
"src/openllm/_service.py" = ["E401"]
|
||||
"src/openllm/cli/entrypoint.py" = ["D301"]
|
||||
"src/openllm/models/**" = ["I001", "E", "D", "F"]
|
||||
"src/openllm/utils/__init__.py" = ["I001"]
|
||||
"src/openllm/utils/import_utils.py" = ["PLW0603"]
|
||||
"src/openllm/client/runtimes/*" = ["D107"]
|
||||
"src/openllm/models/**" = ["E", "D", "F"]
|
||||
"src/openllm/utils/import_utils.py" = ["PLW0603"]
|
||||
"tests/**/*" = [
|
||||
"S101",
|
||||
"TID252",
|
||||
@@ -348,6 +349,7 @@ omit = [
|
||||
"src/openllm/__init__.py",
|
||||
"src/openllm/__main__.py",
|
||||
"src/openllm/utils/dummy_*.py",
|
||||
"src/openllm/_typing_compat.py",
|
||||
]
|
||||
source_pkgs = ["openllm"]
|
||||
[tool.coverage.report]
|
||||
@@ -378,6 +380,7 @@ omit = [
|
||||
"src/openllm/__init__.py",
|
||||
"src/openllm/__main__.py",
|
||||
"src/openllm/utils/dummy_*.py",
|
||||
"src/openllm/_typing_compat.py",
|
||||
]
|
||||
precision = 2
|
||||
show_missing = true
|
||||
@@ -387,16 +390,17 @@ analysis.useLibraryCodeForTypes = true
|
||||
exclude = [
|
||||
"__pypackages__/*",
|
||||
"src/openllm/playground/",
|
||||
"src/openllm/models/",
|
||||
"src/openllm/__init__.py",
|
||||
"src/openllm/__main__.py",
|
||||
"src/openllm/utils/dummy_*.py",
|
||||
"src/openllm/models",
|
||||
"src/openllm/_typing_compat.py",
|
||||
"tools",
|
||||
"examples",
|
||||
"tests",
|
||||
]
|
||||
include = ["src/openllm"]
|
||||
pythonVersion = "3.12"
|
||||
pythonVersion = "3.8"
|
||||
reportMissingImports = "warning"
|
||||
reportMissingTypeStubs = false
|
||||
reportPrivateUsage = "warning"
|
||||
@@ -408,13 +412,16 @@ reportWildcardImportFromLibrary = "warning"
|
||||
typeCheckingMode = "strict"
|
||||
|
||||
[tool.mypy]
|
||||
# TODO: Enable model for strict type checking
|
||||
exclude = ["src/openllm/playground/", "src/openllm/utils/dummy_*.py", "src/openllm/models"]
|
||||
local_partial_types = true
|
||||
exclude = [
|
||||
"src/openllm/playground/",
|
||||
"src/openllm/utils/dummy_*.py",
|
||||
"src/openllm/models",
|
||||
"src/openllm/_typing_compat.py",
|
||||
]
|
||||
modules = ["openllm"]
|
||||
mypy_path = "typings"
|
||||
pretty = true
|
||||
python_version = "3.11"
|
||||
python_version = "3.8"
|
||||
show_error_codes = true
|
||||
warn_no_return = false
|
||||
warn_return_any = false
|
||||
@@ -430,6 +437,7 @@ module = [
|
||||
"optimum.*",
|
||||
"inflection.*",
|
||||
"huggingface_hub.*",
|
||||
"click_option_group.*",
|
||||
"peft.*",
|
||||
"auto_gptq.*",
|
||||
"vllm.*",
|
||||
@@ -437,13 +445,13 @@ module = [
|
||||
"httpx.*",
|
||||
"cloudpickle.*",
|
||||
"circus.*",
|
||||
"grpc_health.*",
|
||||
"grpc_health.v1.*",
|
||||
"transformers.*",
|
||||
"ghapi.*",
|
||||
]
|
||||
[[tool.mypy.overrides]]
|
||||
ignore_errors = true
|
||||
module = ["openllm.models.*", "openllm._types", "openllm.playground.*"]
|
||||
module = ["openllm.models.*", "openllm.playground.*", "openllm._typing_compat"]
|
||||
|
||||
[tool.hatch.version]
|
||||
fallback-version = "0.0.0"
|
||||
@@ -476,24 +484,12 @@ dependencies = [
|
||||
"types-protobuf",
|
||||
]
|
||||
enable-by-default = false
|
||||
exclude = [
|
||||
# no reason to compile model since framework is already written in C++
|
||||
"src/openllm/models/**",
|
||||
"src/openllm/bundle/**",
|
||||
"src/openllm/cli/**",
|
||||
"src/openllm/playground/**",
|
||||
"src/openllm/__main__.py",
|
||||
"src/openllm/_types.py",
|
||||
"src/openllm/_llm.py",
|
||||
"src/openllm/_service.py",
|
||||
"src/openllm/_configuration.py",
|
||||
# can't compile serialisation for transformers since it will raise segfault
|
||||
"src/openllm/serialisation/transformers",
|
||||
"src/openllm/utils/analytics.py",
|
||||
]
|
||||
include = [
|
||||
"src/openllm/serialisation",
|
||||
"src/openllm/bundle",
|
||||
"src/openllm/models/__init__.py",
|
||||
"src/openllm/models/auto/__init__.py",
|
||||
"src/openllm/utils/__init__.py",
|
||||
"src/openllm/utils/codegen.py",
|
||||
"src/openllm/__init__.py",
|
||||
"src/openllm/_prompt.py",
|
||||
"src/openllm/_schema.py",
|
||||
@@ -505,20 +501,18 @@ include = [
|
||||
]
|
||||
# NOTE: This is consistent with pyproject.toml
|
||||
mypy-args = [
|
||||
"--pretty",
|
||||
"--strict",
|
||||
# this is because all transient library doesn't have types
|
||||
"--allow-subclassing-any",
|
||||
"--check-untyped-defs",
|
||||
"--follow-imports=skip",
|
||||
"--python-version=3.11",
|
||||
"--check-untyped-defs",
|
||||
"--ignore-missing-imports",
|
||||
"--no-warn-return-any",
|
||||
"--warn-unreachable",
|
||||
"--no-warn-no-return",
|
||||
"--no-warn-unused-ignores",
|
||||
"--exclude='/src\\/openllm\\/playground\\/**'",
|
||||
"--exclude='/src\\/openllm\\/_types\\.py$'",
|
||||
"--exclude='/src\\/openllm\\/_typing_compat\\.py$'",
|
||||
]
|
||||
options = { verbose = true, debug_level = "2", opt_level = "3" }
|
||||
options = { verbose = true, strip_asserts = true, debug_level = "2", opt_level = "3", include_runtime_files = true }
|
||||
require-runtime-dependencies = true
|
||||
|
||||
@@ -9,12 +9,8 @@ deploy, and monitor any LLMs with ease.
|
||||
* Native integration with BentoML and LangChain for custom LLM apps
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
import typing as _t
|
||||
import warnings as _warnings
|
||||
import logging as _logging, os as _os, typing as _t, warnings as _warnings
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from . import exceptions as exceptions, utils as utils
|
||||
|
||||
if utils.DEBUG:
|
||||
@@ -35,7 +31,6 @@ _import_structure: dict[str, list[str]] = {
|
||||
"exceptions": [], "models": [], "client": [], "bundle": [], "playground": [], "testing": [], "utils": ["infer_auto_class"], "serialisation": ["ggml", "transformers"], "cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"], "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "LLMEmbeddings"], "_configuration": ["LLMConfig", "GenerationConfig", "SamplingParams"], "_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"], "_quantisation": ["infer_quantisation_config"], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs", "HfAgentInput"],
|
||||
"models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": ["ChatGLMConfig"], "models.baichuan": ["BaichuanConfig"], "models.dolly_v2": ["DollyV2Config"], "models.falcon": ["FalconConfig"], "models.flan_t5": ["FlanT5Config"], "models.gpt_neox": ["GPTNeoXConfig"], "models.llama": ["LlamaConfig"], "models.mpt": ["MPTConfig"], "models.opt": ["OPTConfig"], "models.stablelm": ["StableLMConfig"], "models.starcoder": ["StarCoderConfig"]
|
||||
}
|
||||
|
||||
COMPILED = _Path(__file__).suffix in (".pyd", ".so")
|
||||
|
||||
if _t.TYPE_CHECKING:
|
||||
@@ -61,25 +56,38 @@ if _t.TYPE_CHECKING:
|
||||
from .serialisation import ggml as ggml, transformers as transformers
|
||||
from openllm.utils import infer_auto_class as infer_auto_class
|
||||
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_pt_objects"] = ["ChatGLM", "Baichuan"]
|
||||
else:
|
||||
_import_structure["models.chatglm"].extend(["ChatGLM"])
|
||||
_import_structure["models.baichuan"].extend(["Baichuan"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.baichuan import Baichuan as Baichuan
|
||||
from .models.chatglm import ChatGLM as ChatGLM
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_triton_available()): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["MPT"])
|
||||
else: _import_structure["utils.dummy_pt_objects"] = ["MPT"]
|
||||
else:
|
||||
_import_structure["models.mpt"].extend(["MPT"])
|
||||
if _t.TYPE_CHECKING: from .models.mpt import MPT as MPT
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_einops_available()): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["Falcon"])
|
||||
else: _import_structure["utils.dummy_pt_objects"] = ["Falcon"]
|
||||
else:
|
||||
_import_structure["models.falcon"].extend(["Falcon"])
|
||||
if _t.TYPE_CHECKING: from .models.falcon import Falcon as Falcon
|
||||
|
||||
try:
|
||||
if not utils.is_torch_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_pt_objects"] = utils.dummy_pt_objects.__all__
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(utils.dummy_pt_objects) if not name.startswith("_") and name not in ("ChatGLM", "Baichuan", "MPT", "Falcon", "annotations")]
|
||||
else:
|
||||
if utils.is_cpm_kernels_available():
|
||||
_import_structure["models.chatglm"].extend(["ChatGLM"])
|
||||
_import_structure["models.baichuan"].extend(["Baichuan"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.baichuan import Baichuan as Baichuan
|
||||
from .models.chatglm import ChatGLM as ChatGLM
|
||||
if utils.is_einops_available():
|
||||
_import_structure["models.falcon"].extend(["Falcon"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.falcon import Falcon as Falcon
|
||||
if utils.is_triton_available():
|
||||
_import_structure["models.mpt"].extend(["MPT"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.mpt import MPT as MPT
|
||||
_import_structure["models.flan_t5"].extend(["FlanT5"])
|
||||
_import_structure["models.dolly_v2"].extend(["DollyV2"])
|
||||
_import_structure["models.starcoder"].extend(["StarCoder"])
|
||||
@@ -100,7 +108,7 @@ else:
|
||||
try:
|
||||
if not utils.is_vllm_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_vllm_objects"] = utils.dummy_vllm_objects.__all__
|
||||
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(utils.dummy_vllm_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
_import_structure["models.baichuan"].extend(["VLLMBaichuan"])
|
||||
_import_structure["models.llama"].extend(["VLLMLlama"])
|
||||
@@ -124,7 +132,7 @@ else:
|
||||
try:
|
||||
if not utils.is_flax_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_flax_objects"] = utils.dummy_flax_objects.__all__
|
||||
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(utils.dummy_flax_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
|
||||
_import_structure["models.opt"].extend(["FlaxOPT"])
|
||||
@@ -136,7 +144,7 @@ else:
|
||||
try:
|
||||
if not utils.is_tf_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_tf_objects"] = utils.dummy_tf_objects.__all__
|
||||
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(utils.dummy_tf_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
|
||||
_import_structure["models.opt"].extend(["TFOPT"])
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# mypy: disable-error-code="attr-defined,no-untyped-call,type-var,operator,arg-type,no-redef"
|
||||
"""Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``.
|
||||
|
||||
Highlight feature: Each fields in ``openllm.LLMConfig`` will also automatically generate a environment
|
||||
@@ -33,29 +34,16 @@ dynamically during serve, ahead-of-serve or per requests.
|
||||
Refer to ``openllm.LLMConfig`` docstring for more information.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import click_option_group as cog
|
||||
import inflection
|
||||
import orjson
|
||||
import copy, enum, logging, os, sys, types, typing as t
|
||||
import attr, click_option_group as cog, inflection, orjson, openllm
|
||||
from cattr.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
|
||||
from deepmerge.merger import Merger
|
||||
|
||||
import openllm
|
||||
|
||||
from ._strategies import LiteralResourceSpec, available_resource_spec, resource_spec
|
||||
from ._typing_compat import LiteralString, NotRequired, Required, overload, AdapterType, LiteralRuntime
|
||||
from .exceptions import ForbiddenAttributeError
|
||||
from .utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
MYPY,
|
||||
LazyType,
|
||||
ReprMixin,
|
||||
bentoml_cattr,
|
||||
codegen,
|
||||
@@ -63,43 +51,18 @@ from .utils import (
|
||||
field_env_key,
|
||||
first_not_none,
|
||||
lenient_issubclass,
|
||||
non_intrusive_setattr,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
# NOTE: We need to do check overload import
|
||||
# so that it can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import NotRequired, Required, dataclass_transform, overload
|
||||
else:
|
||||
from typing_extensions import NotRequired, Required, dataclass_transform, overload
|
||||
|
||||
# NOTE: Using internal API from attr here, since we are actually
|
||||
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
|
||||
# NOTE: Using internal API from attr here, since we are actually allowing subclass of openllm.LLMConfig to become 'attrs'-ish
|
||||
from attr._compat import set_closure_cell
|
||||
from attr._make import _CountingAttr, _make_init, _transform_attrs
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"]
|
||||
from ._typing_compat import AnyCallable, At, Self, ListStr, DictStrAny
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import click
|
||||
import peft
|
||||
import transformers
|
||||
import vllm
|
||||
import click, peft, transformers, vllm
|
||||
from transformers.generation.beam_constraints import Constraint
|
||||
|
||||
from ._types import AnyCallable, At
|
||||
|
||||
DictStrAny = dict[str, t.Any]
|
||||
ListStr = list[str]
|
||||
else:
|
||||
Constraint = t.Any
|
||||
ListStr = list
|
||||
DictStrAny = dict
|
||||
|
||||
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
peft = openllm.utils.LazyLoader("peft", globals(), "peft")
|
||||
@@ -107,7 +70,7 @@ else:
|
||||
__all__ = ["LLMConfig", "GenerationConfig", "SamplingParams"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config_merger = Merger([(DictStrAny, "merge")], ["override"], ["override"])
|
||||
config_merger = Merger([(dict, "merge")], ["override"], ["override"])
|
||||
|
||||
# case insensitive, but rename to conform with type
|
||||
class _PeftEnumMeta(enum.EnumMeta):
|
||||
@@ -141,8 +104,6 @@ class PeftType(str, enum.Enum, metaclass=_PeftEnumMeta):
|
||||
|
||||
_PEFT_TASK_TYPE_TARGET_MAPPING = {"causal_lm": "CAUSAL_LM", "seq2seq_lm": "SEQ_2_SEQ_LM"}
|
||||
|
||||
AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning", "ia3"]
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType:
|
||||
@@ -263,8 +224,7 @@ class GenerationConfig(ReprMixin):
|
||||
|
||||
if t.TYPE_CHECKING and not MYPY:
|
||||
# stubs this for pyright as mypy already has a attr plugin builtin
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None:
|
||||
...
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: ...
|
||||
|
||||
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
|
||||
if not _internal: raise RuntimeError("GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config")
|
||||
@@ -322,7 +282,7 @@ class SamplingParams(ReprMixin):
|
||||
|
||||
def to_vllm(self) -> vllm.SamplingParams: return vllm.SamplingParams(max_tokens=self.max_tokens, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, **bentoml_cattr.unstructure(self))
|
||||
@classmethod
|
||||
def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> t.Self:
|
||||
def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> Self:
|
||||
"""The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``."""
|
||||
stop = attrs.pop("stop", None)
|
||||
if stop is not None and isinstance(stop, str): stop = [stop]
|
||||
@@ -480,7 +440,7 @@ bentoml_cattr.register_structure_hook(_ModelSettingsAttr, structure_settings)
|
||||
|
||||
def _setattr_class(attr_name: str, value_var: t.Any) -> str: return f"setattr(cls, '{attr_name}', {value_var})"
|
||||
|
||||
def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: t.LiteralString = "openllm") -> t.Callable[..., None]:
|
||||
def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: LiteralString = "openllm") -> t.Callable[..., None]:
|
||||
"""Generate the assignment script with prefix attributes __openllm_<value>__."""
|
||||
args: ListStr = []
|
||||
globs: DictStrAny = {"cls": cls, "_cached_attribute": attributes, "_cached_getattribute_get": _object_getattribute.__get__}
|
||||
@@ -497,11 +457,6 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance
|
||||
|
||||
_reserved_namespace = {"__config__", "GenerationConfig", "SamplingParams"}
|
||||
|
||||
@dataclass_transform(kw_only_default=True, order_default=True, field_specifiers=(attr.field, dantic.Field))
|
||||
def llm_config_transform(cls: type[LLMConfig]) -> type[LLMConfig]:
|
||||
non_intrusive_setattr(cls, "__dataclass_transform__", {"order_default": True, "kw_only_default": True, "field_specifiers": (attr.field, dantic.Field)})
|
||||
return cls
|
||||
|
||||
@attr.define(slots=True)
|
||||
class _ConfigAttr:
|
||||
Field = dantic.Field
|
||||
@@ -669,7 +624,7 @@ class _ConfigBuilder:
|
||||
|
||||
__slots__ = ("_cls", "_cls_dict", "_attr_names", "_attrs", "_model_name", "_base_attr_map", "_base_names", "_has_pre_init", "_has_post_init")
|
||||
|
||||
def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr[t.Any]], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True):
|
||||
def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True):
|
||||
attrs, base_attrs, base_attr_map = _transform_attrs(cls, these, auto_attribs, kw_only, collect_by_mro, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__))
|
||||
self._cls, self._model_name, self._cls_dict, self._attrs, self._base_names, self._base_attr_map = cls, cls.__openllm_model_name__, dict(cls.__dict__), attrs, {a.name for a in base_attrs}, base_attr_map
|
||||
self._attr_names = tuple(a.name for a in attrs)
|
||||
@@ -742,17 +697,16 @@ class _ConfigBuilder:
|
||||
if not closure_cells: continue # Catch None or the empty list.
|
||||
for cell in closure_cells:
|
||||
try: match = cell.cell_contents is self._cls
|
||||
except ValueError: pass # ValueError: Cell is empty
|
||||
except ValueError: pass # noqa: PERF203 # ValueError: Cell is empty
|
||||
else:
|
||||
if match: set_closure_cell(cell, cls)
|
||||
return cls
|
||||
|
||||
return llm_config_transform(cls)
|
||||
|
||||
def add_attrs_init(self) -> t.Self:
|
||||
def add_attrs_init(self) -> Self:
|
||||
self._cls_dict["__attrs_init__"] = codegen.add_method_dunders(self._cls, _make_init(self._cls, self._attrs, self._has_pre_init, self._has_post_init, False, True, False, self._base_attr_map, False, None, True))
|
||||
return self
|
||||
|
||||
def add_repr(self) -> t.Self:
|
||||
def add_repr(self) -> Self:
|
||||
for key, fn in ReprMixin.__dict__.items():
|
||||
if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"): self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
|
||||
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config", "sampling_config"})
|
||||
@@ -865,7 +819,7 @@ class LLMConfig(_ConfigAttr):
|
||||
|
||||
# auto assignment attributes generated from __config__ after create the new slot class.
|
||||
_make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettingsAttr))(cls)
|
||||
def _make_subclass(class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: t.LiteralString | None = None) -> type[At]:
|
||||
def _make_subclass(class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: LiteralString | None = None) -> type[At]:
|
||||
camel_name = cls.__name__.replace("Config", "")
|
||||
klass = attr.make_class(f"{camel_name}{class_attr}", [], bases=(base,), slots=True, weakref_slot=True, frozen=True, repr=False, init=False, collect_by_mro=True, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__, suffix=suffix_env, globs=globs, default_callback=lambda field_name, field_default: getattr(getattr(cls, class_attr), field_name, field_default) if codegen.has_own_attribute(cls, class_attr) else field_default))
|
||||
# For pickling to work, the __module__ variable needs to be set to the
|
||||
@@ -882,19 +836,19 @@ class LLMConfig(_ConfigAttr):
|
||||
anns = codegen.get_annotations(cls)
|
||||
# _CountingAttr is the underlying representation of attr.field
|
||||
ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)}
|
||||
these: dict[str, _CountingAttr[t.Any]] = {}
|
||||
these: dict[str, _CountingAttr] = {}
|
||||
annotated_names: set[str] = set()
|
||||
for attr_name, typ in anns.items():
|
||||
if codegen.is_class_var(typ): continue
|
||||
annotated_names.add(attr_name)
|
||||
val = cd.get(attr_name, attr.NOTHING)
|
||||
if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val):
|
||||
if not isinstance(val, _CountingAttr):
|
||||
if val is attr.NOTHING: val = cls.Field(env=field_env_key(cls.__openllm_model_name__, attr_name))
|
||||
else: val = cls.Field(default=val, env=field_env_key(cls.__openllm_model_name__, attr_name))
|
||||
these[attr_name] = val
|
||||
unannotated = ca_names - annotated_names
|
||||
if len(unannotated) > 0:
|
||||
missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter)
|
||||
missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr", cd.get(n)).counter)
|
||||
raise openllm.exceptions.MissingAnnotationAttributeError(f"The following field doesn't have a type annotation: {missing_annotated}")
|
||||
# We need to set the accepted key before generation_config
|
||||
# as generation_config is a special field that users shouldn't pass.
|
||||
@@ -1102,7 +1056,7 @@ class LLMConfig(_ConfigAttr):
|
||||
def __getitem__(self, item: t.Literal["ia3"]) -> dict[str, t.Any]: ...
|
||||
# update-config-stubs.py: stop
|
||||
|
||||
def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any:
|
||||
def __getitem__(self, item: LiteralString | t.Any) -> t.Any:
|
||||
"""Allowing access LLMConfig as a dictionary. The order will always evaluate as.
|
||||
|
||||
__openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__
|
||||
@@ -1179,13 +1133,13 @@ class LLMConfig(_ConfigAttr):
|
||||
def model_dump_json(self, **kwargs: t.Any) -> bytes: return orjson.dumps(self.model_dump(**kwargs))
|
||||
|
||||
@classmethod
|
||||
def model_construct_json(cls, json_str: str | bytes) -> t.Self:
|
||||
def model_construct_json(cls, json_str: str | bytes) -> Self:
|
||||
try: attrs = orjson.loads(json_str)
|
||||
except orjson.JSONDecodeError as err: raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") from None
|
||||
return bentoml_cattr.structure(attrs, cls)
|
||||
|
||||
@classmethod
|
||||
def model_construct_env(cls, **attrs: t.Any) -> t.Self:
|
||||
def model_construct_env(cls, **attrs: t.Any) -> Self:
|
||||
"""A helpers that respect configuration values environment variables."""
|
||||
attrs = {k: v for k, v in attrs.items() if v is not None}
|
||||
model_config = cls.__openllm_env__.config
|
||||
@@ -1198,7 +1152,7 @@ class LLMConfig(_ConfigAttr):
|
||||
|
||||
if "generation_config" in attrs:
|
||||
generation_config = attrs.pop("generation_config")
|
||||
if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
|
||||
if not isinstance(generation_config, dict): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
|
||||
else: generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(cls.__openllm_generation_class__)}
|
||||
|
||||
for k in tuple(attrs.keys()):
|
||||
@@ -1258,7 +1212,7 @@ class LLMConfig(_ConfigAttr):
|
||||
|
||||
total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set(attr.fields_dict(cls.__openllm_sampling_class__))
|
||||
|
||||
if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return f
|
||||
if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return t.cast("click.Command", f)
|
||||
# We pop out 'generation_config' as it is a attribute that we don't need to expose to CLI.
|
||||
for name, field in attr.fields_dict(cls).items():
|
||||
ty = cls.__openllm_hints__.get(name)
|
||||
@@ -1285,13 +1239,13 @@ def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig:
|
||||
Otherwise, we will filter out all keys are first in LLMConfig, parse it in, then
|
||||
parse the remaining keys into LLMConfig.generation_config
|
||||
"""
|
||||
if not LazyType(DictStrAny).isinstance(data): raise RuntimeError(f"Expected a dictionary, but got {type(data)}")
|
||||
if not isinstance(data, dict): raise RuntimeError(f"Expected a dictionary, but got {type(data)}")
|
||||
|
||||
cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_accepted_keys__}
|
||||
generation_cls_fields = attr.fields_dict(cls.__openllm_generation_class__)
|
||||
if "generation_config" in data:
|
||||
generation_config = data.pop("generation_config")
|
||||
if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
|
||||
if not isinstance(generation_config, dict): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
|
||||
config_merger.merge(generation_config, {k: v for k, v in data.items() if k in generation_cls_fields})
|
||||
else:
|
||||
generation_config = {k: v for k, v in data.items() if k in generation_cls_fields}
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
"""Generation utilities to be reused throughout."""
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import transformers
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
import openllm
|
||||
import typing as t, transformers
|
||||
if t.TYPE_CHECKING: import torch, openllm
|
||||
|
||||
LogitsProcessorList = transformers.LogitsProcessorList
|
||||
StoppingCriteriaList = transformers.StoppingCriteriaList
|
||||
|
||||
class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast):
|
||||
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]
|
||||
|
||||
@@ -1,32 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
import typing as t
|
||||
import uuid
|
||||
import functools, inspect, logging, os, re, traceback, types, typing as t, uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import attr
|
||||
import fs.path
|
||||
import inflection
|
||||
import orjson
|
||||
import attr, fs.path, inflection, orjson, bentoml, openllm
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
|
||||
from ._configuration import (
|
||||
AdapterType,
|
||||
FineTuneConfig,
|
||||
LiteralRuntime,
|
||||
LLMConfig,
|
||||
_object_getattribute,
|
||||
_setattr_class,
|
||||
@@ -57,54 +38,37 @@ from .utils import (
|
||||
validate_is_path,
|
||||
)
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import NotRequired, overload
|
||||
else:
|
||||
from typing_extensions import NotRequired, overload
|
||||
from ._typing_compat import (
|
||||
AdaptersMapping,
|
||||
AdaptersTuple,
|
||||
AnyCallable,
|
||||
AdapterType,
|
||||
LiteralRuntime,
|
||||
DictStrAny,
|
||||
ListStr,
|
||||
LLMEmbeddings,
|
||||
LLMRunnable,
|
||||
LLMRunner,
|
||||
ModelSignatureDict as _ModelSignatureDict,
|
||||
PeftAdapterOutput,
|
||||
TupleAny,
|
||||
NotRequired, overload, M, T, LiteralString
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import auto_gptq as autogptq
|
||||
import peft
|
||||
import torch
|
||||
import transformers
|
||||
import vllm
|
||||
|
||||
import auto_gptq as autogptq, peft, torch, transformers, vllm
|
||||
from ._configuration import PeftType
|
||||
from ._types import (
|
||||
AdaptersMapping,
|
||||
AdaptersTuple,
|
||||
AnyCallable,
|
||||
DictStrAny,
|
||||
ListStr,
|
||||
LLMEmbeddings,
|
||||
LLMRunnable,
|
||||
LLMRunner,
|
||||
ModelSignatureDict as _ModelSignatureDict,
|
||||
PeftAdapterOutput,
|
||||
TupleAny,
|
||||
)
|
||||
from .utils.representation import ReprArgs
|
||||
|
||||
UserDictAny = collections.UserDict[str, t.Any]
|
||||
ResolvedAdaptersMapping = dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]]
|
||||
else:
|
||||
DictStrAny = dict
|
||||
TupleAny = tuple
|
||||
UserDictAny = collections.UserDict
|
||||
LLMRunnable = bentoml.Runnable
|
||||
LLMRunner = bentoml.Runner
|
||||
LLMEmbeddings = dict
|
||||
|
||||
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
vllm = LazyLoader("vllm", globals(), "vllm")
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
peft = LazyLoader("peft", globals(), "peft")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ResolvedAdaptersMapping = t.Dict[AdapterType, t.Dict[t.Union[str, t.Literal["default"]], t.Tuple["peft.PeftConfig", str]]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class ModelSignatureDict(t.TypedDict, total=False):
|
||||
batchable: bool
|
||||
batch_dim: t.Union[t.Tuple[int, int], int]
|
||||
@@ -150,9 +114,6 @@ def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapp
|
||||
|
||||
_reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"}
|
||||
|
||||
M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]")
|
||||
T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]")
|
||||
|
||||
class LLMInterface(ABC, t.Generic[M, T]):
|
||||
"""This defines the loose contract for all openllm.LLM implementations."""
|
||||
@property
|
||||
@@ -257,7 +218,7 @@ class LLMInterface(ABC, t.Generic[M, T]):
|
||||
raise NotImplementedError
|
||||
|
||||
# NOTE: All fields below are attributes that can be accessed by users.
|
||||
config_class: type[LLMConfig]
|
||||
config_class: t.Type[LLMConfig]
|
||||
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
|
||||
bettertransformer: bool
|
||||
"""Whether to load this LLM with FasterTransformer enabled. The order of loading is:
|
||||
@@ -270,7 +231,7 @@ class LLMInterface(ABC, t.Generic[M, T]):
|
||||
"""
|
||||
device: "torch.device"
|
||||
"""The device to be used for this LLM. If the implementation is 'pt', then it will be torch.device, else string."""
|
||||
tokenizer_id: t.LiteralString | t.Literal["local"]
|
||||
tokenizer_id: t.Union[t.Literal["local"], LiteralString]
|
||||
"""optional tokenizer_id for loading with vLLM if the model supports vLLM."""
|
||||
# NOTE: The following will be populated by __init_subclass__, note that these should be immutable.
|
||||
__llm_trust_remote_code__: bool
|
||||
@@ -290,13 +251,13 @@ class LLMInterface(ABC, t.Generic[M, T]):
|
||||
|
||||
An additional naming for all VLLM backend: VLLMLlama -> `vllm`
|
||||
"""
|
||||
__llm_model__: M | None
|
||||
__llm_model__: t.Optional[M]
|
||||
"""A reference to the actual model. Instead of access this directly, you should use `model` property instead."""
|
||||
__llm_tokenizer__: T | None
|
||||
__llm_tokenizer__: t.Optional[T]
|
||||
"""A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead."""
|
||||
__llm_bentomodel__: bentoml.Model | None
|
||||
__llm_bentomodel__: t.Optional[bentoml.Model]
|
||||
"""A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead."""
|
||||
__llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None
|
||||
__llm_adapter_map__: t.Optional[ResolvedAdaptersMapping]
|
||||
"""A reference to the the cached LoRA adapter mapping."""
|
||||
__llm_supports_embeddings__: bool
|
||||
"""A boolean to determine whether models does implement ``LLM.embeddings``."""
|
||||
@@ -307,11 +268,10 @@ class LLMInterface(ABC, t.Generic[M, T]):
|
||||
__llm_supports_generate_iterator__: bool
|
||||
"""A boolean to determine whether models does implement ``LLM.generate_iterator``."""
|
||||
if t.TYPE_CHECKING and not MYPY:
|
||||
def __attrs_init__(self, config: LLMConfig, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None, model_id: str, runtime: t.Literal["ggml", "transformers"], model_decls: TupleAny, model_attrs: DictStrAny, tokenizer_attrs: DictStrAny, tag: bentoml.Tag, adapters_mapping: AdaptersMapping | None, model_version: str | None, quantize_method: t.Literal["int8", "int4", "gptq"] | None, serialisation_format: t.Literal["safetensors", "legacy"], _local: bool, **attrs: t.Any) -> None:
|
||||
def __attrs_init__(self, config: LLMConfig, quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, autogptq.BaseQuantizeConfig]], model_id: str, runtime: t.Literal["ggml", "transformers"], model_decls: TupleAny, model_attrs: DictStrAny, tokenizer_attrs: DictStrAny, tag: bentoml.Tag, adapters_mapping: t.Optional[AdaptersMapping], model_version: t.Optional[str], quantize_method: t.Optional[t.Literal["int8", "int4", "gptq"]], serialisation_format: t.Literal["safetensors", "legacy"], _local: bool, **attrs: t.Any) -> None:
|
||||
"""Generated __attrs_init__ for openllm.LLM."""
|
||||
|
||||
_R = t.TypeVar("_R", covariant=True)
|
||||
|
||||
class _import_model_wrapper(t.Generic[_R, M, T], t.Protocol):
|
||||
def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R: ...
|
||||
class _load_model_wrapper(t.Generic[M, T], t.Protocol):
|
||||
@@ -333,7 +293,7 @@ def _wrapped_import_model(f: _import_model_wrapper[bentoml.Model, M, T]) -> t.Ca
|
||||
(model_decls, model_attrs), _ = self.llm_parameters
|
||||
decls = (*model_decls, *decls)
|
||||
attrs = {**model_attrs, **attrs}
|
||||
return f(self, *decls, trust_remote_code=t.cast(bool, trust_remote_code), **attrs)
|
||||
return f(self, *decls, trust_remote_code=trust_remote_code, **attrs)
|
||||
return wrapper
|
||||
|
||||
_DEFAULT_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||
@@ -483,7 +443,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bentomodel"]) -> bentoml.Model | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["adapter_map"]) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None: ...
|
||||
def __getitem__(self, item: t.Literal["adapter_map"]) -> ResolvedAdaptersMapping | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["supports_embeddings"]) -> bool: ...
|
||||
@overload
|
||||
@@ -492,7 +452,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
def __getitem__(self, item: t.Literal["supports_generate_one"]) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["supports_generate_iterator"]) -> bool: ...
|
||||
def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any:
|
||||
def __getitem__(self, item: t.Union[LiteralString, t.Any]) -> t.Any:
|
||||
if item is None: raise TypeError(f"{self} doesn't understand how to index None.")
|
||||
item = inflection.underscore(item)
|
||||
internal_attributes = f"__llm_{item}__"
|
||||
@@ -575,7 +535,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
"""
|
||||
cfg_cls = cls.config_class
|
||||
_local = False
|
||||
_model_id: str = t.cast(str, first_not_none(model_id, os.environ.get(cfg_cls.__openllm_env__["model_id"]), default=cfg_cls.__openllm_default_id__))
|
||||
_model_id: str = first_not_none(model_id, os.environ.get(cfg_cls.__openllm_env__["model_id"]), default=cfg_cls.__openllm_default_id__)
|
||||
if validate_is_path(_model_id): _model_id, _local = resolve_filepath(_model_id), True
|
||||
quantize = first_not_none(quantize, t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.environ.get(cfg_cls.__openllm_env__["quantize"])), default=None)
|
||||
|
||||
@@ -889,7 +849,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
adapter_mapping = _mapping[adapter_type]
|
||||
|
||||
self.__llm_model__ = self._wrap_default_peft_model(adapter_mapping, inference_mode=inference_mode)
|
||||
|
||||
# now we loop through the rest with add_adapter
|
||||
if len(adapter_mapping) > 0:
|
||||
for adapter_name, (_peft_config, _) in adapter_mapping.items():
|
||||
@@ -1108,11 +1067,10 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
|
||||
def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
|
||||
def available_adapters(_: LLMRunner[M, T]) -> PeftAdapterOutput:
|
||||
if not is_peft_available(): return {"success": False, "result": {}, "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'"}
|
||||
if self.__llm_adapter_map__ is None: return {"success": False, "result": {}, "error_msg": "No adapters available for current running server."}
|
||||
if not isinstance(self.model, peft.PeftModel): return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"}
|
||||
return {"success": True, "result": self.model.peft_config, "error_msg": ""}
|
||||
|
||||
if not is_peft_available(): return PeftAdapterOutput(success=False, result={}, error_msg="peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'")
|
||||
if self.__llm_adapter_map__ is None: return PeftAdapterOutput(success=False, result={}, error_msg="No adapters available for current running server.")
|
||||
if not isinstance(self.model, peft.PeftModel): return PeftAdapterOutput(success=False, result={}, error_msg="Model is not a PeftModel")
|
||||
return PeftAdapterOutput(success=True, result=self.model.peft_config, error_msg="")
|
||||
def _wrapped_generate_run(__self: LLMRunner[M, T], prompt: str, **kwargs: t.Any) -> t.Any:
|
||||
"""Wrapper for runner.generate.run() to handle the prompt and postprocessing.
|
||||
|
||||
|
||||
@@ -1,22 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
import string, typing as t
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0: raise ValueError("Positional arguments are not supported")
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras: raise KeyError(f"Extra params passed: {extras}")
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
|
||||
def process_prompt(prompt: str, template: str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str:
|
||||
# Currently, all default prompt will always have `instruction` key.
|
||||
if not use_prompt_template: return prompt
|
||||
@@ -24,7 +19,5 @@ def process_prompt(prompt: str, template: str | None = None, use_prompt_template
|
||||
template_variables = default_formatter.extract_template_variables(template)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
try:
|
||||
return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None
|
||||
try: return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e: raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None
|
||||
|
||||
@@ -1,18 +1,9 @@
|
||||
# mypy: disable-error-code="name-defined"
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import logging, sys, typing as t
|
||||
from .utils import LazyLoader, is_autogptq_available, is_bitsandbytes_available, is_transformers_supports_kbit, pkg
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if sys.version_info[:2] >= (3, 11): from typing import overload
|
||||
else: from typing_extensions import overload
|
||||
if t.TYPE_CHECKING:
|
||||
from ._llm import LLM
|
||||
from ._types import DictStrAny
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
"""Schema definition for OpenLLM. This can be use for client interaction."""
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import inflection
|
||||
|
||||
import openllm
|
||||
|
||||
import functools, typing as t
|
||||
import attr, inflection, openllm
|
||||
from ._configuration import GenerationConfig, LLMConfig
|
||||
from .utils import bentoml_cattr
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm
|
||||
if t.TYPE_CHECKING: import vllm
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationInput:
|
||||
@@ -30,7 +22,6 @@ class GenerationInput:
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]: return cls.from_llm_config(openllm.AutoConfig.for_model(model_name, **attrs))
|
||||
@classmethod
|
||||
def from_llm_config(cls, llm_config: openllm.LLMConfig) -> type[GenerationInput]: return attr.make_class(inflection.camelize(llm_config["model_name"]) + "GenerationInput", attrs={"prompt": attr.field(type=str), "llm_config": attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__)), "adapter_name": attr.field(default=None, type=str)})
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationOutput:
|
||||
responses: t.List[t.Any]
|
||||
@@ -43,7 +34,6 @@ class GenerationOutput:
|
||||
if hasattr(self, key): return getattr(self, key)
|
||||
elif key in self.configuration: return self.configuration[key]
|
||||
else: raise KeyError(key)
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class MetadataOutput:
|
||||
model_id: str
|
||||
@@ -53,14 +43,11 @@ class MetadataOutput:
|
||||
configuration: str
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class EmbeddingsOutput:
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
|
||||
def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> dict[str, t.Any]: return dict(request_id=request_output.request_id, prompt=request_output.prompt, finished=request_output.finished, prompt_token_ids=request_output.prompt_token_ids, outputs=[dict(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason) for it in request_output.outputs])
|
||||
|
||||
@attr.define
|
||||
class HfAgentInput:
|
||||
inputs: str
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# mypy: disable-error-code="arg-type,misc"
|
||||
"""The service definition for running any LLMService.
|
||||
|
||||
Note that the line `model = ...` is a special line and should not be modified. This will be handled by openllm
|
||||
internally to generate the correct model service when bundling the LLM to a Bento.
|
||||
This will ensure that 'bentoml serve llm-bento' will work accordingly.
|
||||
|
||||
The generation code lives under utils/codegen.py
|
||||
For line with comment '# openllm: ...', it must not be modified as it is managed internally by OpenLLM.
|
||||
Codegen can be found under 'openllm.utils.codegen'
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import os, typing as t, warnings, orjson, bentoml, openllm
|
||||
import os, warnings, orjson, bentoml, openllm, typing as t
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
@@ -24,48 +22,31 @@ llm_config = openllm.AutoConfig.for_model(model)
|
||||
runner = openllm.Runner(model, llm_config=llm_config, ensure_available=False, adapter_map=orjson.loads(adapter_map))
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner])
|
||||
|
||||
@svc.api(route="/v1/generate", input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), # type: ignore[arg-type] # XXX: remove once JSON supports Attrs class
|
||||
output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)}))
|
||||
@svc.api(route="/v1/generate", input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)}))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, **{"adapter_name": qa_inputs.adapter_name, **config})
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
|
||||
@svc.api(
|
||||
route="/v1/metadata", input=bentoml.io.Text(), # type: ignore[misc] # XXX: remove once JSON supports Attrs class
|
||||
output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent})
|
||||
)
|
||||
@svc.api(route="/v1/metadata", input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent}))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,)
|
||||
return openllm.MetadataOutput(timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], model_id=runner.llm.model_id, configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent)
|
||||
|
||||
if runner.supports_embeddings:
|
||||
|
||||
@svc.api( # type: ignore[arg-type] # XXX: remove once JSON supports Attrs class
|
||||
input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({
|
||||
"embeddings": [
|
||||
0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918,
|
||||
0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076
|
||||
], "num_tokens": 20
|
||||
}), route="/v1/embeddings"
|
||||
)
|
||||
@svc.api(route="/v1/embeddings", input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}))
|
||||
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
|
||||
responses = await runner.embeddings.async_run(phrases)
|
||||
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"])
|
||||
|
||||
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
|
||||
async def hf_agent(request: Request) -> Response:
|
||||
json_str = await request.body()
|
||||
try:
|
||||
input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), openllm.HfAgentInput)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
|
||||
try: input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), openllm.HfAgentInput)
|
||||
except orjson.JSONDecodeError as err: raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
|
||||
stop = input_data.parameters.pop("stop", ["\n"])
|
||||
try:
|
||||
return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
|
||||
except NotImplementedError:
|
||||
return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
|
||||
try: return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
|
||||
except NotImplementedError: return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
|
||||
|
||||
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
|
||||
svc.mount_asgi_app(hf_app, path="/hf")
|
||||
|
||||
@@ -1,38 +1,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import typing as t
|
||||
import warnings
|
||||
|
||||
import psutil
|
||||
|
||||
import bentoml
|
||||
import functools, inspect, logging, math, os, sys, types, typing as t, warnings, psutil, bentoml
|
||||
from bentoml._internal.resource import get_resource, system_resources
|
||||
from bentoml._internal.runner.strategy import THREAD_ENVS
|
||||
|
||||
from .utils import DEBUG, ReprMixin
|
||||
if sys.version_info[:2] >= (3, 11): from typing import overload
|
||||
else: from typing_extensions import overload
|
||||
|
||||
class DynResource(t.Protocol):
|
||||
resource_id: t.ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def from_system(cls) -> t.Sequence[t.Any]: ...
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _strtoul(s: str) -> int:
|
||||
"""Return -1 or positive integer sequence string starts with,."""
|
||||
if not s: return -1
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Types definition for OpenLLM.
|
||||
|
||||
Note that this module SHOULD NOT BE IMPORTED DURING RUNTIME, as this serve only for typing purposes.
|
||||
It will raises a RuntimeError if this is imported eagerly.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
if not t.TYPE_CHECKING: raise RuntimeError(f"{__name__} should not be imported during runtime")
|
||||
|
||||
import attr
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict
|
||||
|
||||
from ._configuration import (
|
||||
AdapterType,
|
||||
LiteralRuntime as LiteralRuntime,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
|
||||
import openllm
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
from bentoml._internal.runner.strategy import Strategy
|
||||
|
||||
from ._llm import (
|
||||
M as _M,
|
||||
T as _T,
|
||||
)
|
||||
from .bundle.oci import LiteralContainerVersionStrategy
|
||||
from .utils.lazy import VersionInfo
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
DictStrAny = dict[str, t.Any]
|
||||
ListAny = list[t.Any]
|
||||
ListStr = list[str]
|
||||
TupleAny = tuple[t.Any, ...]
|
||||
P = t.ParamSpec("P")
|
||||
T = t.TypeVar("T")
|
||||
At = t.TypeVar("At", bound=attr.AttrsInstance)
|
||||
|
||||
class PeftAdapterOutput(t.TypedDict):
|
||||
success: bool
|
||||
result: dict[str, peft.PeftConfig]
|
||||
error_msg: str
|
||||
|
||||
class LLMEmbeddings(t.TypedDict):
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
|
||||
class AdaptersTuple(TupleAny):
|
||||
adapter_id: str
|
||||
name: str | None
|
||||
config: DictStrAny
|
||||
|
||||
class RefTuple(TupleAny):
|
||||
git_hash: str
|
||||
version: VersionInfo
|
||||
strategy: LiteralContainerVersionStrategy
|
||||
|
||||
AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]]
|
||||
|
||||
class LLMRunnable(bentoml.Runnable, t.Generic[_M, _T]):
|
||||
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
__call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
|
||||
embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
|
||||
|
||||
class LLMRunner(bentoml.Runner, t.Generic[_M, _T]):
|
||||
__doc__: str
|
||||
__module__: str
|
||||
llm_type: str
|
||||
identifying_params: dict[str, t.Any]
|
||||
llm: openllm.LLM[_M, _T]
|
||||
config: openllm.LLMConfig
|
||||
implementation: LiteralRuntime
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
has_adapters: bool
|
||||
embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]]
|
||||
generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]]
|
||||
def __init__(self, runnable_class: type[LLMRunnable[_M, _T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,) -> None: ...
|
||||
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ...
|
||||
def run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
def download_model(self) -> bentoml.Model: ...
|
||||
@property
|
||||
def peft_adapters(self) -> PeftAdapterOutput: ...
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]: ...
|
||||
102
src/openllm/_typing_compat.py
Normal file
102
src/openllm/_typing_compat.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t, bentoml, attr, abc
|
||||
from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm, peft, transformers, auto_gptq as autogptq, vllm
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
from bentoml._internal.runner.strategy import Strategy
|
||||
|
||||
from .bundle.oci import LiteralContainerVersionStrategy
|
||||
from .utils.lazy import VersionInfo
|
||||
|
||||
M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]")
|
||||
T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]")
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
DictStrAny = t.Dict[str, t.Any]
|
||||
ListAny = t.List[t.Any]
|
||||
ListStr = t.List[str]
|
||||
TupleAny = t.Tuple[t.Any, ...]
|
||||
At = t.TypeVar("At", bound=attr.AttrsInstance)
|
||||
|
||||
LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"]
|
||||
AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning", "ia3"]
|
||||
|
||||
if sys.version_info[:2] >= (3,11):
|
||||
from typing import LiteralString as LiteralString, Self as Self, overload as overload
|
||||
from typing import NotRequired as NotRequired, Required as Required, dataclass_transform as dataclass_transform
|
||||
else:
|
||||
from typing_extensions import LiteralString as LiteralString, Self as Self, overload as overload
|
||||
from typing_extensions import NotRequired as NotRequired, Required as Required, dataclass_transform as dataclass_transform
|
||||
|
||||
if sys.version_info[:2] >= (3,10):
|
||||
from typing import TypeAlias as TypeAlias, ParamSpec as ParamSpec, Concatenate as Concatenate
|
||||
else:
|
||||
from typing_extensions import TypeAlias as TypeAlias, ParamSpec as ParamSpec, Concatenate as Concatenate
|
||||
|
||||
if sys.version_info[:2] >= (3,9):
|
||||
from typing import TypedDict as TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict as TypedDict
|
||||
|
||||
class PeftAdapterOutput(TypedDict):
|
||||
success: bool
|
||||
result: t.Dict[str, peft.PeftConfig]
|
||||
error_msg: str
|
||||
|
||||
class LLMEmbeddings(t.TypedDict):
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
|
||||
class AdaptersTuple(TupleAny):
|
||||
adapter_id: str
|
||||
name: t.Optional[str]
|
||||
config: DictStrAny
|
||||
|
||||
AdaptersMapping = t.Dict[AdapterType, t.Tuple[AdaptersTuple, ...]]
|
||||
|
||||
class RefTuple(TupleAny):
|
||||
git_hash: str
|
||||
version: VersionInfo
|
||||
strategy: LiteralContainerVersionStrategy
|
||||
|
||||
class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
|
||||
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
__call__: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
set_adapter: RunnableMethod[LLMRunnable[M, T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
|
||||
embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Generator[t.Any, None, None]]
|
||||
|
||||
class LLMRunner(bentoml.Runner, t.Generic[M, T]):
|
||||
__doc__: str
|
||||
__module__: str
|
||||
llm_type: str
|
||||
identifying_params: dict[str, t.Any]
|
||||
llm: openllm.LLM[M, T]
|
||||
config: openllm.LLMConfig
|
||||
implementation: LiteralRuntime
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
has_adapters: bool
|
||||
embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[t.Any, None, None]]
|
||||
def __init__(self, runnable_class: type[LLMRunnable[M, T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,) -> None: ...
|
||||
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
@abc.abstractmethod
|
||||
def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ...
|
||||
def run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
@abc.abstractmethod
|
||||
def download_model(self) -> bentoml.Model: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def peft_adapters(self) -> PeftAdapterOutput: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def __repr_keys__(self) -> set[str]: ...
|
||||
@@ -3,10 +3,8 @@
|
||||
These utilities will stay internal, and its API can be changed or updated without backward-compatibility.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import os, typing as t
|
||||
from openllm.utils import LazyModule
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]}
|
||||
|
||||
@@ -29,5 +27,8 @@ if t.TYPE_CHECKING:
|
||||
get_base_container_tag as get_base_container_tag,
|
||||
supported_registries as supported_registries,
|
||||
)
|
||||
else:
|
||||
sys.modules[__name__] = openllm.utils.LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
__lazy=LazyModule(__name__, os.path.abspath("__file__"), _import_structure)
|
||||
__all__=__lazy.__all__
|
||||
__dir__=__lazy.__dir__
|
||||
__getattr__=__lazy.__getattr__
|
||||
|
||||
@@ -1,30 +1,18 @@
|
||||
# mypy: disable-error-code="misc"
|
||||
from __future__ import annotations
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
import importlib.metadata, inspect, logging, os, typing as t
|
||||
from pathlib import Path
|
||||
|
||||
import fs
|
||||
import fs.copy
|
||||
import fs.errors
|
||||
import orjson
|
||||
import fs, fs.copy, fs.errors, orjson, bentoml, openllm
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.bento.build_config import BentoBuildConfig, DockerOptions, ModelSpec, PythonOptions
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
|
||||
from . import oci
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
|
||||
from openllm._typing_compat import LiteralString
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
|
||||
from .oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,7 +27,7 @@ def build_editable(path: str) -> str | None:
|
||||
from build.env import IsolatedEnvBuilder
|
||||
module_location = openllm.utils.pkg.source_locations("openllm")
|
||||
if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.")
|
||||
pyproject_path = Path(module_location).parent.parent / "pyproject.toml"
|
||||
pyproject_path = Path(module_location).parent.parent/"pyproject.toml"
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
logger.info("OpenLLM is installed in editable mode. Generating built wheels...")
|
||||
with IsolatedEnvBuilder() as env:
|
||||
@@ -80,31 +68,26 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
|
||||
_tf_version = importlib.metadata.version(candidate)
|
||||
packages.extend([f"tensorflow>={_tf_version}"])
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass
|
||||
except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
else:
|
||||
if not openllm.utils.is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
|
||||
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
|
||||
|
||||
wheels: list[str] = []
|
||||
built_wheels = build_editable(llm_fs.getsyspath("/"))
|
||||
if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}"))
|
||||
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"])
|
||||
|
||||
def construct_docker_options(
|
||||
llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy,
|
||||
) -> DockerOptions:
|
||||
def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions:
|
||||
_bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options_opts = [
|
||||
"tracing.sample_rate=1.0", f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'api_server.traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
|
||||
]
|
||||
_bentoml_config_options_opts = ["tracing.sample_rate=1.0", f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'api_server.traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}']
|
||||
_bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts)
|
||||
env: openllm.utils.EnvVarMixin = llm.config["env"]
|
||||
if env["framework_value"] == "vllm": serialisation_format = "legacy"
|
||||
env_dict = {
|
||||
env.framework: env["framework_value"], env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'",
|
||||
env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}"
|
||||
env.framework: env["framework_value"], env.config: f"'{llm.config.model_dump_json().decode()}'",
|
||||
env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}",
|
||||
"OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format,
|
||||
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'",
|
||||
}
|
||||
if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
|
||||
@@ -117,10 +100,9 @@ def construct_docker_options(
|
||||
return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template)
|
||||
|
||||
@inject
|
||||
def create_bento(
|
||||
bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, device: tuple[str, ...] | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None, runtime: t.Literal[
|
||||
"ggml", "transformers"] = "transformers", serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release", _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
) -> bentoml.Bento:
|
||||
def create_bento(bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: LiteralString | None, bettertransformer: bool | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None,
|
||||
runtime: t.Literal[ "ggml", "transformers"] = "transformers", serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release",
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento:
|
||||
framework_envvar = llm.config["env"]["framework_value"]
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"})
|
||||
@@ -129,13 +111,9 @@ def create_bento(
|
||||
if workers_per_resource == "round_robin": workers_per_resource = 1.0
|
||||
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if openllm.utils.device_count() == 0 else float(1 / openllm.utils.device_count())
|
||||
else:
|
||||
try:
|
||||
workers_per_resource = float(workers_per_resource)
|
||||
except ValueError:
|
||||
raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
|
||||
elif isinstance(workers_per_resource, int):
|
||||
workers_per_resource = float(workers_per_resource)
|
||||
|
||||
try: workers_per_resource = float(workers_per_resource)
|
||||
except ValueError: raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
|
||||
elif isinstance(workers_per_resource, int): workers_per_resource = float(workers_per_resource)
|
||||
logger.info("Building Bento for '%s'", llm.config["start_name"])
|
||||
# add service.py definition to this temporary folder
|
||||
openllm.utils.codegen.write_service(llm, adapter_map, llm_fs)
|
||||
|
||||
@@ -1,31 +1,21 @@
|
||||
# mypy: disable-error-code="misc"
|
||||
"""OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change."""
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import pathlib
|
||||
import shutil
|
||||
import subprocess
|
||||
import typing as t
|
||||
import functools, importlib, logging, os, pathlib, shutil, subprocess, typing as t
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import attr, orjson, bentoml, openllm
|
||||
from openllm.utils.lazy import VersionInfo
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ghapi import all
|
||||
from openllm._typing_compat import RefTuple
|
||||
|
||||
from openllm._types import DictStrAny, RefTuple
|
||||
else:
|
||||
all = openllm.utils.LazyLoader("all", globals(), "ghapi.all")
|
||||
all = openllm.utils.LazyLoader("all", globals(), "ghapi.all") # noqa: F811
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BUILDER = bentoml.container.get_backend("buildx")
|
||||
ROOT_DIR = pathlib.Path(__file__).parent.parent.parent
|
||||
ROOT_DIR = pathlib.Path(os.path.abspath("__file__")).parent.parent.parent
|
||||
|
||||
# TODO: support quay
|
||||
LiteralContainerRegistry = t.Literal["docker", "gh", "ecr"]
|
||||
@@ -45,14 +35,10 @@ _module_location = openllm.utils.pkg.source_locations("openllm")
|
||||
|
||||
@functools.lru_cache
|
||||
@openllm.utils.apply(str.lower)
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str:
|
||||
return _CONTAINER_REGISTRY[reg]
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str: return _CONTAINER_REGISTRY[reg]
|
||||
|
||||
def _convert_version_from_string(s: str) -> openllm.utils.VersionInfo:
|
||||
return openllm.utils.VersionInfo.from_version_string(s)
|
||||
|
||||
def _commit_time_range(r: int = 5) -> str:
|
||||
return (datetime.now(timezone.utc) - timedelta(days=r)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
def _convert_version_from_string(s: str) -> VersionInfo: return VersionInfo.from_version_string(s)
|
||||
def _commit_time_range(r: int = 5) -> str: return (datetime.now(timezone.utc) - timedelta(days=r)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
class VersionNotSupported(openllm.exceptions.OpenLLMException):
|
||||
"""Raised when the stable release is too low that it doesn't include OpenLLM base container."""
|
||||
@@ -66,54 +52,42 @@ def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
docker_bin = shutil.which("docker")
|
||||
if docker_bin is None:
|
||||
logger.warning("To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)")
|
||||
commits = t.cast("list[DictStrAny]", cls._ghapi.repos.list_commits(since=_commit_time_range()))
|
||||
commits = t.cast("list[dict[str, t.Any]]", cls._ghapi.repos.list_commits(since=_commit_time_range()))
|
||||
return next(f'sha-{it["sha"][:7]}' for it in commits if "[skip ci]" not in it["commit"]["message"])
|
||||
# now is the correct behaviour
|
||||
return orjson.loads(subprocess.check_output([docker_bin, "run", "--rm", "-it", "quay.io/skopeo/stable:latest", "list-tags", "docker://ghcr.io/bentoml/openllm"]).decode().strip())["Tags"][-2]
|
||||
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class RefResolver:
|
||||
"""TODO: Support offline mode.
|
||||
|
||||
Maybe we need to save git hash when building the Bento.
|
||||
"""
|
||||
git_hash: str = attr.field()
|
||||
version: openllm.utils.VersionInfo = attr.field(converter=_convert_version_from_string)
|
||||
strategy: LiteralContainerVersionStrategy = attr.field()
|
||||
_ghapi: all.GhApi = all.GhApi(owner=_OWNER, repo=_REPO) # TODO: support offline mode
|
||||
|
||||
_ghapi: t.ClassVar[all.GhApi] = all.GhApi(owner=_OWNER, repo=_REPO)
|
||||
@classmethod
|
||||
def _nightly_ref(cls) -> RefTuple:
|
||||
return _RefTuple((nightly_resolver(cls), "refs/heads/main", "nightly"))
|
||||
|
||||
def _nightly_ref(cls) -> RefTuple: return _RefTuple((nightly_resolver(cls), "refs/heads/main", "nightly"))
|
||||
@classmethod
|
||||
def _release_ref(cls, version_str: str | None = None) -> RefTuple:
|
||||
_use_base_strategy = version_str is None
|
||||
if version_str is None:
|
||||
# NOTE: This strategy will only support openllm>0.2.12
|
||||
meta: DictStrAny = cls._ghapi.repos.get_latest_release()
|
||||
meta: dict[str, t.Any] = cls._ghapi.repos.get_latest_release()
|
||||
version_str = meta["name"].lstrip("v")
|
||||
version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")["object"]["sha"], version_str)
|
||||
else:
|
||||
version = ("", version_str)
|
||||
else: version = ("", version_str)
|
||||
if openllm.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12): raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
|
||||
return _RefTuple((*version, "release" if _use_base_strategy else "custom"))
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def from_strategy(cls, strategy_or_version: t.Literal["release", "nightly"] | str | None = None) -> RefResolver:
|
||||
if strategy_or_version is None or strategy_or_version == "release":
|
||||
logger.debug("Using default strategy 'release' for resolving base image version.")
|
||||
return cls(*cls._release_ref())
|
||||
elif strategy_or_version == "latest":
|
||||
return cls("latest", "0.0.0", "latest")
|
||||
# using default strategy
|
||||
if strategy_or_version is None or strategy_or_version == "release": return cls(*cls._release_ref())
|
||||
elif strategy_or_version == "latest": return cls("latest", "0.0.0", "latest")
|
||||
elif strategy_or_version == "nightly":
|
||||
_ref = cls._nightly_ref()
|
||||
return cls(_ref[0], "0.0.0", _ref[-1])
|
||||
else:
|
||||
logger.warning("Using custom %s. Make sure that it is at lease 0.2.12 for base container support.", strategy_or_version)
|
||||
return cls(*cls._release_ref(version_str=strategy_or_version))
|
||||
|
||||
@property
|
||||
def tag(self) -> str:
|
||||
# NOTE: latest tag can also be nightly, but discouraged to use it. For nightly refer to use sha-<git_hash_short>
|
||||
@@ -122,33 +96,24 @@ class RefResolver:
|
||||
else: return repr(self.version)
|
||||
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str:
|
||||
return RefResolver.from_strategy(strategy).tag
|
||||
|
||||
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str: return RefResolver.from_strategy(strategy).tag
|
||||
def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None, version_strategy: LiteralContainerVersionStrategy = "release", push: bool = False, machine: bool = False) -> dict[str | LiteralContainerRegistry, str]:
|
||||
"""This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed.
|
||||
|
||||
Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry.
|
||||
"""
|
||||
try:
|
||||
if not _BUILDER.health(): raise openllm.exceptions.Error
|
||||
except (openllm.exceptions.Error, subprocess.CalledProcessError):
|
||||
raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
|
||||
except (openllm.exceptions.Error, subprocess.CalledProcessError): raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
|
||||
if openllm.utils.device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
|
||||
if not shutil.which("nvidia-container-runtime"): raise RuntimeError("NVIDIA Container Toolkit is required to compile CUDA kernel in container.")
|
||||
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml"
|
||||
if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
|
||||
if t.TYPE_CHECKING: tags: dict[str | LiteralContainerRegistry, str]
|
||||
if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy
|
||||
if not registries: tags: dict[str | LiteralContainerRegistry, str] = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy
|
||||
else:
|
||||
registries = [registries] if isinstance(registries, str) else list(registries)
|
||||
tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries}
|
||||
try:
|
||||
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), push=push, progress="plain" if openllm.utils.get_debug_mode() else "auto", quiet=machine)
|
||||
if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip()
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
|
||||
except Exception as err: raise openllm.exceptions.OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
|
||||
return tags
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -156,10 +121,7 @@ if t.TYPE_CHECKING:
|
||||
supported_registries: list[str]
|
||||
|
||||
__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
def __dir__() -> list[str]: return sorted(__all__)
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == "supported_registries": return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
|
||||
elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY
|
||||
|
||||
@@ -1,37 +1,23 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import importlib.util
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import click_option_group as cog
|
||||
import inflection
|
||||
import orjson
|
||||
import functools, importlib.util, os, typing as t
|
||||
import click, click_option_group as cog, inflection, orjson, bentoml, openllm
|
||||
from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click.shell_completion import CompletionItem
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
|
||||
from openllm._typing_compat import LiteralString, DictStrAny, ParamSpec, Concatenate
|
||||
from . import termui
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import subprocess
|
||||
from openllm._configuration import LLMConfig
|
||||
|
||||
from .._configuration import LLMConfig
|
||||
from .._types import DictStrAny, P
|
||||
TupleStr = tuple[str, ...]
|
||||
else:
|
||||
TupleStr = tuple
|
||||
|
||||
P = ParamSpec("P")
|
||||
LiteralOutput = t.Literal["json", "pretty", "porcelain"]
|
||||
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny:
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}']
|
||||
@@ -44,17 +30,15 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res
|
||||
|
||||
_adapter_mapping_key = "adapter_map"
|
||||
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None:
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
|
||||
if not value: return None
|
||||
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
|
||||
# try to resolve the full path if users pass in relative,
|
||||
# currently only support one level of resolve path with current directory
|
||||
try:
|
||||
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
try: adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError: pass
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
return None
|
||||
|
||||
@@ -104,7 +88,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
@start_decorator(llm_config, serve_grpc=_serve_grpc)
|
||||
@click.pass_context
|
||||
def start_cmd(
|
||||
ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString, device: tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool,
|
||||
ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | LiteralString, device: t.Tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool,
|
||||
serialisation_format: t.Literal["safetensors", "legacy"], adapter_id: str | None, return_process: bool, **attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
fast = str(fast).upper() in openllm.utils.ENV_VARS_TRUE_VALUES
|
||||
@@ -152,7 +136,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
llm = openllm.utils.infer_auto_class(env["framework_value"]).for_model(model, model_id=start_env[env.model_id], model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format)
|
||||
start_env.update({env.config: llm.config.model_dump_json().decode()})
|
||||
|
||||
server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs)
|
||||
server = bentoml.GrpcServer("_service:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service:svc", **server_attrs)
|
||||
openllm.utils.analytics.track_start_init(llm.config)
|
||||
|
||||
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
|
||||
@@ -192,7 +176,7 @@ def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, *
|
||||
|
||||
return noop
|
||||
|
||||
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
|
||||
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
|
||||
if adapter_map and not openllm.utils.is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config['env']['framework']}=\"pt\"' to run with quantization.")
|
||||
requirements = llm_config["requirements"]
|
||||
@@ -201,8 +185,8 @@ def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.Li
|
||||
if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow")
|
||||
|
||||
def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
return lambda fn: openllm.utils.compose(
|
||||
*[
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
@@ -246,13 +230,14 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
),
|
||||
cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"),
|
||||
click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True),
|
||||
]
|
||||
)(fn)
|
||||
)
|
||||
return composed(fn)
|
||||
return wrapper
|
||||
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> TupleStr | None:
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
if value is None: return value
|
||||
if not openllm.utils.LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
|
||||
el: TupleStr = tuple(i for k in value for i in k)
|
||||
if not isinstance(value, tuple): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == "all": return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
@@ -269,7 +254,7 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
|
||||
command = "serve" if not serve_grpc else "serve-grpc"
|
||||
group = cog.optgroup.group(f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",)
|
||||
|
||||
def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands[command]
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
@@ -285,7 +270,6 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
|
||||
param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts"))
|
||||
f = cog.optgroup.option(*param_decls, **attrs)(f)
|
||||
return group(f)
|
||||
|
||||
return decorator
|
||||
|
||||
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
|
||||
@@ -299,12 +283,10 @@ def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC |
|
||||
factory = attrs.pop("factory", click)
|
||||
factory_attr = attrs.pop("attr", "option")
|
||||
if factory_attr != "argument": attrs.setdefault("help", "General option for OpenLLM CLI.")
|
||||
|
||||
def decorator(f: FC | None) -> FC:
|
||||
callback = getattr(factory, factory_attr, None)
|
||||
if callback is None: raise ValueError(f"Factory {factory} has no attribute {factory_attr}.")
|
||||
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
|
||||
|
||||
return decorator
|
||||
|
||||
cli_option = functools.partial(_click_factory_type, attr="option")
|
||||
@@ -312,12 +294,8 @@ cli_argument = functools.partial(_click_factory_type, attr="argument")
|
||||
|
||||
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = "pretty", **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
output = ["json", "pretty", "porcelain"]
|
||||
|
||||
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]:
|
||||
return [CompletionItem(it) for it in output]
|
||||
|
||||
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]: return [CompletionItem(it) for it in output]
|
||||
return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True, envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f)
|
||||
|
||||
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True, help="""Whether to skip checking if models is already in store.
|
||||
@@ -325,18 +303,10 @@ def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC
|
||||
This is useful if you already downloaded or setup the model beforehand.
|
||||
""", **attrs
|
||||
)(f)
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
|
||||
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]:
|
||||
return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required)(f)
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]: return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required)(f)
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
@@ -423,10 +393,8 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va
|
||||
value = inflection.underscore(value)
|
||||
if value in _wpr_strategies: return value
|
||||
else:
|
||||
try:
|
||||
float(value) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
|
||||
try: float(value) # type: ignore[arg-type]
|
||||
except ValueError: raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
@@ -1,34 +1,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import itertools, logging, os, re, subprocess, sys, typing as t
|
||||
import bentoml, openllm
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm.exceptions import OpenLLMException
|
||||
|
||||
from . import termui
|
||||
from ._factory import start_command_factory
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._typing_compat import LiteralString, LiteralRuntime
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from openllm._configuration import LiteralRuntime, LLMConfig
|
||||
from openllm._configuration import LLMConfig
|
||||
from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _start(
|
||||
model_name: str, /, *, model_id: str | None = None, timeout: int = 30, workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, device: tuple[str, ...] | t.Literal["all"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
fast: bool = False, adapter_map: dict[t.LiteralString, str | None] | None = None, framework: LiteralRuntime | None = None, additional_args: list[str] | None = None, _serve_grpc: bool = False, __test__: bool = False, **_: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
def _start(model_name: str, /, *, model_id: str | None = None, timeout: int = 30, workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, device: tuple[str, ...] | t.Literal["all"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", fast: bool = False, adapter_map: dict[LiteralString, str | None] | None = None, framework: LiteralRuntime | None = None, additional_args: list[str] | None = None, _serve_grpc: bool = False, __test__: bool = False, **_: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
"""Python API to start a LLM server. These provides one-to-one mapping to CLI arguments.
|
||||
|
||||
For all additional arguments, pass it as string to ``additional_args``. For example, if you want to
|
||||
@@ -169,9 +156,7 @@ def _build(model_name: str, /, *, model_id: str | None = None, model_version: st
|
||||
if matched is None: raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
|
||||
return bentoml.get(matched.group(1), _bento_store=bento_store)
|
||||
|
||||
def _import_model(
|
||||
model_name: str, /, *, model_id: str | None = None, model_version: str | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", implementation: LiteralRuntime = "pt", quantize: t.Literal["int8", "int4", "gptq"] | None = None, serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors", additional_args: t.Sequence[str] | None = None,
|
||||
) -> bentoml.Model:
|
||||
def _import_model(model_name: str, /, *, model_id: str | None = None, model_version: str | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", implementation: LiteralRuntime = "pt", quantize: t.Literal["int8", "int4", "gptq"] | None = None, serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors", additional_args: t.Sequence[str] | None = None) -> bentoml.Model:
|
||||
"""Import a LLM into local store.
|
||||
|
||||
> [!NOTE]
|
||||
|
||||
@@ -20,36 +20,12 @@ bentomodel = openllm.import_model("falcon", model_id='tiiuae/falcon-7b-instruct'
|
||||
```
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import http.client
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import click
|
||||
import click_option_group as cog
|
||||
import fs
|
||||
import fs.copy
|
||||
import fs.errors
|
||||
import inflection
|
||||
import orjson
|
||||
import functools, http.client, inspect, itertools, logging, os, platform, re, subprocess, sys, time, traceback, typing as t
|
||||
import attr, click, click_option_group as cog, fs, fs.copy, fs.errors, inflection, orjson, bentoml, openllm
|
||||
from bentoml_cli.utils import BentoMLCommandGroup, opt_callback
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
|
||||
from . import termui
|
||||
from ._factory import (
|
||||
FC,
|
||||
@@ -69,9 +45,9 @@ from ._factory import (
|
||||
start_command_factory,
|
||||
workers_per_resource_option,
|
||||
)
|
||||
from .. import bundle, serialisation
|
||||
from ..exceptions import OpenLLMException
|
||||
from ..models.auto import (
|
||||
from openllm import bundle, serialisation
|
||||
from openllm.exceptions import OpenLLMException
|
||||
from openllm.models.auto import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_FLAX_MAPPING_NAMES,
|
||||
MODEL_MAPPING_NAMES,
|
||||
@@ -80,7 +56,8 @@ from ..models.auto import (
|
||||
AutoConfig,
|
||||
AutoLLM,
|
||||
)
|
||||
from ..utils import (
|
||||
from openllm._typing_compat import DictStrAny, ParamSpec, Concatenate, LiteralString, Self, LiteralRuntime
|
||||
from openllm.utils import (
|
||||
DEBUG,
|
||||
DEBUG_ENV_VAR,
|
||||
OPTIONAL_DEPENDENCIES,
|
||||
@@ -105,19 +82,15 @@ from ..utils import (
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.container import DefaultBuilder
|
||||
from openllm.client import BaseClient
|
||||
from openllm._schema import EmbeddingsOutput
|
||||
from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
else: torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
from .._schema import EmbeddingsOutput
|
||||
from .._types import DictStrAny, LiteralRuntime, P
|
||||
from ..bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
else:
|
||||
torch, jupytext, nbformat = LazyLoader("torch", globals(), "torch"), LazyLoader("jupytext", globals(), "jupytext"), LazyLoader("nbformat", globals(), "nbformat")
|
||||
|
||||
P = ParamSpec("P")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_FIGLET = """\
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
@@ -132,9 +105,7 @@ ServeCommand = t.Literal["serve", "serve-grpc"]
|
||||
@attr.define
|
||||
class GlobalOptions:
|
||||
cloud_context: str | None = attr.field(default=None)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> t.Self:
|
||||
return attr.evolve(self, **attrs)
|
||||
def with_options(self, **attrs: t.Any) -> Self: return attr.evolve(self, **attrs)
|
||||
|
||||
GrpType = t.TypeVar("GrpType", bound=click.Group)
|
||||
|
||||
@@ -173,7 +144,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[t.Concatenate[bool, P], t.Any]:
|
||||
def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[Concatenate[bool, P], t.Any]:
|
||||
command_name = attrs.get("name", func.__name__)
|
||||
|
||||
@functools.wraps(func)
|
||||
@@ -197,7 +168,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
event.return_code = 2 if isinstance(e, KeyboardInterrupt) else 1
|
||||
analytics.track(event)
|
||||
raise
|
||||
return t.cast("t.Callable[t.Concatenate[bool, P], t.Any]", wrapper)
|
||||
return t.cast(t.Callable[Concatenate[bool, P], t.Any], wrapper)
|
||||
|
||||
@staticmethod
|
||||
def exception_handling(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[P, t.Any]:
|
||||
@@ -266,20 +237,17 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
for subcommand in self.list_commands(ctx):
|
||||
cmd = self.get_command(ctx, subcommand)
|
||||
if cmd is None or cmd.hidden: continue
|
||||
if subcommand in _cached_extensions:
|
||||
extensions.append((subcommand, cmd))
|
||||
else:
|
||||
commands.append((subcommand, cmd))
|
||||
if subcommand in _cached_extensions: extensions.append((subcommand, cmd))
|
||||
else: commands.append((subcommand, cmd))
|
||||
# allow for 3 times the default spacing
|
||||
if len(commands):
|
||||
limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands)
|
||||
rows = []
|
||||
rows: list[tuple[str, str]]= []
|
||||
for subcommand, cmd in commands:
|
||||
help = cmd.get_short_help_str(limit)
|
||||
rows.append((subcommand, help))
|
||||
if rows:
|
||||
with formatter.section(_("Commands")):
|
||||
formatter.write_dl(rows)
|
||||
with formatter.section(_("Commands")): formatter.write_dl(rows)
|
||||
if len(extensions):
|
||||
limit = formatter.width - 6 - max(len(cmd[0]) for cmd in extensions)
|
||||
rows = []
|
||||
@@ -287,8 +255,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
help = cmd.get_short_help_str(limit)
|
||||
rows.append((inflection.dasherize(subcommand), help))
|
||||
if rows:
|
||||
with formatter.section(_("Extensions")):
|
||||
formatter.write_dl(rows)
|
||||
with formatter.section(_("Extensions")): formatter.write_dl(rows)
|
||||
|
||||
@click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="openllm")
|
||||
@click.version_option(None, "--version", "-v", message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}")
|
||||
@@ -524,10 +491,11 @@ def build_command(
|
||||
_previously_built = True
|
||||
except bentoml.exceptions.NotFound:
|
||||
bento = bundle.create_bento(
|
||||
bento_tag, llm_fs, llm, workers_per_resource=workers_per_resource, device=device, adapter_map=adapter_map, quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=enable_features, dockerfile_template=dockerfile_template_path, runtime=runtime, container_registry=container_registry, container_version_strategy=container_version_strategy
|
||||
bento_tag, llm_fs, llm, workers_per_resource=workers_per_resource, adapter_map=adapter_map,
|
||||
quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=enable_features, dockerfile_template=dockerfile_template_path, runtime=runtime,
|
||||
container_registry=container_registry, container_version_strategy=container_version_strategy
|
||||
)
|
||||
except Exception as err:
|
||||
raise err from None
|
||||
except Exception as err: raise err from None
|
||||
|
||||
if machine: termui.echo(f"__tag__:{bento.tag}", fg="white")
|
||||
elif output == "pretty":
|
||||
@@ -535,26 +503,17 @@ def build_command(
|
||||
termui.echo("\n" + OPENLLM_FIGLET, fg="white")
|
||||
if not _previously_built: termui.echo(f"Successfully built {bento}.", fg="green")
|
||||
elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg="yellow")
|
||||
termui.echo(
|
||||
"📖 Next steps:\n\n" + "* Push to BentoCloud with 'bentoml push':\n" + f" $ bentoml push {bento.tag}\n\n" + "* Containerize your Bento with 'bentoml containerize':\n" + f" $ bentoml containerize {bento.tag} --opt progress=plain" + "\n\n" + " Tip: To enable additional BentoML features for 'containerize', " + "use '--enable-features=FEATURE[,FEATURE]' " +
|
||||
"[see 'bentoml containerize -h' for more advanced usage]\n", fg="blue",
|
||||
)
|
||||
elif output == "json":
|
||||
termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
termui.echo(bento.tag)
|
||||
termui.echo("📖 Next steps:\n\n" + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" + f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" + "\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n", fg="blue",)
|
||||
elif output == "json": termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
|
||||
else: termui.echo(bento.tag)
|
||||
|
||||
if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push)
|
||||
elif containerize:
|
||||
backend = t.cast("DefaultBuilder", os.environ.get("BENTOML_CONTAINERIZE_BACKEND", "docker"))
|
||||
try:
|
||||
bentoml.container.health(backend)
|
||||
except subprocess.CalledProcessError:
|
||||
raise OpenLLMException(f"Failed to use backend {backend}") from None
|
||||
try:
|
||||
bentoml.container.build(bento.tag, backend=backend, features=("grpc", "io"))
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err
|
||||
try: bentoml.container.health(backend)
|
||||
except subprocess.CalledProcessError: raise OpenLLMException(f"Failed to use backend {backend}") from None
|
||||
try: bentoml.container.build(bento.tag, backend=backend, features=("grpc", "io"))
|
||||
except Exception as err: raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err
|
||||
return bento
|
||||
|
||||
@cli.command()
|
||||
@@ -613,7 +572,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
# llm, architecture, url, model_id, installation, cpu, gpu, runtime_impl
|
||||
data: list[str | tuple[str, str, list[str], str, t.LiteralString, t.LiteralString, tuple[LiteralRuntime, ...]]] = []
|
||||
data: list[str | tuple[str, str, list[str], str, LiteralString, LiteralString, tuple[LiteralRuntime, ...]]] = []
|
||||
for m, v in json_data.items():
|
||||
data.extend([(m, v["architecture"], v["model_id"], v["installation"], "❌" if not v["cpu"] else "✅", "✅", v["runtime_impl"],)])
|
||||
column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 4),]
|
||||
@@ -622,7 +581,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
termui.echo("Exception found while parsing models:\n", fg="yellow")
|
||||
for m, err in failed_initialized:
|
||||
termui.echo(f"- {m}: ", fg="yellow", nl=False)
|
||||
termui.echo(traceback.print_exception(err, limit=3), fg="red")
|
||||
termui.echo(traceback.print_exception(None, err, None, limit=5), fg="red") # type: ignore[func-returns-value]
|
||||
sys.exit(1)
|
||||
|
||||
table = tabulate.tabulate(data, tablefmt="fancy_grid", headers=["LLM", "Architecture", "Models Id", "pip install", "CPU", "GPU", "Runtime"], maxcolwidths=column_widths)
|
||||
@@ -699,7 +658,7 @@ def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal
|
||||
@click.option("--remote", is_flag=True, default=False, help="Whether or not to use remote tools (inference endpoints) instead of local ones.", show_default=True)
|
||||
@click.option("--opt", help="Define prompt options. "
|
||||
"(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]")
|
||||
def instruct_command(endpoint: str, timeout: int, agent: t.LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str:
|
||||
def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str:
|
||||
"""Instruct agents interactively for given tasks, from a terminal.
|
||||
|
||||
\b
|
||||
|
||||
@@ -1,20 +1,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import typing as t, bentoml, openllm, orjson, inflection ,click
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
|
||||
from .. import termui
|
||||
from .._factory import LiteralOutput, model_name_argument, output_option
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import LiteralOutput, model_name_argument, output_option
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..._types import DictStrAny
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import DictStrAny
|
||||
|
||||
@click.command("list_models", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False)
|
||||
|
||||
@@ -1,26 +1,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib.machinery
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
from .. import termui
|
||||
from ... import playground
|
||||
from ...utils import is_jupyter_available, is_jupytext_available, is_notebook_available
|
||||
import importlib.machinery, logging, os, pkgutil, subprocess, sys, tempfile, typing as t
|
||||
import click, yaml
|
||||
from openllm.cli import termui
|
||||
from openllm import playground
|
||||
from openllm.utils import is_jupyter_available, is_jupytext_available, is_notebook_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import jupytext
|
||||
import nbformat
|
||||
|
||||
from openllm._types import DictStrAny
|
||||
import jupytext, nbformat
|
||||
from openllm._typing_compat import DictStrAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,22 +1,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import inflection
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .._types import DictStrAny
|
||||
import os, typing as t, click, inflection, openllm
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import DictStrAny
|
||||
|
||||
def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.Any) -> None:
|
||||
attrs["fg"] = fg if not openllm.utils.get_debug_mode() else None
|
||||
if not openllm.utils.get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(text, **attrs)
|
||||
|
||||
COLUMNS: int = int(os.environ.get("COLUMNS", str(120)))
|
||||
|
||||
CONTEXT_SETTINGS: DictStrAny = {"help_option_names": ["-h", "--help"], "max_content_width": COLUMNS, "token_normalize_func": inflection.underscore}
|
||||
|
||||
__all__ = ["echo", "COLUMNS", "CONTEXT_SETTINGS"]
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""The actual client implementation.
|
||||
"""OpenLLM Python client.
|
||||
|
||||
Use ``openllm.client`` instead.
|
||||
This holds the implementation of the client, which is used to communicate with the
|
||||
OpenLLM server. It is used to send requests to the server, and receive responses.
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
```
|
||||
|
||||
If the server has embedding supports, use it via `client.embed`:
|
||||
```python
|
||||
client.embed("What is the difference between gather and scatter?")
|
||||
```
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from .runtimes import (
|
||||
from openllm.client.runtimes import (
|
||||
AsyncGrpcClient as AsyncGrpcClient,
|
||||
AsyncHTTPClient as AsyncHTTPClient,
|
||||
BaseAsyncClient as BaseAsyncClient,
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Client that supports REST/gRPC protocol to interact with a LLMServer."""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import (
|
||||
from openllm.client.runtimes.base import (
|
||||
BaseAsyncClient as BaseAsyncClient,
|
||||
BaseClient as BaseClient,
|
||||
)
|
||||
from .grpc import (
|
||||
from openllm.client.runtimes.grpc import (
|
||||
AsyncGrpcClient as AsyncGrpcClient,
|
||||
GrpcClient as GrpcClient,
|
||||
)
|
||||
from .http import (
|
||||
from openllm.client.runtimes.http import (
|
||||
AsyncHTTPClient as AsyncHTTPClient,
|
||||
HTTPClient as HTTPClient,
|
||||
)
|
||||
|
||||
@@ -1,51 +1,35 @@
|
||||
# mypy: disable-error-code="name-defined"
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import typing as t
|
||||
import asyncio, logging, typing as t
|
||||
import bentoml, bentoml.client, openllm, httpx
|
||||
from abc import abstractmethod
|
||||
from http import HTTPStatus
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.client import Client
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
from openllm._typing_compat import overload, LiteralString
|
||||
|
||||
T = t.TypeVar("T")
|
||||
T_co = t.TypeVar("T_co", covariant=True)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
from openllm._typing_compat import DictStrAny, LiteralRuntime
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") # noqa: F811
|
||||
|
||||
from openllm._types import DictStrAny, LiteralRuntime
|
||||
|
||||
class AnnotatedClient(Client, t.Generic[T]):
|
||||
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
async def async_health(self) -> t.Any:
|
||||
...
|
||||
|
||||
def generate_v1(self, qa: openllm.GenerationInput) -> T:
|
||||
...
|
||||
|
||||
def metadata_v1(self) -> T:
|
||||
...
|
||||
|
||||
def embeddings_v1(self) -> t.Sequence[float]:
|
||||
...
|
||||
else:
|
||||
AnnotatedClient = Client
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
DictStrAny = dict
|
||||
class AnnotatedClient(t.Protocol[T_co]):
|
||||
server_url: str
|
||||
_svc: bentoml.Service
|
||||
endpoints: list[str]
|
||||
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ...
|
||||
async def async_health(self) -> t.Any: ...
|
||||
def generate_v1(self, qa: openllm.GenerationInput) -> T_co: ...
|
||||
def metadata_v1(self) -> T_co: ...
|
||||
def embeddings_v1(self) -> t.Sequence[float]: ...
|
||||
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T_co: ...
|
||||
async def async_call(self, name: str, *args: t.Any, **attrs: t.Any) -> T_co: ...
|
||||
@staticmethod
|
||||
def wait_until_server_ready(host: str, port: int, timeout: float = 30, **kwargs: t.Any) -> None: ...
|
||||
@staticmethod
|
||||
def from_url(server_url: str) -> AnnotatedClient[t.Any]: ...
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,12 +37,11 @@ def in_async_context() -> bool:
|
||||
try:
|
||||
_ = asyncio.get_running_loop()
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
except RuntimeError: return False
|
||||
|
||||
class ClientMeta(t.Generic[T]):
|
||||
_api_version: str
|
||||
_client_class: type[bentoml.client.Client]
|
||||
_client_type: t.Literal["GrpcClient", "HTTPClient"]
|
||||
_host: str
|
||||
_port: str
|
||||
|
||||
@@ -66,15 +49,8 @@ class ClientMeta(t.Generic[T]):
|
||||
__agent__: transformers.HfAgent | None = None
|
||||
__llm__: openllm.LLM[t.Any, t.Any] | None = None
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._address = address
|
||||
self._timeout = timeout
|
||||
|
||||
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
|
||||
"""Initialise subclass for HTTP and gRPC client type."""
|
||||
cls._client_class = t.cast(t.Type[bentoml.client.Client], bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient)
|
||||
cls._api_version = api_version
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30): self._address,self._timeout = address,timeout
|
||||
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): cls._client_type, cls._api_version = "HTTPClient" if client_type == "http" else "GrpcClient", api_version
|
||||
@property
|
||||
def _hf_agent(self) -> transformers.HfAgent:
|
||||
if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.")
|
||||
@@ -82,99 +58,62 @@ class ClientMeta(t.Generic[T]):
|
||||
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
|
||||
self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent"))
|
||||
return self.__agent__
|
||||
|
||||
@property
|
||||
def _metadata(self) -> T:
|
||||
if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
|
||||
return self.call("metadata")
|
||||
|
||||
def _metadata(self) -> T: return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() if in_async_context() else self.call("metadata")
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def model_name(self) -> str: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def framework(self) -> LiteralRuntime:
|
||||
raise NotImplementedError
|
||||
|
||||
def framework(self) -> LiteralRuntime: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def timeout(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def timeout(self) -> int: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def model_id(self) -> str: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def configuration(self) -> dict[str, t.Any]: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_embeddings(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_embeddings(self) -> bool: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_hf_agent(self) -> bool:
|
||||
raise NotImplementedError
|
||||
def supports_hf_agent(self) -> bool: raise NotImplementedError
|
||||
@abstractmethod
|
||||
def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ...
|
||||
@abstractmethod
|
||||
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
|
||||
|
||||
@property
|
||||
def config(self) -> openllm.LLMConfig: return self.llm.config
|
||||
@property
|
||||
def llm(self) -> openllm.LLM[t.Any, t.Any]:
|
||||
# XXX: if the server runs vllm or any framework that is not available from the user client, client will fail.
|
||||
if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
|
||||
return self.__llm__
|
||||
|
||||
@property
|
||||
def config(self) -> openllm.LLMConfig:
|
||||
return self.llm.config
|
||||
|
||||
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T:
|
||||
return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
|
||||
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> T:
|
||||
return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
|
||||
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> T: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
@property
|
||||
def _cached(self) -> AnnotatedClient[T]:
|
||||
client_class = t.cast(AnnotatedClient[T], getattr(bentoml.client, self._client_type))
|
||||
if self.__client__ is None:
|
||||
self._client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
|
||||
self.__client__ = t.cast("AnnotatedClient[T]", self._client_class.from_url(self._address))
|
||||
client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
|
||||
self.__client__ = client_class.from_url(self._address)
|
||||
return self.__client__
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, result: t.Any) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
class BaseClient(ClientMeta[T]):
|
||||
def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
def health(self) -> t.Any: raise NotImplementedError
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
|
||||
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
|
||||
...
|
||||
|
||||
def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
|
||||
return_raw_response = attrs.pop("return_raw_response", None)
|
||||
if return_raw_response is not None:
|
||||
@@ -197,21 +136,14 @@ class BaseClient(ClientMeta[T]):
|
||||
|
||||
# NOTE: Scikit interface
|
||||
@overload
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
|
||||
...
|
||||
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
|
||||
@overload
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
|
||||
...
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs))
|
||||
|
||||
def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
|
||||
return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs))
|
||||
|
||||
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
@@ -220,34 +152,21 @@ class BaseClient(ClientMeta[T]):
|
||||
task = kwargs.pop("task", args[0])
|
||||
return_code = kwargs.pop("return_code", False)
|
||||
remote = kwargs.pop("remote", False)
|
||||
try:
|
||||
return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs)
|
||||
try: return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs)
|
||||
except Exception as err:
|
||||
logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err)
|
||||
logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address)
|
||||
logger.info("Tip: LLMServer at '%s' might not support 'generate_one'.", self._address)
|
||||
|
||||
class BaseAsyncClient(ClientMeta[T]):
|
||||
async def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
async def health(self) -> t.Any: raise NotImplementedError
|
||||
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
|
||||
async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
|
||||
...
|
||||
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
|
||||
return_raw_response = attrs.pop("return_raw_response", None)
|
||||
if return_raw_response is not None:
|
||||
@@ -270,25 +189,16 @@ class BaseAsyncClient(ClientMeta[T]):
|
||||
|
||||
# NOTE: Scikit interface
|
||||
@overload
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
|
||||
...
|
||||
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
|
||||
@overload
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
|
||||
return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs))
|
||||
|
||||
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs))
|
||||
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
"""Async version of agent.run."""
|
||||
if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0")
|
||||
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
|
||||
@@ -317,9 +227,7 @@ class BaseAsyncClient(ClientMeta[T]):
|
||||
|
||||
# the below have the same logic as agent.run API
|
||||
explanation, code = clean_code_for_run(result)
|
||||
|
||||
_hf_agent.log(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
_hf_agent.log(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
_hf_agent.log("\n\n==Result==")
|
||||
|
||||
@@ -1,102 +1,93 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import orjson
|
||||
|
||||
import openllm
|
||||
|
||||
import asyncio, logging, typing as t
|
||||
import orjson, openllm
|
||||
from openllm._typing_compat import LiteralRuntime
|
||||
from .base import BaseAsyncClient, BaseClient
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from grpc_health.v1 import health_pb2
|
||||
|
||||
from bentoml.grpc.v1.service_pb2 import Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"]
|
||||
|
||||
class GrpcClientMixin:
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def _metadata(self) -> Response:
|
||||
...
|
||||
|
||||
class GrpcClient(BaseClient["Response"], client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
def health(self) -> health_pb2.HealthCheckResponse: return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata.json.struct_value.fields["model_name"].string_value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata.json.struct_value.fields["model_name"].string_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
value = t.cast(LiteralRuntime, self._metadata.json.struct_value.fields["framework"].string_value)
|
||||
if value not in ("pt", "flax", "tf", "vllm"): raise KeyError
|
||||
return value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try:
|
||||
return int(self._metadata.json.struct_value.fields["timeout"].number_value)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return int(self._metadata.json.struct_value.fields["timeout"].number_value)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try:
|
||||
return self._metadata.json.struct_value.fields["model_id"].string_value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata.json.struct_value.fields["model_id"].string_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try:
|
||||
v = self._metadata.json.struct_value.fields["configuration"].string_value
|
||||
return orjson.loads(v)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return orjson.loads(self._metadata.json.struct_value.fields["configuration"].string_value)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try:
|
||||
return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try:
|
||||
return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
if isinstance(result, dict):
|
||||
return openllm.GenerationOutput(**result)
|
||||
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
if isinstance(result, dict): return openllm.GenerationOutput(**result)
|
||||
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
|
||||
|
||||
class GrpcClient(GrpcClientMixin, BaseClient["Response"], client_type="grpc"):
|
||||
class AsyncGrpcClient(BaseAsyncClient["Response"], client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
|
||||
def health(self) -> health_pb2.HealthCheckResponse:
|
||||
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
|
||||
|
||||
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient["Response"], client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
|
||||
async def health(self) -> health_pb2.HealthCheckResponse:
|
||||
return await self._cached.health("bentoml.grpc.v1.BentoService")
|
||||
async def health(self) -> health_pb2.HealthCheckResponse: return await self._cached.health("bentoml.grpc.v1.BentoService")
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try: return self._metadata.json.struct_value.fields["model_name"].string_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
value = t.cast(LiteralRuntime, self._metadata.json.struct_value.fields["framework"].string_value)
|
||||
if value not in ("pt", "flax", "tf", "vllm"): raise KeyError
|
||||
return value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try: return int(self._metadata.json.struct_value.fields["timeout"].number_value)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try: return self._metadata.json.struct_value.fields["model_id"].string_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try: return orjson.loads(self._metadata.json.struct_value.fields["configuration"].string_value)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try: return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try: return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
if isinstance(result, dict): return openllm.GenerationOutput(**result)
|
||||
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
|
||||
|
||||
@@ -1,110 +1,98 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
import logging, typing as t
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
import orjson
|
||||
|
||||
import openllm
|
||||
|
||||
import httpx, orjson, openllm
|
||||
from .base import BaseAsyncClient, BaseClient, in_async_context
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._types import DictStrAny, LiteralRuntime
|
||||
else:
|
||||
DictStrAny = dict
|
||||
from openllm._typing_compat import DictStrAny, LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
def process_address(self: AsyncHTTPClient | HTTPClient, address: str) -> None:
|
||||
address = address if "://" in address else "http://" + address
|
||||
parsed = urlparse(address)
|
||||
self._host, *_port = parsed.netloc.split(":")
|
||||
if len(_port) == 0: self._port = "80" if parsed.scheme == "http" else "443"
|
||||
else: self._port = next(iter(_port))
|
||||
|
||||
class HTTPClientMixin:
|
||||
if t.TYPE_CHECKING:
|
||||
class HTTPClient(BaseClient[DictStrAny]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
process_address(self, address)
|
||||
super().__init__(address, timeout)
|
||||
|
||||
@property
|
||||
def _metadata(self) -> DictStrAny:
|
||||
...
|
||||
def health(self) -> t.Any: return self._cached.health()
|
||||
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
|
||||
if not self.supports_embeddings: raise ValueError("This model does not support embeddings.")
|
||||
if isinstance(prompt, str): prompt = [prompt]
|
||||
result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout).json() if in_async_context() else self.call("embeddings", list(prompt))
|
||||
return openllm.EmbeddingsOutput(**result)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
return self._metadata["framework"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata["framework"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try:
|
||||
return self._metadata["timeout"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata["timeout"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try:
|
||||
return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try:
|
||||
return self._metadata.get("supports_embeddings", False)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata.get("supports_embeddings", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try:
|
||||
return self._metadata.get("supports_hf_agent", False)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
try: return self._metadata.get("supports_hf_agent", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result)
|
||||
|
||||
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
return openllm.GenerationOutput(**result)
|
||||
|
||||
class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]):
|
||||
class AsyncHTTPClient(BaseAsyncClient[DictStrAny]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
process_address(self, address)
|
||||
super().__init__(address, timeout)
|
||||
|
||||
def health(self) -> t.Any:
|
||||
return self._cached.health()
|
||||
|
||||
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
|
||||
if not self.supports_embeddings:
|
||||
raise ValueError("This model does not support embeddings.")
|
||||
if isinstance(prompt, str): prompt = [prompt]
|
||||
if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout).json()
|
||||
else: result = self.call("embeddings", list(prompt))
|
||||
return openllm.EmbeddingsOutput(**result)
|
||||
|
||||
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
super().__init__(address, timeout)
|
||||
|
||||
async def health(self) -> t.Any:
|
||||
return await self._cached.async_health()
|
||||
|
||||
async def health(self) -> t.Any: return await self._cached.async_health()
|
||||
async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
|
||||
if not self.supports_embeddings:
|
||||
raise ValueError("This model does not support embeddings.")
|
||||
if not self.supports_embeddings: raise ValueError("This model does not support embeddings.")
|
||||
if isinstance(prompt, str): prompt = [prompt]
|
||||
res = await self.acall("embeddings", list(prompt))
|
||||
return openllm.EmbeddingsOutput(**res)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try: return self._metadata["framework"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try: return self._metadata["timeout"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try: return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try: return self._metadata.get("supports_embeddings", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try: return self._metadata.get("supports_hf_agent", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result)
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import bentoml
|
||||
|
||||
class OpenLLMException(bentoml.exceptions.BentoMLException):
|
||||
"""Base class for all OpenLLM exceptions. This extends BentoMLException."""
|
||||
|
||||
class GpuNotAvailableError(OpenLLMException):
|
||||
"""Raised when there is no GPU available in given system."""
|
||||
|
||||
class ValidationError(OpenLLMException):
|
||||
"""Raised when a validation fails."""
|
||||
|
||||
class ForbiddenAttributeError(OpenLLMException):
|
||||
"""Raised when using an _internal field."""
|
||||
|
||||
class MissingAnnotationAttributeError(OpenLLMException):
|
||||
"""Raised when a field under openllm.LLMConfig is missing annotations."""
|
||||
|
||||
class MissingDependencyError(BaseException):
|
||||
"""Raised when a dependency is missing."""
|
||||
|
||||
class Error(BaseException):
|
||||
"""To be used instead of naked raise."""
|
||||
|
||||
class FineTuneStrategyNotSupportedError(OpenLLMException):
|
||||
"""Raised when a fine-tune strategy is not supported for given LLM."""
|
||||
|
||||
4
src/openllm/models/__init__.py
generated
4
src/openllm/models/__init__.py
generated
@@ -1,11 +1,11 @@
|
||||
# This file is generated by tools/update-models-import.py. DO NOT EDIT MANUALLY!
|
||||
# To update this, run ./tools/update-models-import.py
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
import typing as t, os
|
||||
from openllm.utils import LazyModule
|
||||
_MODELS: set[str] = {"auto", "baichuan", "chatglm", "dolly_v2", "falcon", "flan_t5", "gpt_neox", "llama", "mpt", "opt", "stablelm", "starcoder"}
|
||||
if t.TYPE_CHECKING: from . import auto as auto,baichuan as baichuan,chatglm as chatglm,dolly_v2 as dolly_v2,falcon as falcon,flan_t5 as flan_t5,gpt_neox as gpt_neox,llama as llama,mpt as mpt,opt as opt,stablelm as stablelm,starcoder as starcoder
|
||||
__lazy=LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS})
|
||||
__lazy=LazyModule(__name__, os.path.abspath("__file__"), {k: [] for k in _MODELS})
|
||||
__all__=__lazy.__all__
|
||||
__dir__=__lazy.__dir__
|
||||
__getattr__=__lazy.__getattr__
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import typing as t, os
|
||||
import openllm
|
||||
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available
|
||||
|
||||
@@ -40,4 +39,7 @@ else:
|
||||
_import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
|
||||
if t.TYPE_CHECKING: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
__lazy=LazyModule(__name__, os.path.abspath("__file__"), _import_structure)
|
||||
__all__=__lazy.__all__
|
||||
__dir__=__lazy.__dir__
|
||||
__getattr__=__lazy.__getattr__
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
import inflection
|
||||
|
||||
import openllm
|
||||
import inflection, openllm
|
||||
from openllm.utils import ReprMixin
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
from openllm._typing_compat import LiteralString
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]]
|
||||
ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]]
|
||||
ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]]
|
||||
ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]]
|
||||
@@ -18,8 +17,8 @@ if t.TYPE_CHECKING:
|
||||
# NOTE: This is the entrypoint when adding new model config
|
||||
CONFIG_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLMConfig"), ("dolly_v2", "DollyV2Config"), ("falcon", "FalconConfig"), ("flan_t5", "FlanT5Config"), ("gpt_neox", "GPTNeoXConfig"), ("llama", "LlamaConfig"), ("mpt", "MPTConfig"), ("opt", "OPTConfig"), ("stablelm", "StableLMConfig"), ("starcoder", "StarCoderConfig"), ("baichuan", "BaichuanConfig")])
|
||||
|
||||
class _LazyConfigMapping(OrderedDict, ReprMixin): # type: ignore[type-arg]
|
||||
def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]):
|
||||
class _LazyConfigMapping(OrderedDict, ReprMixin):
|
||||
def __init__(self, mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
self._mapping = mapping
|
||||
self._extra_content: dict[str, t.Any] = {}
|
||||
self._modules: dict[str, types.ModuleType] = {}
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import typing as t
|
||||
import importlib, inspect, logging, typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
import inflection
|
||||
|
||||
import openllm
|
||||
import inflection, openllm
|
||||
from openllm.utils import ReprMixin
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._typing_compat import LiteralString, LLMRunner
|
||||
import types
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
|
||||
from _typeshed import SupportsIter
|
||||
|
||||
from ..._llm import LLMRunner
|
||||
ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
@@ -25,7 +18,7 @@ if t.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAutoLLMClass:
|
||||
_model_mapping: _LazyAutoMapping
|
||||
_model_mapping: t.ClassVar[_LazyAutoMapping]
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any): raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
|
||||
@classmethod
|
||||
def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
|
||||
@@ -83,13 +76,13 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
|
||||
except ValueError: raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
|
||||
raise ValueError(f"Could not find {attr} in {openllm_module}!")
|
||||
|
||||
class _LazyAutoMapping(OrderedDict, ReprMixin): # type: ignore[type-arg]
|
||||
class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
|
||||
|
||||
This OrderedDict values() and keys() returns the list instead, so you don't
|
||||
have to do list(mapping.values()) to get the list of values.
|
||||
"""
|
||||
def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]):
|
||||
def __init__(self, config_mapping: OrderedDict[LiteralString, LiteralString], model_mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
self._config_mapping = config_mapping
|
||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||
self._model_mapping = model_mapping
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLM"), ("dolly_v2", "DollyV2"), ("falcon", "Falcon"), ("flan_t5", "FlanT5"), ("gpt_neox", "GPTNeoX"), ("llama", "Llama"), ("mpt", "MPT"), ("opt", "OPT"), ("stablelm", "StableLM"), ("starcoder", "StarCoder"), ("baichuan", "Baichuan")])
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
class AutoLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_MAPPING
|
||||
_model_mapping: t.ClassVar = MODEL_MAPPING
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT")])
|
||||
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
class AutoFlaxLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_FLAX_MAPPING
|
||||
_model_mapping: t.ClassVar = MODEL_FLAX_MAPPING
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5"), ("opt", "TFOPT")])
|
||||
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
class AutoTFLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_TF_MAPPING
|
||||
_model_mapping: t.ClassVar = MODEL_TF_MAPPING
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict([("baichuan", "VLLMBaichuan"), ("dolly_v2", "VLLMDollyV2"), ("gpt_neox", "VLLMGPTNeoX"), ("mpt", "VLLMMPT"), ("opt", "VLLMOPT"), ("stablelm", "VLLMStableLM"), ("starcoder", "VLLMStarCoder"), ("llama", "VLLMLlama")])
|
||||
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
|
||||
class AutoVLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_VLLM_MAPPING
|
||||
_model_mapping: t.ClassVar = MODEL_VLLM_MAPPING
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_cpm_kernels_available, is_torch_available, is_vllm_available
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
class BaichuanConfig(openllm.LLMConfig):
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
@@ -15,6 +12,6 @@ class Baichuan(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrai
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined]
|
||||
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMBaichuan(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizerBase"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_cpm_kernels_available, is_torch_available
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
class ChatGLMConfig(openllm.LLMConfig):
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class DollyV2Config(openllm.LLMConfig):
|
||||
|
||||
@@ -1,105 +1,107 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, re, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from openllm._typing_compat import overload
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE, END_KEY, RESPONSE_KEY, get_special_token_id
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
|
||||
else: torch, transformers, tf = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("tf", globals(), "tensorflow")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
|
||||
class InstructionTextGenerationPipeline(transformers.Pipeline):
|
||||
def __init__(self, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, /, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
|
||||
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
preprocess_params: dict[str, t.Any] = {}
|
||||
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
|
||||
# append a newline to yield a single token. find whatever token is configured for the response key.
|
||||
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
|
||||
response_key_token_id = None
|
||||
end_key_token_id = None
|
||||
if tokenizer_response_key:
|
||||
try:
|
||||
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
|
||||
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
|
||||
# Ensure generation stops once it generates "### End"
|
||||
generate_kwargs["eos_token_id"] = end_key_token_id
|
||||
except ValueError: pass
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
|
||||
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
def preprocess(self, input_: str, **generate_kwargs: t.Any) -> dict[str, t.Any]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
|
||||
inputs = self.tokenizer(prompt_text, return_tensors="pt")
|
||||
inputs["prompt_text"] = prompt_text
|
||||
inputs["instruction_text"] = input_
|
||||
return inputs
|
||||
def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any) -> dict[str, t.Any]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None)
|
||||
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
|
||||
else: in_b = input_ids.shape[0]
|
||||
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
instruction_text = model_inputs.pop("instruction_text")
|
||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
||||
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
_generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
|
||||
generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist()
|
||||
records: list[dict[t.Literal["generated_text"], str]] = []
|
||||
for sequence in generated_sequence:
|
||||
# The response will be set to this variable if we can identify it.
|
||||
decoded = None
|
||||
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
||||
if response_key_token_id and end_key_token_id:
|
||||
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
||||
# prompt, we should definitely find it. We will return the tokens found after this token.
|
||||
try: response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError: response_pos = None
|
||||
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
|
||||
if response_pos:
|
||||
# Next find where "### End" is located. The model has been trained to end its responses with this
|
||||
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
||||
# this token, as the response could be truncated. If we don't find it then just return everything
|
||||
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
||||
try: end_pos = sequence.index(end_key_token_id)
|
||||
except ValueError: end_pos = None
|
||||
decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip()
|
||||
if not decoded:
|
||||
# Otherwise we'll decode everything and use a regex to find the response and end.
|
||||
fully_decoded = self.tokenizer.decode(sequence)
|
||||
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
||||
# end.
|
||||
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else:
|
||||
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
||||
# return everything after "### Response:".
|
||||
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline: ...
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]: ...
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
|
||||
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
|
||||
class InstructionTextGenerationPipeline(transformers.Pipeline):
|
||||
def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
|
||||
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
preprocess_params: dict[str, t.Any] = {}
|
||||
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
|
||||
# append a newline to yield a single token. find whatever token is configured for the response key.
|
||||
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
|
||||
response_key_token_id = None
|
||||
end_key_token_id = None
|
||||
if tokenizer_response_key:
|
||||
try:
|
||||
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
|
||||
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
|
||||
# Ensure generation stops once it generates "### End"
|
||||
generate_kwargs["eos_token_id"] = end_key_token_id
|
||||
except ValueError: pass
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
|
||||
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
def preprocess(self, input_: str, **generate_kwargs: t.Any) -> t.Dict[str, t.Any]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
|
||||
inputs = self.tokenizer(prompt_text, return_tensors="pt")
|
||||
inputs["prompt_text"] = prompt_text
|
||||
inputs["instruction_text"] = input_
|
||||
return t.cast(t.Dict[str, t.Any], inputs)
|
||||
def _forward(self, input_tensors: dict[str, t.Any], **generate_kwargs: t.Any) -> transformers.utils.generic.ModelOutput:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
input_ids, attention_mask = input_tensors["input_ids"], input_tensors.get("attention_mask", None)
|
||||
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
|
||||
else: in_b = input_ids.shape[0]
|
||||
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
instruction_text = input_tensors.pop("instruction_text")
|
||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
||||
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
_generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
|
||||
generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist()
|
||||
records: list[dict[t.Literal["generated_text"], str]] = []
|
||||
for sequence in generated_sequence:
|
||||
# The response will be set to this variable if we can identify it.
|
||||
decoded = None
|
||||
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
||||
if response_key_token_id and end_key_token_id:
|
||||
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
||||
# prompt, we should definitely find it. We will return the tokens found after this token.
|
||||
try: response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError: response_pos = None
|
||||
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
|
||||
if response_pos:
|
||||
# Next find where "### End" is located. The model has been trained to end its responses with this
|
||||
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
||||
# this token, as the response could be truncated. If we don't find it then just return everything
|
||||
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
||||
try: end_pos = sequence.index(end_key_token_id)
|
||||
except ValueError: end_pos = None
|
||||
decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip()
|
||||
if not decoded:
|
||||
# Otherwise we'll decode everything and use a regex to find the response and end.
|
||||
fully_decoded = self.tokenizer.decode(sequence)
|
||||
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
||||
# end.
|
||||
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
|
||||
records.append({"generated_text": t.cast(str, decoded)})
|
||||
return records
|
||||
else:
|
||||
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
||||
# return everything after "### Response:".
|
||||
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
|
||||
records.append({"generated_text": t.cast(str, decoded)})
|
||||
return records
|
||||
return InstructionTextGenerationPipeline() if _init else InstructionTextGenerationPipeline
|
||||
|
||||
class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedTokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.bfloat16}, {}
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return InstructionTextGenerationPipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, return_full_text=self.config.return_full_text)
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return get_pipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str: return generation_result[0]["generated_text"]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
class FalconConfig(openllm.LLMConfig):
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() and torch.cuda_device_count() > 1 else None}, {}
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined]
|
||||
return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True)
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
class FlanT5Config(openllm.LLMConfig):
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
class GPTNeoXConfig(openllm.LLMConfig):
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import typing as t, openllm, logging
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
import typing as t, openllm
|
||||
|
||||
class LlamaConfig(openllm.LLMConfig):
|
||||
"""LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
@@ -25,4 +22,4 @@ class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaToke
|
||||
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
|
||||
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item())
|
||||
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item()))
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMLlama(openllm.LLM["vllm.LLMEngine", "transformers.LlamaTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import logging, typing as t, bentoml, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from openllm.utils import generate_labels, is_triton_available
|
||||
|
||||
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE, MPTPromptType
|
||||
|
||||
if t.TYPE_CHECKING: import transformers, torch
|
||||
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
def get_mpt_config(model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True) -> transformers.PretrainedConfig:
|
||||
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
|
||||
if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
|
||||
@@ -62,7 +57,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()}
|
||||
with torch.inference_mode():
|
||||
if torch.cuda.is_available():
|
||||
with torch.autocast("cuda", torch.float16):
|
||||
with torch.autocast("cuda", torch.float16): # type: ignore[attr-defined]
|
||||
generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
else: generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True)
|
||||
|
||||
@@ -1,28 +1,8 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE, MPTPromptType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
import vllm
|
||||
if t.TYPE_CHECKING: import transformers, vllm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMMPT(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
class OPTConfig(openllm.LLMConfig):
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import logging, typing as t, bentoml, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from openllm.utils import generate_labels
|
||||
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import logging, typing as t, bentoml, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from openllm.utils import generate_labels
|
||||
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
class StableLMConfig(openllm.LLMConfig):
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT
|
||||
|
||||
if t.TYPE_CHECKING: import transformers, torch
|
||||
|
||||
@@ -1,25 +1,7 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT
|
||||
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import logging, typing as t, bentoml, openllm
|
||||
from openllm.utils import generate_labels
|
||||
|
||||
from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
@@ -33,7 +27,6 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
|
||||
# default starcoder doesn't include the same value as santacoder EOD
|
||||
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
with torch.inference_mode():
|
||||
|
||||
@@ -1,27 +1,9 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
import logging, typing as t, openllm
|
||||
from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
|
||||
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VLLMStarCoder(openllm.LLM["vllm.LLMEngine", "transformers.GPT2TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
|
||||
@@ -23,24 +23,20 @@ llm.save_pretrained("./path/to/local-dolly")
|
||||
```
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import typing as t
|
||||
|
||||
import cloudpickle
|
||||
import fs
|
||||
|
||||
import openllm
|
||||
import importlib, typing as t
|
||||
import cloudpickle, fs, openllm
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from openllm._typing_compat import M, T, ParamSpec, Concatenate
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._llm import T
|
||||
|
||||
import bentoml
|
||||
from . import (
|
||||
constants as constants,
|
||||
ggml as ggml,
|
||||
transformers as transformers,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
@@ -67,9 +63,8 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
return tokenizer
|
||||
|
||||
_extras = ["get", "import_model", "save_pretrained", "load_model"]
|
||||
|
||||
def _make_dispatch_function(fn: str) -> t.Callable[..., t.Any]:
|
||||
def caller(llm: openllm.LLM[t.Any, t.Any], *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[openllm.LLM[t.Any, t.Any], P], t.Any]:
|
||||
def caller(llm: openllm.LLM[t.Any, t.Any], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.runtime="transformers"'
|
||||
@@ -77,11 +72,15 @@ def _make_dispatch_function(fn: str) -> t.Callable[..., t.Any]:
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.runtime="ggml"'
|
||||
"""
|
||||
return getattr(importlib.import_module(f".{llm.runtime}", __name__), fn)(llm, *args, **kwargs)
|
||||
|
||||
return caller
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": [], "constants": []}
|
||||
if t.TYPE_CHECKING:
|
||||
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
def save_pretrained(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> None: ...
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ...
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": [], "constants": []}
|
||||
__all__ = ["ggml", "transformers", "constants", "load_tokenizer", *_extras]
|
||||
def __dir__() -> list[str]: return sorted(__all__)
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
|
||||
# NOTE: vllm will use PyTorch implementation of transformers for serialisation
|
||||
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")
|
||||
}
|
||||
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), "vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")}
|
||||
HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token"]
|
||||
|
||||
@@ -4,18 +4,13 @@ This requires ctransformers to be installed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
import bentoml, openllm
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._llm import M
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import M
|
||||
|
||||
_conversion_strategy = {"pt": "ggml"}
|
||||
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model: raise NotImplementedError("Currently work in progress.")
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
@@ -35,14 +30,5 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
|
||||
if auto_import:
|
||||
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise
|
||||
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
"""Load the model from BentoML store.
|
||||
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
|
||||
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: raise NotImplementedError("Currently work in progress.")
|
||||
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None: raise NotImplementedError("Currently work in progress.")
|
||||
|
||||
@@ -1,19 +1,12 @@
|
||||
# mypy: disable-error-code="name-defined,misc"
|
||||
"""Serialisation related implementation for Transformers-based implementation."""
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import importlib, logging, typing as t
|
||||
import bentoml, openllm
|
||||
from huggingface_hub import snapshot_download
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelOptions
|
||||
from openllm.serialisation.transformers.weights import HfIgnore
|
||||
|
||||
from .weights import HfIgnore
|
||||
from ._helpers import (
|
||||
check_unintialised_params,
|
||||
infer_autoclass_from_llm,
|
||||
@@ -26,16 +19,16 @@ from ._helpers import (
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
|
||||
import vllm, auto_gptq as autogptq, transformers ,torch
|
||||
import torch.nn
|
||||
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm._llm import M, T
|
||||
from openllm._types import DictStrAny
|
||||
|
||||
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
|
||||
autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
from openllm._typing_compat import DictStrAny, M, T
|
||||
else:
|
||||
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
|
||||
autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,25 +1,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import copy, typing as t, openllm
|
||||
from bentoml._internal.models.model import ModelInfo, ModelSignature
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, HUB_ATTRS
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
import transformers
|
||||
import torch, transformers, bentoml
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.models.model import ModelSignaturesType
|
||||
|
||||
from ..._llm import M, T
|
||||
from ..._types import DictStrAny
|
||||
else:
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
from openllm._typing_compat import DictStrAny, M, T
|
||||
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@@ -1,22 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import typing as t, attr
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
from openllm._llm import M, T
|
||||
from openllm._typing_compat import M, T
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool: return any(s.rfilename.endswith(".safetensors") for s in HfApi().model_info(model_id, revision=revision).siblings)
|
||||
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
safetensors = "*.safetensors"
|
||||
pt = "*.bin"
|
||||
tf = "*.h5"
|
||||
flax = "*.msgpack"
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_implementation__ == "vllm": base = [cls.tf, cls.flax, cls.safetensors]
|
||||
|
||||
@@ -1,37 +1,24 @@
|
||||
"""Tests utilities for OpenLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ._configuration import LiteralRuntime
|
||||
import contextlib, logging, shutil, subprocess, typing as t, bentoml, openllm
|
||||
if t.TYPE_CHECKING: from ._typing_compat import LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False,) -> t.Iterator[bentoml.Bento]:
|
||||
def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False) -> t.Iterator[bentoml.Bento]:
|
||||
logger.info("Building BentoML for %s", model)
|
||||
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
|
||||
yield bento
|
||||
if cleanup:
|
||||
logger.info("Deleting %s", bento.tag)
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any,) -> t.Iterator[str]:
|
||||
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
|
||||
else: bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None: image_tag = str(bento_tag)
|
||||
executable = shutil.which("docker")
|
||||
if not executable: raise RuntimeError("docker executable not found")
|
||||
|
||||
try:
|
||||
logger.info("Building container for %s", bento_tag)
|
||||
bentoml.container.build(bento_tag, backend="docker", image_tag=(image_tag,), progress="plain", **attrs,)
|
||||
@@ -40,19 +27,15 @@ def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | N
|
||||
if cleanup:
|
||||
logger.info("Deleting container %s", image_tag)
|
||||
subprocess.check_output([executable, "rmi", "-f", image_tag])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True,) -> t.Iterator[str]:
|
||||
def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True) -> t.Iterator[str]:
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
|
||||
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
|
||||
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
|
||||
|
||||
if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
|
||||
else: bento = bentoml.get(bento_tag)
|
||||
|
||||
container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_")
|
||||
if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
yield container_name
|
||||
|
||||
@@ -4,19 +4,9 @@ User can import these function for convenience, but
|
||||
we won't ensure backward compatibility for these functions. So use with caution.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import typing as t
|
||||
import contextlib, functools, hashlib, logging, logging.config, os, sys, types, typing as t, openllm
|
||||
from pathlib import Path
|
||||
|
||||
from circus.exc import ConflictError
|
||||
|
||||
from bentoml._internal.configuration import (
|
||||
DEBUG_ENV_VAR as DEBUG_ENV_VAR,
|
||||
GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR,
|
||||
@@ -41,20 +31,14 @@ from openllm.utils.lazy import (
|
||||
VersionInfo as VersionInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._typing_compat import AnyCallable, LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try: from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
except ImportError: _TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
|
||||
else: _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # _GenericAlias is the actual GenericAlias implementation
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11): from typing import overload as _overload
|
||||
else: from typing_extensions import overload as _overload
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
|
||||
from .._types import AnyCallable, LiteralRuntime
|
||||
|
||||
DEV_DEBUG_VAR = "OPENLLMDEVDEBUG"
|
||||
|
||||
@@ -194,9 +178,7 @@ def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable:
|
||||
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
return functools.reduce(compose_two, funcs)
|
||||
|
||||
def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
|
||||
@@ -241,7 +223,6 @@ def resolve_filepath(path: str, ctx: str | None = None) -> str:
|
||||
def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
|
||||
def generate_context(framework_name: str) -> _ModelContext:
|
||||
import openllm
|
||||
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
|
||||
if openllm.utils.is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
|
||||
if openllm.utils.is_tf_available():
|
||||
@@ -261,14 +242,6 @@ def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[dict[str, t
|
||||
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
|
||||
return attrs, tokenizer_attrs
|
||||
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]: ...
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]: ...
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]: ...
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]: ...
|
||||
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM] | type[openllm.AutoTFLLM] | type[openllm.AutoFlaxLLM] | type[openllm.AutoVLLM]:
|
||||
import openllm
|
||||
if implementation == "tf": return openllm.AutoTFLLM
|
||||
|
||||
@@ -3,33 +3,17 @@
|
||||
Users can disable this with OPENLLM_DO_NOT_TRACK envvar.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
import openllm
|
||||
import contextlib, functools, logging, os, re, typing as t, importlib.metadata
|
||||
import attr, openllm
|
||||
from bentoml._internal.utils import analytics as _internal_analytics
|
||||
|
||||
if sys.version_info[:2] >= (3, 10):
|
||||
from typing import ParamSpec
|
||||
else:
|
||||
from typing_extensions import ParamSpec
|
||||
from openllm._typing_compat import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = t.TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
|
||||
OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK"
|
||||
|
||||
DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper()
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
@@ -75,18 +59,15 @@ class EventMeta:
|
||||
class ModelSaveEvent(EventMeta):
|
||||
module: str
|
||||
model_size_in_kb: float
|
||||
|
||||
@attr.define
|
||||
class OpenllmCliEvent(EventMeta):
|
||||
cmd_group: str
|
||||
cmd_name: str
|
||||
openllm_version: str = importlib.metadata.version("openllm")
|
||||
|
||||
# NOTE: reserved for the do_not_track logics
|
||||
duration_in_ms: t.Any = attr.field(default=None)
|
||||
error_type: str = attr.field(default=None)
|
||||
return_code: int = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class StartInitEvent(EventMeta):
|
||||
model_name: str
|
||||
|
||||
@@ -1,22 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import linecache
|
||||
import logging
|
||||
import string
|
||||
import types
|
||||
import typing as t
|
||||
import functools, inspect, linecache, os, logging, string, types, typing as t
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
|
||||
import openllm
|
||||
|
||||
from .._types import AnyCallable, DictStrAny, ListStr
|
||||
from openllm._typing_compat import LiteralString, AnyCallable, DictStrAny, ListStr
|
||||
PartialAny = functools.partial[t.Any]
|
||||
|
||||
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
|
||||
@@ -24,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
OPENLLM_MODEL_NAME = "# openllm: model name"
|
||||
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: t.LiteralString = "__model_name__"
|
||||
model_keyword: LiteralString = "__model_name__"
|
||||
def __init__(self, model_name: str):
|
||||
"""The formatter that extends model_name to be formatted the 'service.py'."""
|
||||
super().__init__()
|
||||
@@ -36,14 +28,13 @@ class ModelNameFormatter(string.Formatter):
|
||||
return True
|
||||
except ValueError: return False
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: t.LiteralString = "__model_id__"
|
||||
model_keyword: LiteralString = "__model_id__"
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: t.LiteralString = "__model_adapter_map__"
|
||||
model_keyword: LiteralString = "__model_adapter_map__"
|
||||
|
||||
_service_file = Path(__file__).parent.parent / "_service.py"
|
||||
_service_file = Path(os.path.abspath("__file__")).parent.parent/"_service.py"
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
from . import DEBUG
|
||||
|
||||
from openllm.utils import DEBUG
|
||||
model_name = llm.config["model_name"]
|
||||
logger.debug("Generating service file for %s at %s (dir=%s)", model_name, llm.config["service_name"], llm_fs.getsyspath("/"))
|
||||
with open(_service_file.__fspath__(), "r") as f: src_contents = f.readlines()
|
||||
@@ -119,33 +110,26 @@ def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.
|
||||
return globs[attr_class_name]
|
||||
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None,) -> AnyCallable:
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None) -> AnyCallable:
|
||||
from openllm.utils import SHOW_CODEGEN
|
||||
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations: meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
return meth
|
||||
|
||||
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: t.LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
|
||||
from . import dantic, field_env_key
|
||||
|
||||
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
|
||||
from openllm.utils import dantic, field_env_key
|
||||
def identity(_: str, x_value: t.Any) -> t.Any: return x_value
|
||||
default_callback = identity if default_callback is None else default_callback
|
||||
|
||||
globs = {} if globs is None else globs
|
||||
globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,})
|
||||
|
||||
lines: ListStr = ["__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },", " )", " for f in fields", "]"]
|
||||
fields_ann = "list[attr.Attribute[t.Any]]"
|
||||
|
||||
return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann})
|
||||
|
||||
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
"""Enhance sdk with nice repr that plays well with your brain."""
|
||||
from .representation import ReprMixin
|
||||
|
||||
from openllm.utils import ReprMixin
|
||||
if name is None: name = func.__name__.strip("_")
|
||||
_signatures = inspect.signature(func).parameters
|
||||
def _repr(self: ReprMixin) -> str: return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
|
||||
@@ -153,3 +137,5 @@ def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
|
||||
else: doc = func.__doc__
|
||||
return t.cast(_T, functools.update_wrapper(types.new_class(name, (t.cast("PartialAny", functools.partial), ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,))
|
||||
|
||||
__all__ = ["gen_sdk", "make_attr_tuple_class", "make_env_transformer", "generate_unique_filename", "generate_function", "OPENLLM_MODEL_NAME", "OPENLLM_MODEL_ADAPTER_MAP"]
|
||||
|
||||
@@ -1,30 +1,22 @@
|
||||
"""An interface provides the best of pydantic and attrs."""
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
import functools, importlib, os, sys, typing as t
|
||||
from enum import Enum
|
||||
|
||||
import attr
|
||||
import click
|
||||
import click_option_group as cog
|
||||
import inflection
|
||||
import orjson
|
||||
import attr, click, click_option_group as cog, inflection, orjson
|
||||
from click import (
|
||||
ParamType,
|
||||
shell_completion as sc,
|
||||
types as click_types,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import _ValidatorType
|
||||
if t.TYPE_CHECKING: from attr import _ValidatorType
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
|
||||
|
||||
__all__ = ["FC", "attrs_to_options", "Field", "parse_type", "is_typing", "is_literal", "ModuleType", "EnumChoice", "LiteralChoice", "allows_multiple", "is_mapping", "is_container", "parse_container_args", "parse_single_arg", "CUDA", "JsonType", "BytesType"]
|
||||
def __dir__() -> list[str]: return sorted(__all__)
|
||||
|
||||
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
@@ -47,13 +39,11 @@ def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
if env is not None:
|
||||
value = os.environ.get(env, value)
|
||||
if value is not None and isinstance(value, str):
|
||||
try:
|
||||
return orjson.loads(value.lower())
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
try: return orjson.loads(value.lower())
|
||||
except orjson.JSONDecodeError as err: raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
return value
|
||||
|
||||
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[_T] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any,) -> t.Any:
|
||||
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[t.Any] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any) -> t.Any:
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
|
||||
|
||||
By default, if both validator and ge are provided, then then ge will be
|
||||
|
||||
@@ -1,34 +1,16 @@
|
||||
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
from abc import ABCMeta
|
||||
import importlib, importlib.metadata, importlib.util, logging, os, abc, typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
import inflection
|
||||
from packaging import version
|
||||
|
||||
import inflection, packaging.version
|
||||
from bentoml._internal.utils import LazyLoader, pkg
|
||||
from openllm._typing_compat import overload, LiteralString
|
||||
|
||||
from .representation import ReprMixin
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
|
||||
from .._types import LiteralRuntime
|
||||
else:
|
||||
BackendOrderredDict = OrderedDict
|
||||
BackendOrderedDict = OrderedDict[str, t.Tuple[t.Callable[[], bool], str]]
|
||||
from openllm._typing_compat import LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",}
|
||||
@@ -104,10 +86,10 @@ def is_tf_available() -> bool:
|
||||
try:
|
||||
_tf_version = importlib.metadata.version(_pkg)
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError: pass
|
||||
except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
_tf_available = _tf_version is not None
|
||||
if _tf_available:
|
||||
if _tf_version and version.parse(_tf_version) < version.parse("2"):
|
||||
if _tf_version and packaging.version.parse(_tf_version) < packaging.version.parse("2"):
|
||||
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
|
||||
_tf_available = False
|
||||
else:
|
||||
@@ -232,11 +214,13 @@ You can install it with pip: `pip install fairscale`. Please note that you may n
|
||||
your runtime after installation.
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = BackendOrderredDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
|
||||
("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
||||
("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))])
|
||||
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
|
||||
("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
||||
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
||||
("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))])
|
||||
|
||||
class DummyMetaclass(ABCMeta):
|
||||
class DummyMetaclass(abc.ABCMeta):
|
||||
"""Metaclass for dummy object.
|
||||
|
||||
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
|
||||
@@ -258,19 +242,17 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
|
||||
if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
|
||||
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
failed = [msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()]
|
||||
if failed: raise ImportError("".join(failed))
|
||||
|
||||
class EnvVarMixin(ReprMixin):
|
||||
model_name: str
|
||||
if t.TYPE_CHECKING:
|
||||
config: str
|
||||
model_id: str
|
||||
quantize: str
|
||||
framework: str
|
||||
bettertransformer: str
|
||||
runtime: str
|
||||
config: str
|
||||
model_id: str
|
||||
quantize: str
|
||||
framework: str
|
||||
bettertransformer: str
|
||||
runtime: str
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["config"]) -> str: ...
|
||||
@overload
|
||||
@@ -297,9 +279,9 @@ class EnvVarMixin(ReprMixin):
|
||||
if item.endswith("_value") and hasattr(self, f"_{item}"): return object.__getattribute__(self, f"_{item}")()
|
||||
elif hasattr(self, item): return getattr(self, item)
|
||||
raise KeyError(f"Key {item} not found in {self}")
|
||||
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
|
||||
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
|
||||
"""EnvVarMixin is a mixin class that returns the value extracted from environment variables."""
|
||||
from .._configuration import field_env_key
|
||||
from openllm._configuration import field_env_key
|
||||
self.model_name = inflection.underscore(model_name)
|
||||
self._implementation = implementation
|
||||
self._model_id = model_id
|
||||
|
||||
@@ -1,19 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
import typing as t
|
||||
import warnings
|
||||
|
||||
import attr
|
||||
|
||||
import openllm
|
||||
import functools, importlib, importlib.machinery, importlib.metadata, importlib.util, itertools, os, time, types, warnings, typing as t
|
||||
import attr, openllm
|
||||
|
||||
__all__ = ["VersionInfo", "LazyModule"]
|
||||
# vendorred from attrs
|
||||
@@ -68,7 +55,7 @@ class LazyModule(types.ModuleType):
|
||||
for key, values in import_structure.items():
|
||||
for value in values: self._class_to_module[value] = key
|
||||
# Needed for autocompletion in an IDE
|
||||
self.__all__ = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
|
||||
self.__all__: list[str] = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
|
||||
self.__file__ = module_file
|
||||
self.__spec__ = module_spec or importlib.util.find_spec(name)
|
||||
self.__path__ = [os.path.dirname(module_file)]
|
||||
@@ -93,7 +80,7 @@ class LazyModule(types.ModuleType):
|
||||
if name in dunder_to_metadata:
|
||||
if name not in {"__version_info__", "__copyright__", "__version__"}: warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2)
|
||||
meta = importlib.metadata.metadata("openllm")
|
||||
project_url = dict(url.split(", ") for url in meta.get_all("Project-URL"))
|
||||
project_url = dict(url.split(", ") for url in t.cast(t.List[str], meta.get_all("Project-URL")))
|
||||
if name == "__license__": return "Apache-2.0"
|
||||
elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
|
||||
elif name in ("__uri__", "__url__"): return project_url["GitHub"]
|
||||
|
||||
@@ -1,57 +1,32 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from abc import abstractmethod
|
||||
import attr, orjson
|
||||
from openllm import utils
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import TypeAlias
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
|
||||
|
||||
ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None]
|
||||
class ReprMixin:
|
||||
"""This class display possible representation of given class.
|
||||
|
||||
It can be used for implementing __rich_pretty__ and __pretty__ methods in the future.
|
||||
Most subclass needs to implement a __repr_keys__ property.
|
||||
|
||||
Based on the design from Pydantic.
|
||||
The __repr__ will display the json representation of the object for easier interaction.
|
||||
The __str__ will display either __attrs_repr__ or __repr_str__.
|
||||
"""
|
||||
@property
|
||||
@abstractmethod
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
"""This can be overriden by base class using this mixin."""
|
||||
def __repr_keys__(self) -> set[str]: raise NotImplementedError
|
||||
"""This can be overriden by base class using this mixin."""
|
||||
def __repr__(self) -> str: return f"{self.__class__.__name__} {orjson.dumps({k: utils.bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}"
|
||||
"""The `__repr__` for any subclass of Mixin.
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""The `__repr__` for any subclass of Mixin.
|
||||
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
|
||||
"""
|
||||
def __str__(self) -> str: return self.__repr_str__(" ")
|
||||
"""The string representation of the given Mixin subclass.
|
||||
|
||||
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
|
||||
"""
|
||||
from . import bentoml_cattr
|
||||
It will contains all of the attributes from __repr_keys__
|
||||
"""
|
||||
def __repr_name__(self) -> str: return self.__class__.__name__
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
|
||||
"""To be used with __str__."""
|
||||
def __repr_args__(self) -> ReprArgs: return ((k, getattr(self, k)) for k in self.__repr_keys__)
|
||||
"""This can also be overriden by base class using this mixin.
|
||||
|
||||
serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}
|
||||
return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""The string representation of the given Mixin subclass.
|
||||
|
||||
It will contains all of the attributes from __repr_keys__
|
||||
"""
|
||||
return self.__repr_str__(" ")
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
"""To be used with __str__."""
|
||||
return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
"""This can also be overriden by base class using this mixin.
|
||||
|
||||
By default it does a getattr of the current object from __repr_keys__.
|
||||
"""
|
||||
return ((k, getattr(self, k)) for k in self.__repr_keys__)
|
||||
By default it does a getattr of the current object from __repr_keys__.
|
||||
"""
|
||||
|
||||
@@ -1,16 +1,3 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -1,17 +1,3 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user