feat(test): snapshot testing (#107)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-07-10 17:23:19 -04:00
committed by GitHub
parent d3e4b95e84
commit c7f4dc7bb2
205 changed files with 11633 additions and 2349 deletions

2
.gitattributes vendored
View File

@@ -1,3 +1,5 @@
nightly-requirements.txt linguist-generated=true
nightly-requirements-gpu.txt linguist-generated=true
tests/models/__snapshots__/* linguist-generated=true
typings/**/*.pyi linguist-generated=true
* text=auto eol=lf

View File

@@ -31,10 +31,6 @@ runs:
with:
python-version: ${{ inputs.python-version }}
architecture: ${{ inputs.architecture }}
- name: Setup node
uses: actions/setup-node@v3
with:
node-version: '17'
- name: Get cache key prefix
id: get-cache-key-prefix
shell: bash
@@ -54,10 +50,3 @@ runs:
- name: Install dependencies
shell: bash
run: pip install hatch towncrier
- name: Install pyright
shell: bash
run: npm install -g npm@^7 pyright
- name: Setup bufbuild/buf
uses: bufbuild/buf-setup-action@v1.20.0
with:
github_token: ${{ github.token }}

View File

@@ -29,6 +29,20 @@ defaults:
run:
shell: bash --noprofile --norc -exo pipefail {0}
jobs:
quality:
runs-on: ubuntu-latest
if: github.event_name == 'pull_request'
name: quality-check
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup CI
uses: ./.github/actions/setup-repo
with:
python-version: ${{ env.STABLE_PYTHON_VERSION }}
- name: Run type check
run: hatch run typing
tests:
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request' || github.event_name == 'push' }}
@@ -47,7 +61,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Run tests
run: hatch run full
run: hatch run tests:python
- name: Disambiguate coverage filename
run: mv .coverage ".coverage.${{ matrix.os }}.${{ matrix.python-version }}"
- name: Upload coverage data
@@ -99,6 +113,7 @@ jobs:
needs:
- coverage
- tests
- quality
runs-on: ubuntu-latest
steps:
- name: Decide whether the needed jobs succeeded or failed

View File

@@ -1,46 +0,0 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: cache-cleanup
on:
pull_request:
types:
- closed
jobs:
cleanup:
runs-on: ubuntu-latest
if: github.repository_owner == 'bentoml'
steps:
- name: Check out code
uses: actions/checkout@v3
- name: Cleanup
run: |
gh extension install actions/gh-actions-cache
REPO=${{ github.repository }}
BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge"
echo "Fetching list of cache key"
cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 )
## Setting this to not fail the workflow while deleting cache keys.
set +e
echo "Deleting caches..."
for cacheKey in $cacheKeysForPR
do
gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm
done
echo "Done"
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -14,16 +14,16 @@
ci:
autoupdate_schedule: weekly
skip: [check-models-table-update, check-models-table-update, changelog-dry-run]
autofix_commit_msg: "ci: auto fixes from pre-commit.ci\nFor more information, see https://pre-commit.ci"
skip: [check-models-table-update, check-models-table-update, changelog-dry-run, typecheck]
autofix_commit_msg: "ci: auto fixes from pre-commit.ci\n\nFor more information, see https://pre-commit.ci"
autoupdate_commit_msg: 'ci: pre-commit autoupdate [pre-commit.ci]'
exclude: '.*\.(css|js|svg)$'
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.0.275'
rev: 'v0.0.277'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
args: [--exit-non-zero-on-fix, --show-fixes]
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
@@ -37,6 +37,13 @@ repos:
args: [--config=pyproject.toml]
- repo: local
hooks:
- id: typecheck
name: type-check
entry: pyright src/openllm --level error
types: [python]
language: node
pass_filenames: false
additional_dependencies: ['pyright@1.1.316']
- id: check-license-header
name: check for license headers
entry: ./tools/assert-license-headers
@@ -69,3 +76,10 @@ repos:
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude: |
(?x)^(
tests/models/.*
)$
- id: check-yaml
args: ['--unsafe']
- id: check-toml

View File

@@ -100,28 +100,25 @@ After setting up your environment, here's how you can start contributing:
3. Run all formatter and linter with `hatch`:
```bash
hatch run fmt
hatch run quality
```
4. Write tests that verify your feature or fix (see
[Writing Tests](#writing-tests) below).
5. Run all tests to ensure your changes haven't broken anything:
```bash
hatch run full
hatch run tests:python
```
6. Commit your changes:
```bash
git commit -m "Add my feature"
```
7. Push your changes to your fork:
```bash
git push origin feature/my-feature
```
8. Submit a Pull Request on GitHub.
## Using a custom fork
@@ -141,7 +138,13 @@ directory and their filenames start with `test_`.
Run all tests with:
```bash
hatch run full
hatch run tests:python
```
Run snapshot testing for model outputs:
```bash
hatch run tests:models
```
## Releasing a New Version

View File

@@ -11,6 +11,9 @@
</a><a href="https://l.bentoml.com/join-openllm-discord">
<img src="https://badgen.net/badge/icon/OpenLLM/7289da?icon=discord&label=Join%20Us" alt="Discord" />
</a><br>
</a><a href="https://pdm.fming.dev">
<img src="https://img.shields.io/badge/pdm-managed-blueviolet" alt="PDM" />
</a><br>
<p>An open platform for operating large language models (LLMs) in production.</br>
Fine-tune, serve, deploy, and monitor any LLMs with ease.</p>
<i></i>

26
changelog.d/107.fix.md Normal file
View File

@@ -0,0 +1,26 @@
Fixes relative model_id handling for running LLM within the container.
Added support for building container directly with `openllm build`. Users now
can do `openllm build --format=container`:
```bash
openllm build flan-t5 --format=container
```
This is equivalent to:
```bash
openllm build flan-t5 && bentoml containerize google-flan-t5-large-service
```
Added Snapshot testing and more robust edge cases for model testing
General improvement in `openllm.LLM.import_model` where it will parse santised
parameters automatically.
Fixes `openllm start <bento>` to use correct `model_id`, ignoring `--model-id`
(The correct behaviour)
Fixes `--workers-per-resource conserved` to respect `--device`
Added initial interface for `LLM.embeddings`

View File

@@ -13,9 +13,6 @@
# limitations under the License.
from __future__ import annotations
import subprocess
import sys
import typing as t
from langchain.chains import LLMChain
@@ -36,11 +33,9 @@ class Query(BaseModel):
def gen_llm(model_name: str, model_id: str | None = None) -> OpenLLM:
args = [sys.executable, "-m", "openllm", "download", model_name]
if model_id:
args += ["--model-id", model_id]
subprocess.check_output(args)
return OpenLLM(model_name=model_name, model_id=model_id, embedded=False)
lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False)
lc_llm.runner.download_model()
return lc_llm
llm = gen_llm("dolly-v2", model_id="databricks/dolly-v2-7b")

View File

@@ -8,8 +8,6 @@ dependencies = [
"tomlkit",
# NOTE: Using under ./tools/update-readme.py
"markdown-it-py",
# NOTE: pyright for type
"pyright",
# NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy
"coverage[toml]>=6.5",
"filelock>=3.7.1",
@@ -26,25 +24,28 @@ dependencies = [
]
features = ['flan-t5']
[envs.default.scripts]
_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml"
changelog = "towncrier build --version main --draft"
fmt = ["tools", "pre-commit run --all-files"]
full = "_run_script --reruns 5 --reruns-delay 3 -r aR {args:tests}"
setup = "pre-commit install"
tools = [
quality = [
"./tools/update-readme.py",
"./tools/update-optional-dependencies.py",
"./tools/update-config-stubs.py",
"./tools/update-models-import.py",
"- ./tools/add-license-headers .",
"pre-commit run --all-files",
]
typing = "pyright {args:src/openllm tests}"
[envs.test.overrides]
setup = "pre-commit install"
typing = "pre-commit run typecheck --all-files"
[envs.tests]
extra-dependencies = [
# NOTE: interact with docker for container tests.
"docker",
]
[envs.tests.scripts]
_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml"
models = "_run_script -r aR {args:tests/models}"
python = "_run_script --reruns 5 --reruns-delay 3 --ignore tests/models -n 3 -r aR {args:tests}"
[envs.tests.overrides]
env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT="
env.HERMETIC_TESTS.type = [{ value = "container", if = ["true"] }, "virtual"]
[envs.test.scripts]
[[envs.test.matrix]]
python = ["3.8", "3.9", "3.10", "3.11"]
[envs.coverage]
dependencies = ["coverage[toml]>=6.5", "lxml", "orjson"]
detached = true

View File

@@ -72,12 +72,12 @@ all = [
"openllm[falcon]",
"openllm[mpt]",
"openllm[starcoder]",
"openllm[ggml]",
"openllm[playground]",
"openllm[fine-tune]",
"openllm[agents]",
"openllm[flan-t5]",
"openllm[openai]",
"openllm[playground]",
"openllm[flan-t5]",
"openllm[agents]",
"openllm[ggml]",
]
chatglm = ["cpm-kernels", "sentencepiece"]
falcon = ["einops", "xformers", "safetensors"]
@@ -155,7 +155,7 @@ verbose = 2
whitelist-regex = ["test_.*"]
[tool.pytest.ini_options]
addopts = ["-rfEX", "-pno:warnings"]
addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"]
python_files = ["test_*.py", "*_test.py"]
testpaths = ["tests"]
@@ -183,70 +183,125 @@ line-length = 119
target-version = ["py311"]
[tool.ruff]
exclude = ["tools"]
exclude = ["tools", "src/openllm/playground"]
extend-select = [
"B", # flake8-bugbear
"I", # isort
"G", # flake8-logging-format
"D", # pydocstyle
"W", # pycodestyle
"Q", # flake8-quotes
"FA", # flake8-future-annotations
"S", # flake8-bandit
"TCH", # flake8-type-checking
"PLW", # pylint-warning
"PLR", # pylint-refactor
"PT", # flake8-pytest-style
"PYI", # flake8-pyi
"PERF", # perflint
"FLY", # flynt
"RUF", # Ruff-specific rules
"YTT", # flake8-2020
]
fix = true
ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Ignore checks for possible passwords
"S105",
"B027", # Allow non-abstract empty methods in abstract base classes
"FBT003", # Allow boolean positional values in function calls, like `dict.get(... True)`
"S105", # Ignore checks for possible passwords
"S106",
"S107",
# Ignore complexity
"C901",
"S603", # ignore subprocess.call
"PLR0911",
"PLR0912",
"PLR0913",
"PLR0915",
"E501",
"E741",
"PLR2004", # magic value to use constant
"E501", # ignore line length violation
"PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs.
"D103", # Just missing docstring for magic methods.
"D102",
"D101",
"D100",
"TCH004", # don't move runtime import out, just warn about it
"RUF012", # mutable attributes to be used with ClassVar
"B905", # zip warning about strict, only applicable for 3.10+
]
line-length = 119
target-version = "py311"
target-version = "py312"
unfixable = [
"F401", # Don't touch unused imports, just warn about it.
"F401", # Don't touch unused imports, just warn about it.
"TCH004", # Don't touch import outside of TYPE_CHECKING block
]
[tool.ruff.flake8-type-checking]
exempt-modules = ["typing", "typing_extensions", "."]
runtime-evaluated-base-classes = [
"pydantic.BaseModel",
"openllm._configuration.LLMConfig",
"openllm._configuration.GenerationConfig",
"openllm._configuration.ModelSettings",
]
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen"]
[tool.ruff.pydocstyle]
convention = "google"
[tool.ruff.pycodestyle]
ignore-overlong-task-comments = true
[tool.ruff.isort]
force-single-line = true
known-first-party = ["openllm", "bentoml", 'transformers']
lines-after-imports = 2
no-lines-before = ["future", "standard-library"]
relative-imports-order = "closest-to-furthest"
[tool.ruff.per-file-ignores]
[tool.ruff.flake8-quotes]
avoid-escape = false
[tool.ruff.extend-per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"__init__.py" = ["E402", "F401", "F403", "F811"]
"examples/**/*" = ["D"]
"src/openllm/_llm.py" = ["B010", "B009"]
"src/openllm/_strategies.py" = ["B904"]
"src/openllm/_types.py" = ["E402"]
"src/openllm/playground/**/*" = ["E402", "F401"]
"tests/**/*" = ["PLR2004", "S101", "TID252"]
"src/openllm/cli.py" = ["D301", "S101"]
"src/openllm/models/**/*" = ["D106", "S101", "D104"]
"src/openllm/playground/**/*" = ["E402", "F401", "PLR", "D"]
"src/openllm/utils/dummy_*" = ["D107"]
"src/openllm/utils/import_utils.py" = [
"PLW0603", # OK to ignore global access here
"D105", # magic docstring
]
"src/openllm_client/runtimes/*" = ["D107"]
"tests/**/*" = [
"S101",
"TID252",
"D", # No docstring in tests
"PT011", # ignore too broad raises, as it can be use pytest.raises().match()
"S307", # Ignore eval(compile) as it is a known script execution
]
"typings/**/*" = ["D", "F", "E", "PYI002", "I001"]
[tool.pyright]
analysis.useLibraryCodeForTypes = true
enableTypeIgnoreComments = true
include = ["src/", "tests/", "tools/", "examples/"]
exclude = ["src/openllm/playground", "src/openllm/models/"]
include = ["src/openllm", "src/openllm_client", "tests/", "tools/", "examples/"]
pythonVersion = "3.12"
reportMissingImports = "none"
reportMissingModuleSource = "warning"
reportMissingTypeStubs = "warning"
reportMissingTypeStubs = false
reportPrivateUsage = "warning"
reportUnknownArgumentType = "warning"
reportUnknownMemberType = "warning"
reportUnknownVariableType = "warning"
strictDictionaryInference = true
strictListInference = true
strictParameterNoneValue = true
strictSetInference = true
typeCheckingMode = "strict"
[tool.coverage.run]
branch = true
omit = [
"src/openllm/playground/",
"src/openllm/__about__.py",
"src/openllm/__main__.py",
"src/openllm/tests.py",
"src/openllm/utils/dummy_*.py",
]
source_pkgs = ["openllm"]
@@ -255,4 +310,14 @@ source_pkgs = ["openllm"]
openllm = ["src/openllm", "*/openllm/src/openllm"]
[tool.coverage.report]
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:", "@overload", "# pragma: no cover"]
exclude_lines = [
"no cov",
"pragma: no cover",
"if __name__ == .__main__.:",
"if t.TYPE_CHECKING:",
'if TYPE_CHECKING:',
'if typing.TYPE_CHECKING:',
'@overload',
'@typing.overload',
'raise NotImplementedError',
]

View File

@@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenLLM
=======
"""OpenLLM.
An open platform for operating large language models in production. Fine-tune, serve,
deploy, and monitor any LLMs with ease.
@@ -24,7 +22,6 @@ deploy, and monitor any LLMs with ease.
* Native integration with BentoML and LangChain for custom LLM apps
"""
from __future__ import annotations
import logging
import os
import typing as t
@@ -39,7 +36,6 @@ if utils.DEBUG:
utils.set_debug_mode(True)
utils.set_quiet_mode(False)
utils.configure_logging()
logging.basicConfig(level=logging.NOTSET)
else:
# configuration for bitsandbytes before import
@@ -64,16 +60,15 @@ else:
_import_structure = {
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"],
"_configuration": ["LLMConfig"],
"_package": ["build"],
"exceptions": [],
"_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput"],
"utils": [],
"utils": ["infer_auto_class"],
"models": [],
"client": [],
"playground": [],
"tests": [],
"testing": [],
"serialisation": ["ggml", "transformers"],
"cli": ["start", "start_grpc"],
"cli": ["start", "start_grpc", "build", "import_model", "list_models"],
# NOTE: models
"models.auto": [
"AutoConfig",
@@ -182,23 +177,23 @@ if t.TYPE_CHECKING:
from . import exceptions as exceptions
from . import models as models
from . import playground as playground
from . import tests as tests
from . import serialisation as serialisation
from . import testing as testing
# Specific types import
from ._configuration import LLMConfig as LLMConfig
from ._llm import LLM as LLM
from ._llm import LLMRunner as LLMRunner
from ._llm import LLMRunnable as LLMRunnable
from ._llm import LLMRunner as LLMRunner
from ._llm import Runner as Runner
from ._package import build as build
from ._schema import GenerationInput as GenerationInput
from ._schema import GenerationOutput as GenerationOutput
from ._schema import MetadataOutput as MetadataOutput
from .cli import build as build
from .cli import import_model as import_model
from .cli import list_models as list_models
from .cli import start as start
from .cli import start_grpc as start_grpc
from .serialisation import ggml as ggml
from .serialisation import transformers as transformers
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
@@ -213,6 +208,9 @@ if t.TYPE_CHECKING:
from .models.opt import OPTConfig as OPTConfig
from .models.stablelm import StableLMConfig as StableLMConfig
from .models.starcoder import StarCoderConfig as StarCoderConfig
from .serialisation import ggml as ggml
from .serialisation import transformers as transformers
from .utils import infer_auto_class as infer_auto_class
# NOTE: torch and cpm_kernels
try:
@@ -286,6 +284,7 @@ else:
globals()["__file__"],
_import_structure,
module_spec=__spec__,
doc=__doc__,
extra_objects={
"__version__": __version__,
# The below is a special mapping that allows openllm to be used as a dictionary.

View File

@@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CLI entrypoint for OpenLLM.
"""CLI entrypoint for OpenLLM.
Usage:
openllm --help

View File

@@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``.
"""Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``.
Highlight feature: Each fields in ``openllm.LLMConfig`` will also automatically generate a environment
variable based on its name field.
@@ -47,7 +46,6 @@ dynamically during serve, ahead-of-serve or per requests.
Refer to ``openllm.LLMConfig`` docstring for more information.
"""
from __future__ import annotations
import copy
import enum
import logging
@@ -68,56 +66,48 @@ from deepmerge.merger import Merger
import openllm
from .exceptions import ForbiddenAttributeError
from .utils import ENV_VARS_TRUE_VALUES
from .utils import LazyType
from .utils import ReprMixin, ENV_VARS_TRUE_VALUES
from .utils import ReprMixin
from .utils import bentoml_cattr
from .utils import codegen
from .utils import dantic
from .utils import first_not_none, field_env_key
from .utils import field_env_key
from .utils import first_not_none
from .utils import lenient_issubclass
from .utils import non_intrusive_setattr
from .utils import requires_dependencies
if hasattr(t, "Required"):
from typing import Required
else:
from typing_extensions import Required
if hasattr(t, "NotRequired"):
from typing import NotRequired
else:
from typing_extensions import NotRequired
if hasattr(t, "dataclass_transform"):
from typing import dataclass_transform
else:
from typing_extensions import dataclass_transform
# NOTE: We need to do this so that overload can register
# NOTE: We need to do check overload import
# so that it can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
if sys.version_info[:2] >= (3, 11):
from typing import NotRequired
from typing import Required
from typing import dataclass_transform
from typing import overload
else:
from typing_extensions import NotRequired
from typing_extensions import Required
from typing_extensions import dataclass_transform
from typing_extensions import overload
_T = t.TypeVar("_T")
if t.TYPE_CHECKING:
import click
import peft
from attr import _CountingAttr # type: ignore
from attr import _make_init # type: ignore
from attr import _transform_attrs # type: ignore
from attr import _CountingAttr
from attr import _make_init
from attr import _transform_attrs
from attr._compat import set_closure_cell
import transformers
from transformers.generation.beam_constraints import Constraint
from ._types import ClickFunctionWrapper
from ._types import F
from ._types import O_co
from ._types import P
from ._types import AnyCallable
DictStrAny = dict[str, t.Any]
ListStr = list[str]
@@ -154,10 +144,10 @@ config_merger = Merger(
# case insensitive, but rename to conform with type
class _PeftEnumMeta(enum.EnumMeta):
def __getitem__(self, __key: str | t.Any) -> PeftType:
def __getitem__(self, __key: str | t.Any) -> enum.Enum:
if isinstance(__key, str):
__key = inflection.underscore(__key).upper()
return super().__getitem__(__key)
return self._member_map_[__key]
# vendorred from peft.utils.config.PeftType
@@ -171,11 +161,11 @@ class PeftType(enum.Enum, metaclass=_PeftEnumMeta):
ADAPTION_PROMPT = "ADAPTION_PROMPT"
@classmethod
def _missing_(cls, value: object) -> PeftType | None:
def _missing_(cls, value: object) -> enum.Enum | None:
if isinstance(value, str):
normalized = inflection.underscore(value).upper()
if normalized in cls._member_map_:
return cls[normalized]
return cls._member_map_[normalized]
@classmethod
def supported(cls) -> set[str]:
@@ -184,6 +174,11 @@ class PeftType(enum.Enum, metaclass=_PeftEnumMeta):
def to_str(self) -> str:
return self.value
@staticmethod
def get(__key: str | t.Any) -> PeftType:
"""type-safe getitem."""
return t.cast(PeftType, PeftType[__key])
_PEFT_TASK_TYPE_TARGET_MAPPING = {"causal_lm": "CAUSAL_LM", "seq2seq_lm": "SEQ_2_SEQ_LM"}
@@ -200,14 +195,33 @@ def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType:
raise ValueError("'AdapterType' cannot be None.")
if isinstance(value, PeftType):
return value
if isinstance(value, str) and value not in PeftType.supported():
if value not in PeftType.supported():
raise ValueError(f"Given '{value}' is not a supported adapter type.")
return PeftType[value]
return PeftType.get(value)
@attr.define(slots=True)
class FineTuneConfig:
"""FineTuneConfig defines a default value for fine-tuning this any given LLM. For example:
"""FineTuneConfig defines a default value for fine-tuning this any given LLM.
For example:
```python
class FalconConfig(openllm.LLMConfig):
__config__ = {
"fine_tune_strategies": (
{
"adapter_type": "lora",
"r": 64,
"lora_alpha": 16,
"lora_dropout": 0.1,
"bias": "none",
"target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
},
),
}
```
This is a lower level API that leverage `peft` as well as openllm.LLMConfig to create default
and customization
@@ -318,8 +332,7 @@ class FineTuneConfig:
docs: str | None = None,
**attrs: t.Any,
) -> type[FineTuneConfig]:
"""A loose codegen to create default subclass for given adapter config type"""
"""A loose codegen to create default subclass for given adapter config type."""
_new_default = {
"adapter_type": PeftType[adapter_type],
"adapter_config": attrs,
@@ -355,8 +368,7 @@ class FineTuneConfig:
@attr.frozen(slots=True, repr=False)
class GenerationConfig(ReprMixin):
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
with some additional validation and environment constructor.
"""GenerationConfig is the attrs-compatible version of ``transformers.GenerationConfig``, with some additional validation and environment constructor.
Note that we always set `do_sample=True`. This class is not designed to be used directly, rather
to be used conjunction with LLMConfig. The instance of the generation config can then be accessed
@@ -588,7 +600,7 @@ class GenerationConfig(ReprMixin):
if t.TYPE_CHECKING:
def __attrs_init__(self, **_: t.Any):
def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
...
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
@@ -628,6 +640,7 @@ _object_getattribute = object.__getattribute__
class ModelSettings(t.TypedDict, total=False):
"""ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__.
Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig.
If the field below changes, make sure to run ./tools/update-config-stubs.py to generate correct __getitem__
@@ -728,9 +741,9 @@ class _ModelSettingsAttr:
service_name: str
requirements: t.Optional[ListStr]
bettertransformer: bool
model_type: t.Literal['causal_lm', 'seq2seq_lm']
runtime: t.Literal['transformers', 'ggml']
name_type: t.Optional[t.Literal['dasherize', 'lowercase']]
model_type: t.Literal["causal_lm", "seq2seq_lm"]
runtime: t.Literal["transformers", "ggml"]
name_type: t.Optional[t.Literal["dasherize", "lowercase"]]
model_name: str
start_name: str
env: openllm.utils.EnvVarMixin
@@ -743,7 +756,6 @@ class _ModelSettingsAttr:
def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]):
assert cl_.__config__ is not None, f"'__config__' is required for {cls}."
if "generation_class" in cl_.__config__:
raise ValueError(
"'generation_class' shouldn't be defined in '__config__', rather defining "
@@ -813,8 +825,8 @@ bentoml_cattr.register_structure_hook(_ModelSettingsAttr, structure_settings)
def _setattr_class(attr_name: str, value_var: t.Any):
"""
Use the builtin setattr to set *attr_name* to *value_var*.
"""Use the builtin setattr to set *attr_name* to *value_var*.
We can't use the cached object.__setattr__ since we are setting
attributes to a class.
@@ -828,7 +840,7 @@ def _setattr_class(attr_name: str, value_var: t.Any):
def _make_assignment_script(
cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: t.LiteralString = "openllm"
) -> t.Callable[..., None]:
"""Generate the assignment script with prefix attributes __openllm_<value>__"""
"""Generate the assignment script with prefix attributes __openllm_<value>__."""
args: ListStr = []
globs: DictStrAny = {
"cls": cls,
@@ -852,13 +864,14 @@ def _make_assignment_script(
_reserved_namespace = {"__config__", "GenerationConfig"}
@dataclass_transform(order_default=True, field_specifiers=(attr.field, dantic.Field))
@dataclass_transform(kw_only_default=True, order_default=True, field_specifiers=(attr.field, dantic.Field))
def llm_config_transform(cls: type[LLMConfig]) -> type[LLMConfig]:
non_intrusive_setattr(
cls,
"__dataclass_transform__",
{
"order_default": True,
"kw_only_default": True,
"field_specifiers": (attr.field, dantic.Field),
},
)
@@ -880,7 +893,7 @@ class _ConfigAttr:
# NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
if t.TYPE_CHECKING:
# NOTE: public attributes to override
__config__: ModelSettings | None = Field(None)
__config__: ModelSettings = Field(None)
"""Internal configuration for this LLM model. Each of the field in here will be populated
and prefixed with __openllm_<value>__"""
GenerationConfig: type = Field(None)
@@ -914,7 +927,7 @@ class _ConfigAttr:
to create the generation_config argument that can be used throughout the lifecycle.
This class will also be managed internally by OpenLLM."""
def __attrs_init__(self, **attrs: t.Any):
def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
"""Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract."""
# NOTE: The following will be populated from __config__ and also
@@ -955,14 +968,14 @@ class _ConfigAttr:
architecture. By default, we will use BetterTransformer for T5 and StableLM models,
and set to False for every other models.
"""
__openllm_model_type__: t.Literal['causal_lm', 'seq2seq_lm'] = Field(None)
__openllm_model_type__: t.Literal["causal_lm", "seq2seq_lm"] = Field(None)
"""The model type for this given LLM. By default, it should be causal language modeling.
Currently supported 'causal_lm' or 'seq2seq_lm'
"""
__openllm_runtime__: t.Literal['transformers', 'ggml'] = Field(None)
__openllm_runtime__: t.Literal["transformers", "ggml"] = Field(None)
"""The runtime to use for this model. Possible values are `transformers` or `ggml`. See
LlaMA for more information."""
__openllm_name_type__: t.Optional[t.Literal['dasherize', 'lowercase']] = Field(None)
__openllm_name_type__: t.Optional[t.Literal["dasherize", "lowercase"]] = Field(None)
"""The default name typed for this model. "dasherize" will convert the name to lowercase and
replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both
`model_name` and `start_name` must be specified."""
@@ -991,11 +1004,187 @@ class _ConfigAttr:
# fmt: on
@attr.define(slots=True)
class LLMConfig(_ConfigAttr):
class _ConfigBuilder:
"""A modified version of attrs internal _ClassBuilder, and should only be called within __init_subclass__ of LLMConfig.
Where:
- has_custom_setattr=True
- getstate_setstate=None (config class will always be a slotted class.)
- slots=True
- auto_attribs=False (We should handle it before _ConfigBuilder is invoked)
- cache_hash=False (We don't need to cache the hash code of this object for now.)
- collect_by_mro=True (The correct behaviour to resolve inheritance)
- field_transformer=codegen.make_env_transformer (We need to transform the field to have env variable)
It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__
"""
``openllm.LLMConfig`` is somewhat a hybrid combination between the performance of `attrs` with the
easy-to-use interface that pydantic offer. It lives in between where it allows users to quickly formulate
__slots__ = (
"_cls",
"_cls_dict",
"_attr_names",
"_attrs",
"_model_name",
"_base_attr_map",
"_base_names",
"_has_pre_init",
"_has_post_init",
)
def __init__(
self,
cls: type[LLMConfig],
these: dict[str, _CountingAttr[t.Any]],
auto_attribs: bool = False,
kw_only: bool = False,
collect_by_mro: bool = True,
):
attrs, base_attrs, base_attr_map = _transform_attrs(
cls,
these,
auto_attribs,
kw_only,
collect_by_mro,
field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__),
)
self._cls = cls
self._model_name = cls.__openllm_model_name__
self._cls_dict = dict(cls.__dict__)
self._attrs = attrs
self._base_names = {a.name for a in base_attrs}
self._base_attr_map = base_attr_map
self._attr_names = tuple(a.name for a in attrs)
self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
self._cls_dict["__attrs_attrs__"] = self._attrs
def build_class(self) -> type[LLMConfig]:
"""Finalize class based on the accumulated configuration.
Builder cannot be used after calling this method.
> A difference between this and attrs._ClassBuilder is that we don't
> create a new class after constructing all __dict__. This has to do
> with recursive called within __init_subclass__
"""
cd = {
k: v for k, v in self._cls_dict.items() if k not in (*tuple(self._attr_names), "__dict__", "__weakref__")
}
# Traverse the MRO to collect existing slots
# and check for an existing __weakref__.
existing_slots: DictStrAny = {}
weakref_inherited = False
for base_cls in self._cls.__mro__[1:-1]:
if base_cls.__dict__.get("__weakref__", None) is not None:
weakref_inherited = True
existing_slots.update(
{name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])}
)
base_names = set(self._base_names)
names = self._attr_names
if (
"__weakref__" not in getattr(self._cls, "__slots__", ())
and "__weakref__" not in names
and not weakref_inherited
):
names += ("__weakref__",)
# We only add the names of attributes that aren't inherited.
# Setting __slots__ to inherited attributes wastes memory.
slot_names = [name for name in names if name not in base_names]
# There are slots for attributes from current class
# that are defined in parent classes.
# As their descriptors may be overridden by a child class,
# we collect them here and update the class dict
reused_slots = {
slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names
}
# We only add the names of attributes that aren't inherited.
# Setting __slots__ to inherited attributes wastes memory.
# __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots
slot_names = [name for name in slot_names if name not in reused_slots]
cd.update(reused_slots)
cd["__slots__"] = tuple(slot_names)
cd["__qualname__"] = self._cls.__qualname__
# We can only patch the class here, rather than instantiate
# a new one, since type.__new__ actually will invoke __init_subclass__
# and since we use the _ConfigBuilder in __init_subclass__, it will
# raise recusion error. See https://peps.python.org/pep-0487/ for more
# information on how __init_subclass__ works.
for k, value in cd.items():
setattr(self._cls, k, value)
return self.make_closure(self._cls)
def make_closure(self, cls: type):
# The following is a fix for
# <https://github.com/python-attrs/attrs/issues/102>.
# If a method mentions `__class__` or uses the no-arg super(), the
# compiler will bake a reference to the class in the method itself
# as `method.__closure__`. Since we replace the class with a
# clone, we rewrite these references so it keeps working.
for item in cls.__dict__.values():
if isinstance(item, (classmethod, staticmethod)):
# Class- and staticmethods hide their functions inside.
# These might need to be rewritten as well.
closure_cells = getattr(item.__func__, "__closure__", None)
elif isinstance(item, property):
# Workaround for property `super()` shortcut (PY3-only).
# There is no universal way for other descriptors.
closure_cells = getattr(item.fget, "__closure__", None)
else:
closure_cells = getattr(item, "__closure__", None)
if not closure_cells: # Catch None or the empty list.
continue
for cell in closure_cells:
try:
match = cell.cell_contents is self._cls
except ValueError: # ValueError: Cell is empty
pass
else:
if match:
set_closure_cell(cell, cls)
return llm_config_transform(cls)
def add_attrs_init(self) -> t.Self:
self._cls_dict["__attrs_init__"] = codegen.add_method_dunders(
self._cls,
_make_init(
self._cls,
self._attrs,
self._has_pre_init,
self._has_post_init,
False, # frozen
True, # slots
False, # cache_hash
self._base_attr_map,
False, # This is not an exception
None, # no on_setattr
True,
),
)
return self
def add_repr(self):
for key, fn in ReprMixin.__dict__.items():
if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"):
self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"})
return self
@attr.define(slots=True, init=False)
class LLMConfig(_ConfigAttr):
"""``openllm.LLMConfig`` is a pydantic-like ``attrs`` interface that offers fast and easy-to-use APIs.
It lives in between the nice UX of `pydantic` and fast performance of `attrs` where it allows users to quickly formulate
a LLMConfig for any LLM without worrying too much about performance. It does a few things:
- Automatic environment conversion: Each fields will automatically be provisioned with an environment
@@ -1077,182 +1266,16 @@ class LLMConfig(_ConfigAttr):
),
}
```
Future work:
- Support pydantic-core as validation backend.
"""
class _ConfigBuilder:
"""A modified version of attrs internal _ClassBuilder, should only be called
within __init_subclass__ of LLMConfig.
Where:
- has_custom_setattr=True
- getstate_setstate=None (config class will always be a slotted class.)
- slots=True
- auto_attribs=False (We should handle it before _ConfigBuilder is invoked)
- cache_hash=False (We don't need to cache the hash code of this object for now.)
- collect_by_mro=True (The correct behaviour to resolve inheritance)
- field_transformer=codegen.make_env_transformer (We need to transform the field to have env variable)
It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__
"""
__slots__ = (
"_cls",
"_cls_dict",
"_attr_names",
"_attrs",
"_model_name",
"_base_attr_map",
"_base_names",
"_has_pre_init",
"_has_post_init",
)
def __init__(
self,
cls: type[LLMConfig],
these: dict[str, _CountingAttr[t.Any]],
auto_attribs: bool = False,
kw_only: bool = False,
collect_by_mro: bool = True,
):
attrs, base_attrs, base_attr_map = _transform_attrs(
cls,
these,
auto_attribs,
kw_only,
collect_by_mro,
field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__),
)
self._cls = cls
self._model_name = cls.__openllm_model_name__
self._cls_dict = dict(cls.__dict__)
self._attrs = attrs
self._base_names = {a.name for a in base_attrs}
self._base_attr_map = base_attr_map
self._attr_names = tuple(a.name for a in attrs)
self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
self._cls_dict["__attrs_attrs__"] = self._attrs
def build_class(self) -> type[LLMConfig]:
"""
Finalize class based on the accumulated configuration.
Builder cannot be used after calling this method.
> A difference between this and attrs._ClassBuilder is that we don't
> create a new class after constructing all __dict__. This has to do
> with recursive called within __init_subclass__
"""
cd = {
k: v
for k, v in self._cls_dict.items()
if k not in tuple(self._attr_names) + ("__dict__", "__weakref__")
}
# Traverse the MRO to collect existing slots
# and check for an existing __weakref__.
existing_slots: DictStrAny = {}
weakref_inherited = False
for base_cls in self._cls.__mro__[1:-1]:
if base_cls.__dict__.get("__weakref__", None) is not None:
weakref_inherited = True
existing_slots.update(
{name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])}
)
base_names = set(self._base_names)
names = self._attr_names
if (
"__weakref__" not in getattr(self._cls, "__slots__", ())
and "__weakref__" not in names
and not weakref_inherited
):
names += ("__weakref__",)
# We only add the names of attributes that aren't inherited.
# Setting __slots__ to inherited attributes wastes memory.
slot_names = [name for name in names if name not in base_names]
# There are slots for attributes from current class
# that are defined in parent classes.
# As their descriptors may be overridden by a child class,
# we collect them here and update the class dict
reused_slots = {
slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names
}
# We only add the names of attributes that aren't inherited.
# Setting __slots__ to inherited attributes wastes memory.
# __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots
slot_names = [name for name in slot_names if name not in reused_slots]
cd.update(reused_slots)
cd["__slots__"] = tuple(slot_names)
for k, value in cd.items():
setattr(self._cls, k, value)
# The following is a fix for
# <https://github.com/python-attrs/attrs/issues/102>.
# If a method mentions `__class__` or uses the no-arg super(), the
# compiler will bake a reference to the class in the method itself
# as `method.__closure__`. Since we replace the class with a
# clone, we rewrite these references so it keeps working.
for item in self._cls.__dict__.values():
if isinstance(item, (classmethod, staticmethod)):
# Class- and staticmethods hide their functions inside.
# These might need to be rewritten as well.
closure_cells = getattr(item.__func__, "__closure__", None)
elif isinstance(item, property):
# Workaround for property `super()` shortcut (PY3-only).
# There is no universal way for other descriptors.
closure_cells = getattr(item.fget, "__closure__", None)
else:
closure_cells = getattr(item, "__closure__", None)
if not closure_cells: # Catch None or the empty list.
continue
for cell in closure_cells:
try:
match = cell.cell_contents is self._cls
except ValueError: # ValueError: Cell is empty
pass
else:
if match:
set_closure_cell(cell, self._cls)
return llm_config_transform(self._cls)
def add_attrs_init(self) -> t.Self:
self._cls_dict["__attrs_init__"] = codegen.add_method_dunders(
self._cls,
_make_init(
self._cls,
self._attrs,
self._has_pre_init,
self._has_post_init,
False, # frozen
True, # slots
False, # cache_hash
self._base_attr_map,
False, # This is not an exception
None, # no on_setattr
attrs_init=True,
),
)
return self
def add_repr(self):
for key, fn in ReprMixin.__dict__.items():
if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"):
self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"})
return self
def __init_subclass__(cls: type[LLMConfig]):
"""The purpose of this __init_subclass__ is that we want all subclass of LLMConfig
to adhere to the attrs contract, and have pydantic-like interface. This means we will
construct all fields and metadata and hack into how attrs use some of the 'magic' construction
to generate the fields.
"""The purpose of this ``__init_subclass__`` is to offer pydantic UX while adhering to attrs contract.
This means we will construct all fields and metadata and hack into
how attrs use some of the 'magic' construction to generate the fields.
It also does a few more extra features: It also generate all __openllm_*__ config from
ModelSettings (derived from __config__) to the class.
@@ -1261,7 +1284,7 @@ class LLMConfig(_ConfigAttr):
logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__)
cls.__name__ = f"{cls.__name__}Config"
if not hasattr(cls, "__config__") or cls.__config__ is None:
if not hasattr(cls, "__config__"):
raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.")
# auto assignment attributes generated from __config__ after create the new slot class.
@@ -1320,7 +1343,7 @@ class LLMConfig(_ConfigAttr):
a.name for a in attr.fields(cls.__openllm_generation_class__)
}
cls = cls._ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class()
cls = _ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class()
# Finally, resolve the types
if getattr(cls, "__attrs_types_resolved__", None) != cls:
@@ -1398,11 +1421,11 @@ class LLMConfig(_ConfigAttr):
@overload
def __getitem__(self, item: t.Literal["bettertransformer"] = ...) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["model_type"] = ...) -> t.Literal['causal_lm', 'seq2seq_lm']: ...
def __getitem__(self, item: t.Literal["model_type"] = ...) -> t.Literal["causal_lm", "seq2seq_lm"]: ...
@overload
def __getitem__(self, item: t.Literal["runtime"] = ...) -> t.Literal['transformers', 'ggml']: ...
def __getitem__(self, item: t.Literal["runtime"] = ...) -> t.Literal["transformers", "ggml"]: ...
@overload
def __getitem__(self, item: t.Literal["name_type"] = ...) -> t.Optional[t.Literal['dasherize', 'lowercase']]: ...
def __getitem__(self, item: t.Literal["name_type"] = ...) -> t.Optional[t.Literal["dasherize", "lowercase"]]: ...
@overload
def __getitem__(self, item: t.Literal["model_name"] = ...) -> str: ...
@overload
@@ -1516,7 +1539,7 @@ class LLMConfig(_ConfigAttr):
# fmt: on
def __getitem__(self, item: t.LiteralString | t.Any = None) -> t.Any:
"""Allowing access LLMConfig as a dictionary. The order will always evaluate as
"""Allowing access LLMConfig as a dictionary. The order will always evaluate as.
__openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__
@@ -1599,7 +1622,6 @@ class LLMConfig(_ConfigAttr):
**attrs: The attributes to be added to the new class. This will override
any existing attributes with the same name.
"""
assert cls.__config__ is not None, "Cannot derivate a LLMConfig without __config__"
_new_cfg = {k: v for k, v in attrs.items() if k in attr.fields_dict(_ModelSettingsAttr)}
attrs = {k: v for k, v in attrs.items() if k not in _new_cfg}
new_cls = types.new_class(
@@ -1642,14 +1664,12 @@ class LLMConfig(_ConfigAttr):
try:
attrs = orjson.loads(json_str)
except orjson.JSONDecodeError as err:
raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}")
raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") from None
return bentoml_cattr.structure(attrs, cls)
@classmethod
def model_construct_env(cls, **attrs: t.Any) -> t.Self:
"""A helpers that respect configuration values that
sets from environment variables for any given configuration class.
"""
"""A helpers that respect configuration values environment variables."""
attrs = {k: v for k, v in attrs.items() if v is not None}
model_config = cls.__openllm_env__.config
@@ -1696,7 +1716,7 @@ class LLMConfig(_ConfigAttr):
return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove}
@overload
def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig:
def to_generation_config(self, return_as_dict: t.Literal[False] = False) -> transformers.GenerationConfig:
...
@overload
@@ -1708,22 +1728,12 @@ class LLMConfig(_ConfigAttr):
return config.to_dict() if return_as_dict else config
@classmethod
@overload
def to_click_options(
cls, f: t.Callable[..., openllm.LLMConfig]
) -> F[P, ClickFunctionWrapper[..., openllm.LLMConfig]]:
...
def to_click_options(cls, f: AnyCallable) -> click.Command:
"""Convert current configuration to click options.
@classmethod
@overload
def to_click_options(cls, f: t.Callable[P, O_co]) -> F[P, ClickFunctionWrapper[P, O_co]]:
...
This can be used as a decorator for click commands.
@classmethod
def to_click_options(cls, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
"""
Convert current model to click options. This can be used as a decorator for click commands.
Note that the identifier for all LLMConfig will be prefixed with '<model_name>_*', and the generation config
> **Note**: that the identifier for all LLMConfig will be prefixed with '<model_name>_*', and the generation config
will be prefixed with '<model_name>_generation_*'.
"""
for name, field in attr.fields_dict(cls.__openllm_generation_class__).items():
@@ -1769,8 +1779,7 @@ bentoml_cattr.register_unstructure_hook_factory(
def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig:
"""
Structure a dictionary to a LLMConfig object.
"""Structure a dictionary to a LLMConfig object.
Essentially, if the given dictionary contains a 'generation_config' key, then we will
use it for LLMConfig.generation_config

View File

@@ -14,14 +14,15 @@
"""Generation utilities to be reused throughout."""
from __future__ import annotations
import typing as t
import torch
import transformers
if t.TYPE_CHECKING:
import torch
class StopSequenceCriteria(transformers.StoppingCriteria):
"""This class used to stop generation when a seq of tokens are met.
@@ -42,6 +43,6 @@ class StopSequenceCriteria(transformers.StoppingCriteria):
class StopOnTokens(transformers.StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool:
stop_ids = {50278, 50279, 50277, 1, 0}
return input_ids[0][-1] in stop_ids
return t.cast(int, input_ids[0][-1]) in stop_ids

View File

File diff suppressed because it is too large Load Diff

View File

@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Any build-related utilities. This is used for CI.
"""Build-related utilities. Some of these utilities are mainly used for 'openllm.build'.
These utilities will stay internal, and its API can be changed or updated without backward-compatibility.
"""
from __future__ import annotations
import importlib.metadata
import logging
import os
import re
import subprocess
import sys
import typing as t
from pathlib import Path
@@ -34,7 +31,6 @@ from simple_di import Provide
from simple_di import inject
import bentoml
import openllm
from bentoml._internal.bento.build_config import BentoBuildConfig
from bentoml._internal.bento.build_config import DockerOptions
from bentoml._internal.bento.build_config import PythonOptions
@@ -56,7 +52,7 @@ from .utils import resolve_user_filepath
if t.TYPE_CHECKING:
from fs.base import FS
from bentoml._internal.bento import BentoStore
import openllm
logger = logging.getLogger(__name__)
@@ -135,7 +131,8 @@ def construct_python_options(
env: EnvVarMixin = llm.config["env"]
framework_envvar = env["framework_value"]
if framework_envvar == "flax":
assert is_flax_available(), f"Flax is not available, while {env.framework} is set to 'flax'"
if not is_flax_available():
raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
packages.extend(
[
handle_package_version("flax", has_dockerfile_template),
@@ -144,7 +141,8 @@ def construct_python_options(
]
)
elif framework_envvar == "tf":
assert is_tf_available(), f"TensorFlow is not available, while {env.framework} is set to 'tf'"
if not is_tf_available():
raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
candidates = (
"tensorflow",
"tensorflow-cpu",
@@ -170,7 +168,8 @@ def construct_python_options(
except importlib.metadata.PackageNotFoundError:
pass
else:
assert is_torch_available(), "PyTorch is not available. Make sure to have it locally installed."
if not is_torch_available():
raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
packages.extend([handle_package_version("torch", has_dockerfile_template)])
wheels: list[str] = []
@@ -206,7 +205,7 @@ def construct_docker_options(
"OPENLLM_MODEL": llm.config["model_name"],
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'",
"BENTOML_DEBUG": str(get_debug_mode()),
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options,
"BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'",
}
if adapter_map:
@@ -257,7 +256,8 @@ def create_bento(
logger.info("Building Bento for '%s'", llm.config["start_name"])
if adapter_map is not None:
assert build_ctx is not None, "build_ctx is required when 'adapter_map' is not None"
if build_ctx is None:
raise ValueError("build_ctx is required when 'adapter_map' is not None")
updated_mapping: dict[str, str | None] = {}
for adapter_id, name in adapter_map.items():
try:
@@ -321,7 +321,7 @@ def create_bento(
# new behaviour with BentoML models
model = _model_store.get(f"{model_framework}-{model_type}")
except bentoml.exceptions.NotFound:
raise OpenLLMException(f"Failed to find models for {llm.config['start_name']}")
raise OpenLLMException(f"Failed to find models for {llm.config['start_name']}") from None
# NOTE: the model_id_path here are only used for setting this environment variable within the container
# built with for BentoLLM.
@@ -330,10 +330,12 @@ def create_bento(
with open(service_path, "r") as f:
service_contents = f.readlines()
rel_path = f"../models/{model.tag.path()}"
for it in service_contents:
if codegen.OPENLLM_MODEL_ID in it:
service_contents[service_contents.index(it)] = (
codegen.ModelIdFormatter(str(model.tag)).vformat(it)[: -(len(codegen.OPENLLM_MODEL_ID) + 3)] + "\n"
codegen.ModelIdFormatter(rel_path).vformat(it)[: -(len(codegen.OPENLLM_MODEL_ID) + 3)] + "\n"
)
if "__bento_name__" in it:
service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
@@ -346,70 +348,3 @@ def create_bento(
bento._fs.writetext(service_fs_path, script)
return bento.save()
@inject
def build(
model_name: str,
*,
model_id: str | None = None,
model_version: str | None = None,
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
bettertransformer: bool | None = None,
adapter_map: dict[str, str | None] | None = None,
build_ctx: str | None = None,
extra_dependencies: tuple[str, ...] | None = None,
workers_per_resource: int | float | None = None,
overwrite_existing_bento: bool = False,
runtime: t.Literal["ggml", "transformers"] = "transformers",
dockerfile_template: str | None = None,
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
) -> bentoml.Bento:
"""Package a LLM into a Bento.
The LLM will be built into a BentoService with the following structure:
if quantize is passed, it will instruct the model to be quantized dynamically during serving time.
if bettertransformer is passed, it will instruct the model to use BetterTransformer during serving time.
Other parameters including model_name, model_id and attrs will be passed to the LLM class itself.
"""
args = [sys.executable, "-m", "openllm", "build", model_name, "--machine", "--runtime", runtime]
if quantize and bettertransformer:
raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.")
if quantize:
args.extend(["--quantize", quantize])
if bettertransformer:
args.append("--bettertransformer")
if model_id:
args.extend(["--model-id", model_id])
if build_ctx:
args.extend(["--build-ctx", build_ctx])
if extra_dependencies:
args.extend([f"--enable-features={f}" for f in extra_dependencies])
if workers_per_resource:
args.extend(["--workers-per-resource", str(workers_per_resource)])
if overwrite_existing_bento:
args.append("--overwrite")
if adapter_map:
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
if model_version:
args.extend(["--model-version", model_version])
if dockerfile_template:
args.extend(["--dockerfile-template", dockerfile_template])
try:
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
except subprocess.CalledProcessError as e:
logger.error("Exception caught while building %s", model_name, exc_info=e)
if e.stderr:
raise OpenLLMException(e.stderr.decode("utf-8")) from None
raise OpenLLMException(str(e)) from None
# NOTE: This usually only concern BentoML devs.
pattern = r"^__tag__:[^:\n]+:[^:\n]+"
matched = re.search(pattern, output.decode("utf-8").strip(), re.MULTILINE)
assert matched is not None, f"Failed to find tag from output: {output}"
_, _, tag = matched.group(0).partition(":")
return bentoml.get(tag, _bento_store=bento_store)

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import string
import typing as t
@@ -20,10 +19,10 @@ import typing as t
class PromptFormatter(string.Formatter):
"""This PromptFormatter is largely based on langchain's implementation."""
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str:
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.LiteralString:
if len(args) > 0:
raise ValueError("Positional arguments are not supported")
return super().vformat(format_string, args, kwargs)
return t.cast("t.LiteralString", super().vformat(format_string, args, kwargs))
def check_unused_args(
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]

View File

@@ -11,3 +11,92 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
from .utils import LazyLoader
from .utils import is_bitsandbytes_available
from .utils import is_transformers_supports_kbit
from .utils import pkg
if t.TYPE_CHECKING:
import torch
import openllm
import transformers
from ._types import DictStrAny
else:
torch = LazyLoader("torch", globals(), "torch")
transformers = LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
QuantiseMode = t.Literal["int8", "int4", "gptq"]
def infer_quantisation_config(
cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any
) -> tuple[transformers.BitsAndBytesConfig | t.Any, DictStrAny]:
# 8 bit configuration
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
def create_int8_config(int8_skip_modules: list[str] | None):
if int8_skip_modules is None:
int8_skip_modules = []
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
int8_skip_modules.append("lm_head")
return transformers.BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
llm_int8_threshhold=int8_threshold,
llm_int8_skip_modules=int8_skip_modules,
llm_int8_has_fp16_weight=int8_has_fp16_weight,
)
# 4 bit configuration
int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16)
int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4")
int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True)
# NOTE: Quantization setup
# quantize is a openllm.LLM feature, where we can quantize the model
# with bitsandbytes or quantization aware training.
if not is_bitsandbytes_available():
raise RuntimeError(
"Quantization requires bitsandbytes to be installed. Make "
"sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
)
if quantise == "int8":
quantisation_config = create_int8_config(int8_skip_modules)
elif quantise == "int4":
if is_transformers_supports_kbit():
quantisation_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=int4_compute_dtype,
bnb_4bit_quant_type=int4_quant_type,
bnb_4bit_use_double_quant=int4_use_double_quant,
)
else:
logger.warning(
"'quantize' is set to int4, while the current transformers version %s does not support "
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
"make sure to install the latest version of transformers either via PyPI or "
"from git source: 'pip install git+https://github.com/huggingface/transformers'.",
pkg.pkg_version_info("transformers"),
)
logger.warning("OpenLLM will fallback to 8-bit quantization.")
quantisation_config = create_int8_config(int8_skip_modules)
elif quantise == "gptq":
# TODO: support GPTQ loading quantization
raise NotImplementedError("GPTQ is not supported yet.")
else:
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.")
return quantisation_config, attrs

View File

@@ -11,11 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Schema definition for OpenLLM. This can be use for client interaction.
"""
"""Schema definition for OpenLLM. This can be use for client interaction."""
from __future__ import annotations
import functools
import typing as t
@@ -23,6 +20,8 @@ import attr
import inflection
import openllm
from openllm._configuration import GenerationConfig
from openllm.utils import bentoml_cattr
if t.TYPE_CHECKING:
@@ -79,6 +78,14 @@ class GenerationOutput:
configuration: t.Dict[str, t.Any]
"""A mapping of configuration values for given system."""
@property
def marshaled_config(self) -> GenerationConfig:
return bentoml_cattr.structure(self.configuration, GenerationConfig)
@property
def unmarshaled(self) -> dict[str, t.Any]:
return bentoml_cattr.unstructure(self)
@attr.frozen(slots=True)
class MetadataOutput:

View File

@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The service definition for running any LLMService.
"""The service definition for running any LLMService.
Note that the line `model = ...` is a special line and should not be modified. This will be handled by openllm
internally to generate the correct model service when bundling the LLM to a Bento.
@@ -22,7 +21,6 @@ This will ensure that 'bentoml serve llm-bento' will work accordingly.
The generation code lives under utils/codegen.py
"""
from __future__ import annotations
import os
import typing as t
import warnings
@@ -136,7 +134,7 @@ async def hf_agent(request: Request) -> Response:
except orjson.JSONDecodeError as err:
raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
stop = input_data.parameters.pop("stop", "\n")
stop = input_data.parameters.pop("stop", ["\n"])
try:
resp = await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters)
return JSONResponse(resp, status_code=200)
@@ -150,9 +148,11 @@ svc.mount_asgi_app(hf_app, path="/hf")
async def list_adapter_v1(_: Request) -> Response:
res = runner.peft_adapters
if res["success"]:
res["result"] = {k: v.to_dict() for k, v in res["result"].items()}
res: dict[str, t.Any] = {}
if runner.peft_adapters["success"] is True:
res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()}
res["success"] = runner.peft_adapters["success"]
res["error_msg"] = runner.peft_adapters["error_msg"]
return JSONResponse(res, status_code=200)

View File

@@ -13,8 +13,6 @@
# limitations under the License.
from __future__ import annotations
import functools
import logging
import math
import os
@@ -23,19 +21,19 @@ import typing as t
import psutil
import bentoml
import openllm
from bentoml._internal.resource import Resource
from bentoml._internal.resource import get_resource
from bentoml._internal.resource import system_resources
from bentoml._internal.runner.strategy import THREAD_ENVS
from bentoml._internal.runner.strategy import Strategy
from .utils import LazyType
from .exceptions import OpenLLMException
from .utils import ReprMixin
if t.TYPE_CHECKING:
import bentoml
ListIntStr = list[int | str]
else:
ListIntStr = list
@@ -43,28 +41,46 @@ else:
logger = logging.getLogger(__name__)
class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"):
class AmdGpuResource(Resource[t.List[str]], resource_id="amd.com/gpu"):
@classmethod
def from_spec(cls, spec: int | str | list[str | int]) -> list[int]:
if not isinstance(spec, (int, str)) and not LazyType(ListIntStr).isinstance(spec):
def from_spec(cls, spec: t.Any) -> list[str]:
if not isinstance(spec, (int, str, list)):
raise TypeError("AMD GPU device IDs must be int, str or a list specifing the exact GPUs to use.")
try:
if isinstance(spec, int):
if spec == -1:
return []
if spec < -1:
raise ValueError
return list(range(spec))
return [str(i) for i in range(spec)]
elif isinstance(spec, str):
return cls.from_spec(int(spec))
try:
return cls.from_spec(int(spec))
except ValueError:
if spec.startswith("GPU"):
return [spec]
raise ValueError
else:
return [int(x) for x in spec]
return [str(x) for x in spec]
except ValueError:
raise openllm.exceptions.OpenLLMException(f"Invalid AMD GPU resource limit '{spec}'. ")
raise OpenLLMException(f"Invalid AMD GPU resource limit '{spec}'.")
@classmethod # type: ignore (overload)
@functools.lru_cache(maxsize=1)
def from_system(cls) -> list[int]:
@classmethod
def from_system(cls) -> list[str]:
"""Retrieve AMD GPU from system, currently only supports on Linux.
This assumes that ROCm is setup correctly."""
This assumes that ROCm is setup correctly.
"""
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
if cuda_visible_devices in ("", "-1"):
return []
if cuda_visible_devices is not None:
cuda_visible_devices = cuda_visible_devices.split(",")
if "-1" in cuda_visible_devices:
cuda_visible_devices = cuda_visible_devices[: cuda_visible_devices.index("-1")]
return cuda_visible_devices
if not psutil.LINUX:
logger.debug("AMD GPU resource is only supported on Linux.")
return []
@@ -84,7 +100,7 @@ class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"):
num = c_uint32(0)
ret = rocmsmi.rsmi_num_monitor_devices(byref(num))
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS:
return list(range(num.value))
return [str(i) for i in range(num.value)]
return []
except Exception as err:
logger.debug("Failed to setup AMD GPU resource: %s", err)
@@ -93,18 +109,22 @@ class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"):
sys.path.remove("/opt/rocm/libexec/rocm_smi")
@classmethod
def validate(cls, val: list[int]):
if any(gpu_index < 0 for gpu_index in val):
raise openllm.exceptions.OpenLLMException(f"Negative GPU device in {val}.")
if any(gpu_index >= len(cls.from_system()) for gpu_index in val):
raise openllm.exceptions.OpenLLMException(
f"GPU device index in {val} is greater than the system available: {cls.from_system()}"
)
def validate(cls, val: list[str]):
for gpu_index_or_literal in val:
try:
idx = int(gpu_index_or_literal)
except ValueError:
raise OpenLLMException(f"Invalid AMD GPU device index: {val}")
if int(idx) < 0:
raise OpenLLMException(f"Negative GPU device in {val}.")
if int(idx) >= len(cls.from_system()):
raise OpenLLMException(
f"GPU device index in {val} is greater than the system available: {cls.from_system()}"
)
class CascadingResourceStrategy(Strategy, ReprMixin):
"""This is rather an extension of bentoml._internal.runner.strategy.DefaultStrategy
where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU.
See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices
@@ -147,9 +167,9 @@ class CascadingResourceStrategy(Strategy, ReprMixin):
)
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
if isinstance(workers_per_resource, float):
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0:
raise ValueError("Fractional CPU multi threading support is not yet supported.")
return workers_per_resource
return int(workers_per_resource)
return math.ceil(cpus) * workers_per_resource
@@ -167,11 +187,13 @@ class CascadingResourceStrategy(Strategy, ReprMixin):
workers_per_resource: int | float,
worker_index: int,
) -> dict[str, t.Any]:
"""
"""Get worker env for this given worker_index.
Args:
runnable_class : The runnable class to be run.
resource_request : The resource request of the runnable.
worker_index : The index of the worker, start from 0.
runnable_class: The runnable class to be run.
resource_request: The resource request of the runnable.
workers_per_resource: # of workers per resource.
worker_index: The index of the worker, start from 0.
"""
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None)
disabled = cuda_env in ("", "-1")

View File

@@ -11,35 +11,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Types definition for OpenLLM.
"""Types definition for OpenLLM.
Note that this module SHOULD NOT BE IMPORTED DURING RUNTIME, as this serve only for typing purposes.
It will raises a RuntimeError if this is imported eagerly.
"""
from __future__ import annotations
import typing as t
if not t.TYPE_CHECKING:
raise RuntimeError(f"{__name__} should not be imported during runtime")
import click
import bentoml
import openllm
import transformers
from ._configuration import AdapterType
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
if t.TYPE_CHECKING:
import click
import peft
import openllm
import transformers
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
AnyCallable = t.Callable[..., t.Any]
DictStrAny = dict[str, t.Any]
ListAny = list[t.Any]
ListStr = list[str]
TupleAny = tuple[t.Any, ...]
P = t.ParamSpec("P")
O_co = t.TypeVar("O_co", covariant=True)
LiteralRuntime: t.TypeAlias = t.Literal["pt", "tf", "flax"]
T = t.TypeVar("T")
Ts = t.TypeVarTuple("Ts")
class ClickFunctionWrapper(t.Protocol[P, O_co]):
@@ -83,9 +91,19 @@ class TokenizerProtocol(_StubsMixin[_MT], t.Protocol):
...
PeftAdapterOutput = dict[t.Literal["success", "result", "error_msg"], bool | str | dict[t.Any, t.Any]]
class PeftAdapterOutput(t.TypedDict):
success: bool
result: dict[str, peft.PeftConfig]
error_msg: str
AdaptersMapping = dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None
class AdaptersTuple(TupleAny):
adapter_id: str
name: str | None
config: DictStrAny
AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]] | None
class LLMRunnable(bentoml.Runnable):

View File

File diff suppressed because it is too large Load Diff

View File

@@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenLLM client.
"""OpenLLM client.
To start interact with the server, you can do the following:
@@ -21,7 +20,6 @@ To start interact with the server, you can do the following:
>>> client.query("What is the meaning of life?")
"""
from __future__ import annotations
import importlib
import itertools
import typing as t

View File

@@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base exceptions for OpenLLM. This extends BentoML exceptions.
"""
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
from __future__ import annotations
import bentoml

View File

@@ -15,12 +15,14 @@
"""This module is derived from HuggingFace's AutoConfig, AutoModel, etc."""
from __future__ import annotations
import typing as t
import openllm
from ...utils import is_torch_available, is_flax_available, is_tf_available, LazyModule
from ...utils import LazyModule
from ...utils import is_flax_available
from ...utils import is_tf_available
from ...utils import is_torch_available
_import_structure = {

View File

@@ -13,8 +13,6 @@
# limitations under the License.
from __future__ import annotations
import types
import typing as t
from collections import OrderedDict
@@ -24,6 +22,7 @@ import openllm
if t.TYPE_CHECKING:
import types
from collections import _odict_items
from collections import _odict_keys
from collections import _odict_values
@@ -93,9 +92,7 @@ class _LazyConfigMapping(ConfigOrderedDict):
return item in self._mapping or item in self._extra_content
def register(self, key: str, value: t.Any):
"""
Register a new configuration in this mapping.
"""
"""Register a new configuration in this mapping."""
if key in self._mapping.keys():
raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
self._extra_content[key] = value
@@ -115,7 +112,10 @@ CONFIG_NAME_ALIASES: dict[str, str] = {
class AutoConfig:
def __init__(self, *_: t.Any, **__: t.Any):
raise EnvironmentError("Cannot instantiate Config. Please use `Config.for_model(model_name)` instead.")
"""This metaclass should be initialised via `AutoConfig.for_model`."""
raise EnvironmentError(
"Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead."
)
@classmethod
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig:

View File

@@ -13,11 +13,10 @@
# limitations under the License.
from __future__ import annotations
import importlib
import inspect
import logging
import types
import sys
import typing as t
from collections import OrderedDict
@@ -30,13 +29,14 @@ from .configuration_auto import AutoConfig
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
if sys.version_info[:2] >= (3, 11):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
import types
from collections import _odict_items
from collections import _odict_keys
from collections import _odict_values
@@ -54,10 +54,10 @@ else:
logger = logging.getLogger(__name__)
class _BaseAutoLLMClass:
class BaseAutoLLMClass:
_model_mapping: _LazyAutoMapping
def __init__(self, *args: t.Any, **attrs: t.Any):
def __init__(self, *args: t.Any, **attrs: t.Any): # noqa
raise EnvironmentError(
f"Cannot instantiate {self.__class__.__name__} directly. "
"Please use '{self.__class__.__name__}.Runner(model_name)' instead."
@@ -129,7 +129,7 @@ class _BaseAutoLLMClass:
llm = model_class.from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs)
if ensure_available:
logger.debug(
"'ensure_available=True', Downloading '%s' with 'model_id=%s' to local model store.",
"'ensure_available=True', OpenLLM will automatically import '%s' with 'model_id=%s' to local store if the entry does not exists.",
model,
llm.model_id,
)
@@ -144,8 +144,7 @@ class _BaseAutoLLMClass:
@classmethod
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner:
"""
Create a LLM Runner for the given model name.
"""Create a LLM Runner for the given model name.
Args:
model: The model name to instantiate.
@@ -160,8 +159,7 @@ class _BaseAutoLLMClass:
@classmethod
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]):
"""
Register a new model for this class.
"""Register a new model for this class.
Args:
config_class: The configuration corresponding to the model to register.
@@ -191,13 +189,14 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
try:
return getattribute_from_module(openllm_module, attr)
except ValueError:
raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!")
raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
else:
raise ValueError(f"Could not find {attr} in {openllm_module}!")
class _LazyAutoMapping(ConfigModelOrderedDict):
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
This OrderedDict values() and keys() returns the list instead, so you don't
have to do list(mapping.values()) to get the list of values.
"""
@@ -281,9 +280,7 @@ class _LazyAutoMapping(ConfigModelOrderedDict):
return model_type in self._model_mapping
def register(self, key: t.Any, value: t.Any):
"""
Register a new model in this mapping.
"""
"""Register a new model in this mapping."""
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping.keys():
@@ -292,4 +289,4 @@ class _LazyAutoMapping(ConfigModelOrderedDict):
self._extra_content[key] = value
__all__ = ["_BaseAutoLLMClass", "_LazyAutoMapping"]
__all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"]

View File

@@ -13,12 +13,11 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import _BaseAutoLLMClass
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
@@ -45,5 +44,5 @@ MODEL_MAPPING: dict[
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
class AutoLLM(_BaseAutoLLMClass):
class AutoLLM(BaseAutoLLMClass):
_model_mapping = MODEL_MAPPING

View File

@@ -13,12 +13,11 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import _BaseAutoLLMClass
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
@@ -38,5 +37,5 @@ MODEL_FLAX_MAPPING: dict[
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
class AutoFlaxLLM(_BaseAutoLLMClass):
class AutoFlaxLLM(BaseAutoLLMClass):
_model_mapping = MODEL_FLAX_MAPPING

View File

@@ -13,12 +13,11 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import _BaseAutoLLMClass
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
@@ -38,5 +37,5 @@ MODEL_TF_MAPPING: dict[
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
class AutoTFLLM(_BaseAutoLLMClass):
class AutoTFLLM(BaseAutoLLMClass):
_model_mapping = MODEL_TF_MAPPING

View File

@@ -13,11 +13,12 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, is_cpm_kernels_available, LazyModule
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_cpm_kernels_available
from ...utils import is_torch_available
_import_structure = {

View File

@@ -17,9 +17,7 @@ import openllm
class ChatGLMConfig(openllm.LLMConfig):
"""
ChatGLM is an open bilingual language model based on
[General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
With the quantization technique, users can deploy locally on consumer-grade graphics cards
(only 6GB of GPU memory is required at the INT4 quantization level).

View File

@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
import bentoml
import openllm
from ...utils import generate_labels
if t.TYPE_CHECKING:
import torch
import transformers
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
@@ -34,15 +36,15 @@ class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrain
self.device = torch.device("cuda")
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
(_, model_attrs), tokenizer_kwds = self.llm_parameters
attrs = {**model_attrs, **attrs}
_, tokenizer_attrs = self.llm_parameters
return bentoml.transformers.save_model(
self.tag,
transformers.AutoModel.from_pretrained(self.model_id, trust_remote_code=trust_remote_code),
labels=generate_labels(self),
custom_objects={
"tokenizer": transformers.AutoTokenizer.from_pretrained(
self.model_id, trust_remote_code=trust_remote_code, **tokenizer_kwds
self.model_id, trust_remote_code=trust_remote_code, **tokenizer_attrs
)
},
)
@@ -62,8 +64,8 @@ class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrain
if use_default_prompt_template and chat_history is not None:
for i, (old_query, response) in enumerate(chat_history):
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n"
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:"
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001
else:
prompt_text = prompt

View File

@@ -13,11 +13,12 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, LazyModule
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure = {
"configuration_dolly_v2": ["DollyV2Config", "START_DOLLY_V2_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"],

View File

@@ -11,12 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The following includes OpenLLM configuration and excerpt from
[instruct_pipeline.py](https://huggingface.co/databricks/dolly-v2-3b/blob/main/instruct_pipeline.py)
"""
from __future__ import annotations
import typing as t
import openllm
@@ -27,8 +22,7 @@ if t.TYPE_CHECKING:
class DollyV2Config(openllm.LLMConfig):
"""Databricks Dolly is an instruction-following large language model trained on the Databricks
machine learning platform that is licensed for commercial use.
"""Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.
Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k
generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming,
@@ -103,15 +97,19 @@ DEFAULT_PROMPT_TEMPLATE = """{intro}
def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int:
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
Args:
tokenizer (PreTrainedTokenizer): the tokenizer
key (str): the key to convert to a single token
tokenizer: the tokenizer
key: the key to convert to a single token
Raises:
RuntimeError: if more than one ID was generated
Returns:
int: the token ID for the given key
int: the token ID for the given key.
"""
token_ids = tokenizer.encode(key)
if len(token_ids) > 1:

View File

@@ -12,24 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import re
import typing as t
import bentoml
import openllm
from ...utils import normalize_attrs_to_model_tokenizer_pair
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE
from .configuration_dolly_v2 import END_KEY
from .configuration_dolly_v2 import RESPONSE_KEY
from .configuration_dolly_v2 import get_special_token_id
from ...utils import normalize_attrs_to_model_tokenizer_pair
if t.TYPE_CHECKING:
import tensorflow as tf
import torch
import bentoml
import transformers
else:
tf = openllm.utils.LazyLoader("tf", globals(), "tensorflow")
@@ -75,14 +75,16 @@ def get_pipeline(
top_k: int = 0,
**kwargs: t.Any,
):
"""Initialize the pipeline
"""Initialize the pipeline.
Args:
do_sample (bool, optional): Whether or not to use sampling. Defaults to True.
max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128.
top_p (float, optional): If set to float < 1, only the smallest set of most probable tokens with
probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92.
top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering.
Defaults to 0.
do_sample: Whether or not to use sampling. Defaults to True.
max_new_tokens: Max new tokens after the prompt to generate. Defaults to 128.
top_p: If set to float < 1, only the smallest set of most probable tokens with
probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 0.
*args: Additional positional arguments to be passed to ``transformers.Pipeline``.
**kwargs: Additional keyword arguments to be passed to ``transformers.Pipeline``.
"""
super().__init__(
*args,
@@ -195,7 +197,7 @@ def get_pipeline(
try:
response_pos = sequence.index(response_key_token_id)
except ValueError:
logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}")
logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
response_pos = None
if response_pos:
@@ -228,7 +230,7 @@ def get_pipeline(
if m:
decoded = m.group(1).strip()
else:
logger.warn(f"Failed to find response in:\n{fully_decoded}")
logger.warning("Failed to find response in:\n%s", fully_decoded)
# If the full text is requested, then append the decoded text to the original instruction.
# This technically isn't the full text, as we format the instruction in the prompt the model has been

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, LazyModule
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure = {

View File

@@ -17,9 +17,9 @@ import openllm
class FalconConfig(openllm.LLMConfig):
"""Falcon-7B is a 7B parameters causal decoder-only model built by
TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
enhanced with curated corpora. It is made available under the TII Falcon LLM License.
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
It is made available under the TII Falcon LLM License.
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
"""

View File

@@ -13,19 +13,19 @@
# limitations under the License.
from __future__ import annotations
import typing as t
import bentoml
import openllm
from ..._prompt import default_formatter
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
if t.TYPE_CHECKING:
import torch
import torch.amp
import bentoml
import transformers
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
@@ -81,7 +81,7 @@ class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTraine
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
)
) from None
else:
prompt_text = prompt

View File

@@ -13,11 +13,13 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, is_tf_available, is_flax_available, LazyModule
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_flax_available
from ...utils import is_tf_available
from ...utils import is_torch_available
_import_structure = {

View File

@@ -46,8 +46,9 @@ DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruct
class FlanT5Config(openllm.LLMConfig):
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf)
- it is an enhanced version of T5 that has been finetuned in a mixture of tasks.
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
It is an enhanced version of T5 that has been finetuned in a mixture of tasks.
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
"""

View File

@@ -12,19 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
import openllm
from ..._prompt import default_formatter
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
if t.TYPE_CHECKING:
import torch
import transformers # noqa
import transformers # noqa: F401
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
@@ -60,7 +59,7 @@ class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformer
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
)
) from None
else:
prompt_text = prompt

View File

@@ -12,17 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
import openllm
from ..._prompt import default_formatter
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
if t.TYPE_CHECKING:
import transformers # noqa
import transformers # noqa: F401
class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
@@ -54,7 +53,7 @@ class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "tra
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
)
) from None
else:
prompt_text = prompt

View File

@@ -12,17 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
import openllm
from ..._prompt import default_formatter
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
if t.TYPE_CHECKING:
import transformers # noqa
import transformers # noqa: F401
class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
@@ -40,17 +39,20 @@ class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transfo
**attrs: t.Any,
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
if use_default_prompt_template:
prompt_variables = {
k: v
for k, v in attrs.items()
if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
}
template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
if "instruction" in prompt_variables:
raise RuntimeError(
"'instruction' should be passed as the first argument "
"instead of kwargs when 'use_default_prompt_template=True'"
)
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
try:
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
except KeyError as e:
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
) from None
else:
prompt_text = prompt

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, LazyModule
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure = {

View File

@@ -14,14 +14,13 @@
from __future__ import annotations
from __future__ import annotations
import openllm
class GPTNeoXConfig(openllm.LLMConfig):
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and
openly available to the public through a permissive license. It is, to the best of our knowledge, the largest dense autoregressive model
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
It is, to the best of our knowledge, the largest dense autoregressive model
that has publicly available weights at the time of submission. The training and evaluation code, as well as the model weights,
can be found at https://github.com/EleutherAI/gpt-neox.

View File

@@ -13,21 +13,21 @@
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import openllm
from ..._prompt import default_formatter
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
if t.TYPE_CHECKING:
import bentoml
import transformers # noqa
import torch
import torch.amp
import bentoml
import transformers
else:
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
@@ -62,7 +62,7 @@ class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNe
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
)
) from None
else:
prompt_text = prompt

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, LazyModule
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure = {

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import typing as t
import openllm
@@ -27,10 +26,11 @@ else:
class MPTConfig(openllm.LLMConfig):
"""MPT is a decoder-style transformer pretrained from scratch on
English text and code. This model was trained by [MosaicML](https://www.mosaicml.com/).
"""MPT is a decoder-style transformer pretrained from scratch on English text and code.
`openllm.MPT` encapsulate a family of MPT variants that is publicly available
This model was trained by [MosaicML](https://www.mosaicml.com/).
``openllm.MPT`` encapsulate a family of MPT variants that is publicly available
on HuggingFace. Refers [HuggingFace's MosaicML page](https://huggingface.co/mosaicml)
for more details on specific models.
"""

View File

@@ -13,16 +13,16 @@
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
from ..._prompt import default_formatter
from ...utils import is_triton_available
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
from ...utils import generate_labels
from ...utils import is_triton_available
if t.TYPE_CHECKING:
@@ -78,8 +78,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
return model_kwds, tokenizer_kwds
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
(_, model_attrs), tokenizer_kwds = self.llm_parameters
attrs = {**model_attrs, **attrs}
_, tokenizer_attrs = self.llm_parameters
torch_dtype = attrs.pop("torch_dtype", self.dtype)
device_map = attrs.pop("device_map", None)
@@ -93,7 +92,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
trust_remote_code=trust_remote_code,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
if tokenizer.pad_token_id is None:
logger.warning("pad_token_id is not set. Setting it to eos_token")
tokenizer.pad_token = tokenizer.eos_token
@@ -107,7 +106,12 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
**attrs,
)
try:
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
return bentoml.transformers.save_model(
self.tag,
model,
custom_objects={"tokenizer": tokenizer},
labels=generate_labels(self),
)
finally:
torch.cuda.empty_cache()
@@ -169,7 +173,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
)
) from None
else:
prompt_text = prompt

View File

@@ -13,11 +13,13 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, LazyModule, is_flax_available, is_tf_available
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_flax_available
from ...utils import is_tf_available
from ...utils import is_torch_available
_import_structure = {

View File

@@ -18,9 +18,7 @@ import openllm
class OPTConfig(openllm.LLMConfig):
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068)
and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq)
on May 3rd 2022 by Meta AI.
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
OPT was predominantly pretrained with English text, but a small amount of non-English data is still present
within the training corpus via CommonCrawl. The model was pretrained using a causal language modeling (CLM)

View File

@@ -13,15 +13,15 @@
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
from ..._prompt import default_formatter
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
from ...utils import generate_labels
if t.TYPE_CHECKING:
@@ -44,17 +44,21 @@ class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tok
return {}, tokenizer_kwds
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
(_, model_attrs), tokenizer_kwds = self.llm_parameters
attrs = {**model_attrs, **attrs}
_, tokenizer_attrs = self.llm_parameters
config = transformers.AutoConfig.from_pretrained(self.model_id)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
tokenizer.pad_token_id = config.pad_token_id
model = t.cast(
"transformers.FlaxOPTForCausalLM",
transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs),
)
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
return bentoml.transformers.save_model(
self.tag,
model,
custom_objects={"tokenizer": tokenizer},
labels=generate_labels(self),
)
def sanitize_parameters(
self,
@@ -81,7 +85,7 @@ class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tok
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
)
) from None
else:
prompt_text = prompt

View File

@@ -13,21 +13,21 @@
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
from ..._prompt import default_formatter
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
from ...utils import generate_labels
if t.TYPE_CHECKING:
import torch
import transformers # noqa
import transformers
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
@@ -55,13 +55,12 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer
return model_kwds, tokenizer_kwds
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
(_, model_attrs), tokenizer_kwds = self.llm_parameters
attrs = {**model_attrs, **attrs}
_, tokenizer_attrs = self.llm_parameters
torch_dtype = attrs.pop("torch_dtype", self.dtype)
config = transformers.AutoConfig.from_pretrained(self.model_id)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
tokenizer.pad_token_id = config.pad_token_id
model = t.cast(
"transformers.OPTForCausalLM",
@@ -69,7 +68,12 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer
self.model_id, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, **attrs
),
)
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
return bentoml.transformers.save_model(
self.tag,
model,
custom_objects={"tokenizer": tokenizer},
labels=generate_labels(self),
)
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM:
torch_dtype = attrs.pop("torch_dtype", self.dtype)
@@ -105,7 +109,7 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
)
) from None
else:
prompt_text = prompt

View File

@@ -13,15 +13,15 @@
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
from ..._prompt import default_formatter
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
from ..._prompt import default_formatter
from ...utils import generate_labels
if t.TYPE_CHECKING:
@@ -45,16 +45,20 @@ class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Token
return {}, tokenizer_kwds
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
(_, model_attrs), tokenizer_kwds = self.llm_parameters
attrs = {**model_attrs, **attrs}
_, tokenizer_attrs = self.llm_parameters
config = transformers.AutoConfig.from_pretrained(self.model_id)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
tokenizer.pad_token_id = config.pad_token_id
model: transformers.TFOPTForCausalLM = transformers.TFOPTForCausalLM.from_pretrained(
self.model_id, trust_remote_code=trust_remote_code, **attrs
)
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
return bentoml.transformers.save_model(
self.tag,
model,
custom_objects={"tokenizer": tokenizer},
labels=generate_labels(self),
)
def sanitize_parameters(
self,
@@ -80,7 +84,7 @@ class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Token
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
"Use 'use_default_prompt_template=False' to disable the default prompt template."
)
) from None
else:
prompt_text = prompt

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, LazyModule
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure = {

View File

@@ -17,8 +17,9 @@ import openllm
class StableLMConfig(openllm.LLMConfig):
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models
pre-trained on a diverse collection of English datasets with a sequence
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
It is pre-trained on a diverse collection of English datasets with a sequence
length of 4096 to push beyond the context window limitations of existing open-source language models.
StableLM-Tuned-Alpha is a suite of 3B and 7B parameter decoder-only language models
@@ -74,6 +75,6 @@ SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
""" # noqa
"""
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>"""

View File

@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import openllm
from ..._prompt import default_formatter
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE
from .configuration_stablelm import SYSTEM_PROMPT
from ..._prompt import default_formatter
if t.TYPE_CHECKING:

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ...utils import is_torch_available, LazyModule
from ...exceptions import MissingDependencyError
from ...utils import LazyModule
from ...utils import is_torch_available
_import_structure = {

View File

@@ -17,8 +17,7 @@ import openllm
class StarCoderConfig(openllm.LLMConfig):
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from
[The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
The model uses [Multi Query Attention](https://arxiv.org/abs/1911.02150),
[a context window of 8192 tokens](https://arxiv.org/abs/2205.14135), and was trained using the

View File

@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
import bentoml
import openllm
from ...utils import generate_labels
if t.TYPE_CHECKING:
import torch
@@ -54,13 +55,12 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
return model_kwds, tokenizer_kwds
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
(_, model_attrs), tokenizer_kwds = self.llm_parameters
attrs = {**model_attrs, **attrs}
_, tokenizer_attrs = self.llm_parameters
torch_dtype = attrs.pop("torch_dtype", torch.float16)
device_map = attrs.pop("device_map", "auto")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
tokenizer.add_special_tokens(
{
"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
@@ -72,7 +72,12 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs
)
try:
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
return bentoml.transformers.save_model(
self.tag,
model,
custom_objects={"tokenizer": tokenizer},
labels=generate_labels(self),
)
finally:
# NOTE: We need to free the cache after saving here so that we can load it back later on.
torch.cuda.empty_cache()

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import dataclasses
import logging
import os

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import argparse
import logging
import typing as t

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import dataclasses
import logging
import os

View File

@@ -37,15 +37,20 @@ llm.save_pretrained("./path/to/local-dolly")
"""
from __future__ import annotations
import typing as t
import openllm
from ..utils import LazyModule
import typing as t
import openllm
if t.TYPE_CHECKING:
import bentoml
from .._types import ModelProtocol, TokenizerProtocol
from .transformers import _M, _T
from .._llm import M
from .._llm import T
from .._types import ModelProtocol
from .._types import TokenizerProtocol
def import_model(
@@ -80,7 +85,7 @@ def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]:
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]:
if llm.runtime == "transformers":
return openllm.transformers.load_model(llm, *decls, **attrs)
elif llm.runtime == "ggml":
@@ -89,7 +94,7 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]:
if llm.runtime == "transformers":
return openllm.transformers.load_tokenizer(llm)
elif llm.runtime == "ggml":
@@ -109,11 +114,6 @@ _extras = {
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": []}
if t.TYPE_CHECKING:
from . import import_model as import_model
from . import get as get
from . import save_pretrained as save_pretrained
from . import load_model as load_model
from . import load_tokenizer as load_tokenizer
from . import ggml as ggml
from . import transformers as transformers
else:

View File

@@ -16,19 +16,25 @@
This requires ctransformers to be installed.
"""
from __future__ import annotations
import openllm
import typing as t
import bentoml
import cloudpickle
from ..exceptions import OpenLLMException
from ..utils import LazyLoader
import bentoml
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
from ..exceptions import OpenLLMException
from ..utils import LazyLoader
if t.TYPE_CHECKING:
from .._types import ModelProtocol, TokenizerProtocol
from .transformers import _M, _T
import openllm
import transformers
from .._llm import M
from .._llm import T
from .._types import ModelProtocol
from .._types import TokenizerProtocol
else:
transformers = LazyLoader("transformers", globals(), "transformers")
@@ -44,6 +50,7 @@ def import_model(
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
"""Return an instance of ``bentoml.Model`` from given LLM instance.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
@@ -66,15 +73,16 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
raise
def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]:
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]:
"""Load the model from BentoML store.
By default, it will try to find check the model in the local store.
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
raise NotImplementedError("Currently work in progress.")
def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]:
"""Load the tokenizer from BentoML store.
By default, it will try to find the bentomodel whether it is in store..
@@ -95,14 +103,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
"Model does not have tokenizer. Make sure to save \
the tokenizer within the model via 'custom_objects'.\
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))"
)
) from None
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
bentomodel_fs.getsyspath("/"),
trust_remote_code=llm.__llm_trust_remote_code__,
**tokenizer_attrs,
)
return t.cast("TokenizerProtocol[_T]", tokenizer)
return t.cast("TokenizerProtocol[T]", tokenizer)
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any):

View File

@@ -11,32 +11,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serialisation related implementation for Transformers-based implementation.
"""
"""Serialisation related implementation for Transformers-based implementation."""
from __future__ import annotations
import copy
import openllm
import typing as t
import importlib
import typing as t
import cloudpickle
import bentoml
from bentoml._internal.frameworks.transformers import make_default_signatures
from bentoml._internal.models.model import ModelOptions
from ..exceptions import OpenLLMException
import cloudpickle
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
from ..utils import LazyLoader, is_torch_available
from ..utils import generate_context, normalize_attrs_to_model_tokenizer_pair
from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, MODEL_TO_AUTOCLASS_MAPPING
from bentoml._internal.models.model import ModelOptions
from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
from .constants import MODEL_TO_AUTOCLASS_MAPPING
from ..exceptions import OpenLLMException
from ..utils import LazyLoader
from ..utils import first_not_none
from ..utils import generate_context
from ..utils import generate_labels
from ..utils import is_torch_available
from ..utils import normalize_attrs_to_model_tokenizer_pair
if t.TYPE_CHECKING:
import transformers
import torch
from .._types import P
from .._llm import _M, _T
from .._types import DictStrAny, ModelProtocol, TokenizerProtocol
import openllm
import transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from .._llm import M
from .._llm import T
from .._types import DictStrAny
from .._types import ModelProtocol
from .._types import TokenizerProtocol
else:
transformers = LazyLoader("transformers", globals(), "transformers")
torch = LazyLoader("torch", globals(), "torch")
@@ -46,7 +57,6 @@ def process_transformers_config(
model_id: str, trust_remote_code: bool, **attrs: t.Any
) -> tuple[transformers.PretrainedConfig, dict[str, t.Any], dict[str, t.Any]]:
"""Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs."""
config: transformers.PretrainedConfig = attrs.pop("config", None)
# this logic below is synonymous to handling `from_pretrained` attrs.
@@ -123,7 +133,7 @@ def import_model(
attrs = {**model_attrs, **attrs}
tokenizer = t.cast(
transformers.PreTrainedTokenizer,
"transformers.PreTrainedTokenizer",
transformers.AutoTokenizer.from_pretrained(
llm.model_id,
config=config,
@@ -133,13 +143,16 @@ def import_model(
),
)
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(
llm.model_id,
*decls,
config=config,
trust_remote_code=trust_remote_code,
**hub_attrs,
**attrs,
model = t.cast(
"transformers.PreTrainedModel",
infer_autoclass_from_llm_config(llm, config).from_pretrained(
llm.model_id,
*decls,
config=config,
trust_remote_code=trust_remote_code,
**hub_attrs,
**attrs,
),
)
try:
@@ -148,7 +161,7 @@ def import_model(
module="openllm.serialisation.transformers",
api_version="v1",
context=generate_context(framework_name="openllm"),
labels={"runtime": llm.runtime},
labels=generate_labels(llm),
options=ModelOptions(),
signatures=make_default_signatures(model),
external_modules=[
@@ -173,6 +186,7 @@ def import_model(
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
"""Return an instance of ``bentoml.Model`` from given LLM instance.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
@@ -201,8 +215,9 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
raise
def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]:
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]:
"""Load the model from BentoML store.
By default, it will try to find check the model in the local store.
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
@@ -224,11 +239,11 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo
# BetterTransformer is currently only supported on PyTorch.
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
return t.cast("ModelProtocol[_M]", model)
model = BetterTransformer.transform(model) # type: ignore
return t.cast("ModelProtocol[M]", model)
def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]:
"""Load the tokenizer from BentoML store.
By default, it will try to find the bentomodel whether it is in store..
@@ -249,14 +264,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
"Model does not have tokenizer. Make sure to save \
the tokenizer within the model via 'custom_objects'.\
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))"
)
) from None
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
bentomodel_fs.getsyspath("/"),
trust_remote_code=llm.__llm_trust_remote_code__,
**tokenizer_attrs,
)
return t.cast("TokenizerProtocol[_T]", tokenizer)
return tokenizer
def save_pretrained(
@@ -264,7 +279,7 @@ def save_pretrained(
save_directory: str,
is_main_process: bool = True,
state_dict: DictStrAny | None = None,
save_function: t.Callable[P, None] | None = None,
save_function: t.Callable[..., None] | None = None,
push_to_hub: bool = False,
max_shard_size: int | str = "10GB",
safe_serialization: bool = False,
@@ -272,8 +287,7 @@ def save_pretrained(
**attrs: t.Any,
):
"""Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``."""
if save_function is None:
save_function = torch.save
save_function = first_not_none(save_function, default=torch.save)
model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs)

113
src/openllm/testing.py Normal file
View File

@@ -0,0 +1,113 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests utilities for OpenLLM."""
from __future__ import annotations
import contextlib
import logging
import shutil
import subprocess
import typing as t
import bentoml
import openllm
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from ._types import LiteralRuntime
@contextlib.contextmanager
def build_bento(
model: str,
model_id: str | None = None,
quantize: t.Literal["int4", "int8", "gptq"] | None = None,
runtime: t.Literal["ggml", "transformers"] = "transformers",
cleanup: bool = False,
):
logger.info("Building BentoML for %s", model)
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
yield bento
if cleanup:
logger.info("Deleting %s", bento.tag)
bentoml.bentos.delete(bento.tag)
@contextlib.contextmanager
def build_container(
bento: bentoml.Bento | str | bentoml.Tag,
image_tag: str | None = None,
cleanup: bool = False,
**attrs: t.Any,
):
if isinstance(bento, bentoml.Bento):
bento_tag = bento.tag
else:
bento_tag = bentoml.Tag.from_taglike(bento)
if image_tag is None:
image_tag = str(bento_tag)
executable = shutil.which("docker")
if not executable:
raise RuntimeError("docker executable not found")
try:
logger.info("Building container for %s", bento_tag)
bentoml.container.build(
bento_tag,
backend="docker",
image_tag=(image_tag,),
progress="plain",
**attrs,
)
yield image_tag
finally:
if cleanup:
logger.info("Deleting container %s", image_tag)
subprocess.check_output([executable, "rmi", "-f", image_tag])
@contextlib.contextmanager
def prepare(
model: str,
model_id: str | None = None,
implementation: LiteralRuntime = "pt",
deployment_mode: t.Literal["container", "local"] = "local",
clean_context: contextlib.ExitStack | None = None,
cleanup: bool = False,
):
if clean_context is None:
clean_context = contextlib.ExitStack()
cleanup = True
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
if not bentoml.list(bento_tag):
bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
else:
bento = bentoml.get(bento_tag)
container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_")
if deployment_mode == "container":
container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
yield container_name
if cleanup:
clean_context.close()

View File

@@ -1,39 +0,0 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
from .utils import is_flax_available
from .utils import is_tf_available
from .utils import is_torch_available
try:
import pytest
except ImportError:
raise ImportError("You need to install pytest to use 'openllm.tests' utilities: 'pip install pytest'")
def require_tf(f: t.Callable[..., t.Any]):
return pytest.mark.skipif(not is_tf_available(), reason="requires TensorFlow")(f)
def require_flax(f: t.Callable[..., t.Any]):
return pytest.mark.skipif(not is_flax_available(), reason="requires Flax")(f)
def require_torch(f: t.Callable[..., t.Any]):
return pytest.mark.skipif(not is_torch_available(), reason="requires PyTorch")(f)

View File

@@ -11,30 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities function for OpenLLM. User can import these function for convenience, but
"""Utilities function for OpenLLM.
User can import these function for convenience, but
we won't ensure backward compatibility for these functions. So use with caution.
"""
from __future__ import annotations as _annotations
from __future__ import annotations
import contextlib
import functools
import logging
import logging.config
import os
import sys
import platform
import types
import typing as t
from pathlib import Path
from circus.exc import ConflictError
from bentoml._internal.configuration import DEBUG_ENV_VAR as _DEBUG_ENV_VAR
from bentoml._internal.configuration import GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR
from bentoml._internal.configuration import get_debug_mode
from bentoml._internal.configuration import get_quiet_mode
from bentoml._internal.configuration import set_debug_mode
from bentoml._internal.configuration import set_quiet_mode
from bentoml._internal.log import configure_server_logging
from bentoml._internal.models.model import ModelContext as _ModelContext
from bentoml._internal.log import CLI_LOGGING_CONFIG as _CLI_LOGGING_CONFIG
from bentoml._internal.types import LazyType
from bentoml._internal.utils import LazyLoader
from bentoml._internal.utils import bentoml_cattr
from bentoml._internal.utils import cached_contextmanager
from bentoml._internal.utils import copy_file_to_fs_folder
from bentoml._internal.utils import first_not_none
from bentoml._internal.utils import pkg
@@ -62,8 +67,26 @@ else:
types.UnionType,
)
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload as _overload
else:
from typing_extensions import overload as _overload
if t.TYPE_CHECKING:
import openllm
from .._types import DictStrAny
from .._types import LiteralRuntime
from .._types import P
from ..models.auto.factory import BaseAutoLLMClass
def set_debug_mode(enabled: bool):
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
os.environ[_DEBUG_ENV_VAR] = str(enabled)
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
@@ -75,16 +98,12 @@ def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.An
raise
def gpu_count() -> tuple[int, ...]:
def gpu_count() -> tuple[str, ...]:
from bentoml._internal.resource import NvidiaGpuResource
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is not None:
if "," in cuda_visible_devices:
available_gpu = tuple(int(i) for i in cuda_visible_devices.split(","))
else:
available_gpu = tuple(int(i) for i in cuda_visible_devices.split())
return available_gpu
return tuple(i for i in cuda_visible_devices.split(","))
return tuple(NvidiaGpuResource.from_system())
@@ -94,7 +113,7 @@ _object_setattr = object.__setattr__
def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None:
"""This makes sure that we don't overwrite any existing attributes on the object"""
"""This makes sure that we don't overwrite any existing attributes on the object."""
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
if not hasattr(obj, name):
@@ -107,22 +126,65 @@ def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None
DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get("OPENLLMDEVDEBUG")))
SHOW_CODEGEN = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3
_LOGGING_CONFIG = _CLI_LOGGING_CONFIG.copy()
_LOGGING_CONFIG["loggers"].update(
{
"openllm": {
"level": logging.INFO,
class _ExceptionFilter(logging.Filter):
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
if exclude_exceptions is None:
exclude_exceptions = [ConflictError]
else:
exclude_exceptions.append(ConflictError)
super(_ExceptionFilter, self).__init__(**kwargs)
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
etype, _, _ = record.exc_info
if etype is not None:
for exc in self.EXCLUDE_EXCEPTIONS:
if issubclass(etype, exc):
return False
return True
_LOGGING_CONFIG: DictStrAny = {
"version": 1,
"disable_existing_loggers": True,
"filters": {"excfilter": {"()": _ExceptionFilter}},
"handlers": {
"bentomlhandler": {
"class": "logging.StreamHandler",
"filters": ["excfilter"],
"stream": "ext://sys.stdout",
},
"defaulthandler": {
"class": "logging.StreamHandler",
"level": logging.WARNING,
},
},
"loggers": {
"bentoml": {
"handlers": ["bentomlhandler", "defaulthandler"],
"level": logging.INFO,
"propagate": False,
}
}
)
},
"openllm": {
"handlers": ["bentomlhandler", "defaulthandler"],
"level": logging.INFO,
"propagate": False,
},
},
"root": {"level": logging.WARNING},
}
def configure_logging() -> None:
"""Configure logging for OpenLLM. Behaves similar to how BentoML loggers
are being configured."""
"""Configure logging for OpenLLM.
Behaves similar to how BentoML loggers are being configured.
"""
if get_quiet_mode():
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
@@ -144,7 +206,7 @@ def in_notebook() -> bool:
try:
from IPython.core.getipython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
if "IPKernelApp" not in get_ipython().config: # type: ignore
return False
except ImportError:
return False
@@ -153,8 +215,91 @@ def in_notebook() -> bool:
return True
_dockerenv = Path("/.dockerenv")
_cgroup = Path("/proc/self/cgroup")
class suppress(contextlib.suppress, contextlib.ContextDecorator):
"""A version of contextlib.suppress with decorator support.
>>> @suppress(KeyError)
... def key_error():
... {}['']
>>> key_error()
"""
def compose(*funcs: t.Callable[..., t.Any]):
"""Compose any number of unary functions into a single unary function.
>>> import textwrap
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
>>> strip_and_dedent(compose.__doc__) == expected
True
Compose also allows the innermost function to take arbitrary arguments.
>>> round_three = lambda x: round(x, ndigits=3)
>>> f = compose(round_three, int.__truediv__)
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
def compose_two(f1: t.Callable[..., t.Any], f2: t.Callable[P, t.Any]):
def _(*args: P.args, **kwargs: P.kwargs) -> t.Any:
return f1(f2(*args, **kwargs))
return _
return functools.reduce(compose_two, funcs)
def apply(transform: t.Callable[..., t.Any]):
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
```python
@apply(reversed)
def get_numbers(start):
"doc for get_numbers"
return range(start, start+3)
list(get_numbers(4))
# [6, 5, 4]
```
```python
get_numbers.__doc__
# 'doc for get_numbers'
```
"""
def wrap(func: t.Callable[P, t.Any]):
return functools.wraps(func)(compose(transform, func))
return wrap
@apply(bool)
@suppress(FileNotFoundError)
def _text_in_file(text: str, filename: Path):
return any(text in line for line in filename.open())
def in_docker() -> bool:
"""Is this current environment running in docker?
```python
type(in_docker())
```
"""
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
T = t.TypeVar("T")
K = t.TypeVar("K")
def resolve_filepath(path: str) -> str:
"""Resolve a file path to an absolute path, expand user and environment variables"""
"""Resolve a file path to an absolute path, expand user and environment variables."""
try:
return resolve_user_filepath(path, None)
except FileNotFoundError:
@@ -166,7 +311,9 @@ def validate_is_path(maybe_path: str) -> bool:
def generate_context(framework_name: str) -> _ModelContext:
from .import_utils import is_torch_available, is_flax_available, is_tf_available
from .import_utils import is_flax_available
from .import_utils import is_tf_available
from .import_utils import is_torch_available
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
if is_torch_available():
@@ -174,7 +321,7 @@ def generate_context(framework_name: str) -> _ModelContext:
if is_tf_available():
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
framework_versions["tensorflow-macos" if platform.system() == "Darwin" else "tensorflow"] = get_tf_version()
framework_versions["tensorflow"] = get_tf_version()
if is_flax_available():
framework_versions.update(
{
@@ -186,6 +333,10 @@ def generate_context(framework_name: str) -> _ModelContext:
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> DictStrAny:
return {"runtime": llm.runtime, "framework": "openllm"}
_TOKENIZER_PREFIX = "_tokenizer_"
@@ -198,6 +349,33 @@ def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[DictStrAny,
return attrs, tokenizer_attrs
@_overload
def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]:
...
@_overload
def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]:
...
@_overload
def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]:
...
def infer_auto_class(implementation: LiteralRuntime) -> type[BaseAutoLLMClass]:
if implementation == "tf":
from ..models.auto import AutoTFLLM as auto
elif implementation == "flax":
from ..models.auto import AutoFlaxLLM as auto
elif implementation == "pt":
from ..models.auto import AutoLLM as auto
else:
raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf')")
return auto
# NOTE: The set marks contains a set of modules name
# that are available above and are whitelisted
# to be included in the extra_objects map.
@@ -218,6 +396,7 @@ _import_structure = {
"codegen": [],
"dantic": [],
"representation": ["ReprMixin"],
"lazy": ["LazyModule"],
"import_utils": [
"OPTIONAL_DEPENDENCIES",
"ENV_VARS_TRUE_VALUES",
@@ -248,28 +427,18 @@ if t.TYPE_CHECKING:
from . import LazyType as LazyType
from . import analytics as analytics
from . import bentoml_cattr as bentoml_cattr
from . import cached_contextmanager as cached_contextmanager
from . import codegen as codegen
from . import configure_logging as configure_logging
from . import configure_server_logging as configure_server_logging
from . import copy_file_to_fs_folder as copy_file_to_fs_folder
from . import dantic as dantic
from . import first_not_none as first_not_none
from . import get_debug_mode as get_debug_mode
from . import get_quiet_mode as get_quiet_mode
from . import gpu_count as gpu_count
from . import lenient_issubclass as lenient_issubclass
from . import non_intrusive_setattr as non_intrusive_setattr
from . import pkg as pkg
from . import reserve_free_port as reserve_free_port
from . import resolve_user_filepath as resolve_user_filepath
from . import set_debug_mode as set_debug_mode
from . import set_quiet_mode as set_quiet_mode
from . import in_notebook as in_notebook
from . import validate_or_create_dir as validate_or_create_dir
from . import validate_is_path as validate_is_path
from . import resolve_filepath as resolve_filepath
from . import normalize_attrs_to_model_tokenizer_pair as normalize_attrs_to_model_tokenizer_pair
from . import generate_context as generate_context
from . import field_env_key as field_env_key
from . import validate_or_create_dir as validate_or_create_dir
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES
from .import_utils import DummyMetaclass as DummyMetaclass
@@ -279,6 +448,9 @@ if t.TYPE_CHECKING:
from .import_utils import is_datasets_available as is_datasets_available
from .import_utils import is_einops_available as is_einops_available
from .import_utils import is_flax_available as is_flax_available
from .import_utils import is_jupyter_available as is_jupyter_available
from .import_utils import is_jupytext_available as is_jupytext_available
from .import_utils import is_notebook_available as is_notebook_available
from .import_utils import is_peft_available as is_peft_available
from .import_utils import is_tf_available as is_tf_available
from .import_utils import is_torch_available as is_torch_available
@@ -287,10 +459,6 @@ if t.TYPE_CHECKING:
from .import_utils import is_triton_available as is_triton_available
from .import_utils import require_backends as require_backends
from .import_utils import requires_dependencies as requires_dependencies
from .import_utils import is_jupyter_available as is_jupyter_available
from .import_utils import is_jupytext_available as is_jupytext_available
from .import_utils import is_notebook_available as is_notebook_available
from .lazy import LazyModule as LazyModule
from .representation import ReprMixin as ReprMixin
else:
import sys

View File

@@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Telemetry related for OpenLLM tracking.
"""Telemetry related for OpenLLM tracking.
Users can disable this with OPENLLM_DO_NOT_TRACK envvar.
"""
from __future__ import annotations
import contextlib
import functools
import os

View File

@@ -13,11 +13,14 @@
# limitations under the License.
from __future__ import annotations
import functools
import inspect
import linecache
import logging
import os
import string
import types
import typing as t
from operator import itemgetter
from pathlib import Path
import orjson
@@ -28,17 +31,16 @@ if t.TYPE_CHECKING:
import openllm
DictStrAny = dict[str, t.Any]
ListStr = list[str]
from .._types import AnyCallable
from .._types import DictStrAny
from .._types import ListStr
from .._types import P
from attr import _make_method
PartialAny = functools.partial[t.Any]
else:
# NOTE: Using internal API from attr here, since we are actually
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
from attr._make import _make_method
DictStrAny = dict
ListStr = list
PartialAny = functools.partial
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
@@ -53,11 +55,12 @@ class ModelNameFormatter(string.Formatter):
model_keyword: t.LiteralString = "__model_name__"
def __init__(self, model_name: str):
"""The formatter that extends model_name to be formatted the 'service.py'."""
super().__init__()
self.model_name = model_name
def vformat(self, format_string: str) -> str:
return super().vformat(format_string, (), {self.model_keyword: self.model_name})
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.LiteralString:
return t.cast("t.LiteralString", super().vformat(format_string, (), {self.model_keyword: self.model_name}))
def can_format(self, value: str) -> bool:
try:
@@ -117,9 +120,7 @@ _sentinel = object()
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
"""
Check whether *cls* defines *attrib_name* (and doesn't just inherit it).
"""
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
attr = getattr(cls, attrib_name, _sentinel)
if attr is _sentinel:
return False
@@ -133,9 +134,7 @@ def has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
def get_annotations(cls: type[t.Any]) -> DictStrAny:
"""
Get annotations for *cls*.
"""
"""Get annotations for *cls*."""
if has_own_attribute(cls, "__annotations__"):
return cls.__annotations__
@@ -151,8 +150,7 @@ _classvar_prefixes = (
def is_class_var(annot: str | t.Any) -> bool:
"""
Check whether *annot* is a typing.ClassVar.
"""Check whether *annot* is a typing.ClassVar.
The string comparison hack is used to avoid evaluating all string
annotations which would put attrs-based classes at a performance
@@ -168,16 +166,14 @@ def is_class_var(annot: str | t.Any) -> bool:
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
"""
Add __module__ and __qualname__ to a *method* if possible.
"""
"""Add __module__ and __qualname__ to a *method* if possible."""
try:
method_or_cls.__module__ = cls.__module__
except AttributeError:
pass
try:
method_or_cls.__qualname__ = ".".join((cls.__qualname__, method_or_cls.__name__))
method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
except AttributeError:
pass
@@ -191,6 +187,64 @@ def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str
return method_or_cls
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = ""):
"""Exec the script with the given global (globs) and local (locs) variables."""
bytecode = compile(script, filename, "exec")
eval(bytecode, globs, locs) # noqa: S307
# ported from attrs
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
"""Create the method with the script given and return the method object."""
locs: DictStrAny = {}
# In order of debuggers like PDB being able to step through the code,
# we add a fake linecache entry.
count = 1
base_filename = filename
while True:
linecache_tuple = (
len(script),
None,
script.splitlines(True),
filename,
)
old_val = linecache.cache.setdefault(filename, linecache_tuple)
if old_val == linecache_tuple:
break
else:
filename = f"{base_filename[:-1]}-{count}>"
count += 1
_compile_and_eval(script, globs, locs, filename)
return locs[name]
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]):
"""Create a tuple subclass to hold class attributes.
The subclass is a bare tuple with properties for names.
class MyClassAttributes(tuple):
__slots__ = ()
x = property(itemgetter(0))
"""
attr_class_name = f"{cls_name}Attributes"
attr_class_template = [
f"class {attr_class_name}(tuple):",
" __slots__ = ()",
]
if attr_names:
for i, attr_name in enumerate(attr_names):
attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
else:
attr_class_template.append(" pass")
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
_compile_and_eval("\n".join(attr_class_template), globs)
return globs[attr_class_name]
def generate_unique_filename(cls: type[t.Any], func_name: str):
return f"<{cls.__name__} generated {func_name} {cls.__module__}." f"{getattr(cls, '__qualname__', cls.__name__)}>"
@@ -203,7 +257,7 @@ def generate_function(
globs: dict[str, t.Any],
annotations: dict[str, t.Any] | None = None,
):
from . import DEBUG
from . import SHOW_CODEGEN
script = "def %s(%s):\n %s\n" % (
func_name,
@@ -214,7 +268,7 @@ def generate_function(
if annotations:
meth.__annotations__ = annotations
if DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3:
if SHOW_CODEGEN:
logger.info("Generated script for %s:\n\n%s", typ, script)
return meth
@@ -227,7 +281,8 @@ def make_env_transformer(
default_callback: t.Callable[[str, t.Any], t.Any] | None = None,
globs: DictStrAny | None = None,
):
from . import dantic, field_env_key
from . import dantic
from . import field_env_key
def identity(_: str, x_value: t.Any) -> t.Any:
return x_value
@@ -268,3 +323,35 @@ def make_env_transformer(
globs=globs,
annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann},
)
def gen_sdk(func: t.Callable[P, t.Any], name: str | None = None, **attrs: t.Any):
from .representation import ReprMixin
if name is None:
name = func.__name__.strip("_")
_signatures = inspect.signature(func).parameters
def _repr(self: ReprMixin) -> str:
return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]:
return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
return functools.update_wrapper(
types.new_class(
name,
(PartialAny, ReprMixin),
exec_body=lambda ns: ns.update(
{
"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]),
"__repr_args__": _repr_args,
"__repr__": _repr,
"__doc__": inspect.cleandoc(t.cast(str, func.__doc__)),
"__module__": "openllm",
}
),
)(func, **attrs),
func,
)

View File

@@ -14,71 +14,40 @@
"""A shim provides usable transition from pydantic to attrs."""
from __future__ import annotations
import functools
import importlib
import os
import sys
import typing as t
from enum import Enum
import attr
import click
import sys
import click_option_group as cog
import inflection
import orjson
from click import ParamType, shell_completion as sc, types as click_types
from click import ParamType
from click import shell_completion as sc
from click import types as click_types
import openllm
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
from attr import _ValidatorType
from .._types import ClickFunctionWrapper
from .._types import F
from .._types import O_co
from .._types import P
from .._types import ListAny
_T = t.TypeVar("_T")
@overload
def attrs_to_options(
name: str,
field: attr.Attribute[t.Any],
model_name: str,
typ: type[t.Any] | None = None,
suffix_generation: bool = False,
) -> F[..., F[..., openllm.LLMConfig]]:
...
@overload
def attrs_to_options( # type: ignore (overlapping overload)
name: str,
field: attr.Attribute[O_co],
model_name: str,
typ: type[t.Any] | None = None,
suffix_generation: bool = False,
) -> F[..., F[P, O_co]]:
...
def attrs_to_options(
name: str,
field: attr.Attribute[t.Any],
model_name: str,
typ: type[t.Any] | None = None,
suffix_generation: bool = False,
) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]:
):
# TODO: support parsing nested attrs class and Union
envvar = field.metadata["env"]
dasherized = inflection.dasherize(name)
@@ -86,6 +55,8 @@ def attrs_to_options(
if typ in (None, attr.NOTHING):
typ = field.type
if typ is None:
raise RuntimeError(f"Failed to parse type for {name}")
full_option_name = f"--{dasherized}"
if field.type is bool:
@@ -116,7 +87,7 @@ def env_converter(value: t.Any, env: str | None = None) -> t.Any:
try:
return orjson.loads(value.lower())
except orjson.JSONDecodeError as err:
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}")
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
return value
@@ -132,14 +103,16 @@ def Field(
use_default_converter: bool = True,
**attrs: t.Any,
):
"""A decorator that extends attr.field with additional arguments, which provides the same
interface as pydantic's Field.
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
By default, if both validator and ge are provided, then then ge will be
piped into first, then all of the other validator will be run afterwards.
Args:
default: The default value for ``dantic.Field``. Defaults to ``None``.
ge: Greater than or equal to. Defaults to None.
le: Less than or equal to. Defaults to None.
validator: Optional attrs-compatible validators type. Default to None
description: the documentation for the field. Defaults to None.
env: the environment variable to read from. Defaults to None.
auto_default: a bool indicating whether to use the default value as the environment.
@@ -150,7 +123,7 @@ def Field(
to True. If set to False, then the default converter will not be used.
The default converter converts a given value from the environment variable
for this given Field.
**kwargs: The rest of the arguments are passed to attr.field
**attrs: The rest of the arguments are passed to attr.field
"""
metadata = attrs.pop("metadata", {})
if description is None:
@@ -205,13 +178,14 @@ def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]:
"""
from . import lenient_issubclass
assert t.get_origin(field_type) is not t.Union, "Unions are not supported"
if t.get_origin(field_type) is t.Union:
raise NotImplementedError("Unions are not supported")
# enumeration strings or other Enum derivatives
if lenient_issubclass(field_type, Enum):
return EnumChoice(enum=field_type, case_sensitive=True)
# literals are enum-like with way less functionality
if is_literal(field_type):
return LiteralChoice(enum=field_type, case_sensitive=True)
return LiteralChoice(value=field_type, case_sensitive=True)
# modules, classes, functions
if is_typing(field_type):
return ModuleType()
@@ -248,6 +222,7 @@ def is_typing(field_type: type) -> bool:
def is_literal(field_type: type) -> bool:
"""Checks whether the given field type is a Literal type or not.
Literals are weird: isinstance and subclass do not work, so you compare
the origin with the Literal declaration itself.
@@ -266,63 +241,76 @@ class ModuleType(ParamType):
def _import_object(self, value: str) -> t.Any:
module_name, class_name = value.rsplit(".", maxsplit=1)
assert all(s.isidentifier() for s in module_name.split(".")), f"'{value}' is not a valid module name"
assert class_name.isidentifier(), f"Variable '{class_name}' is not a valid identifier"
if not all(s.isidentifier() for s in module_name.split(".")):
raise ValueError(f"'{value}' is not a valid module name")
if not class_name.isidentifier():
raise ValueError(f"Variable '{class_name}' is not a valid identifier")
module = importlib.import_module(module_name)
if class_name:
try:
return getattr(module, class_name)
except AttributeError:
raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.")
raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
return None
def convert(self, value: str, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
try:
if isinstance(value, str):
return self._import_object(value)
return value
except Exception as exc:
self.fail(f"'{value}' is not a valid object ({type(exc)}: {str(exc)})", param, ctx)
self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
class EnumChoice(click.Choice):
name = "enum"
def __init__(self, enum: Enum, case_sensitive: bool = False):
"""Enum type support for click that extends ``click.Choice``.
Args:
enum: Given enum
case_sensitive: Whether this choice should be case case_sensitive.
"""
self.mapping = enum
self.internal_type = enum
super().__init__([e.name for e in self.mapping], case_sensitive)
self.internal_type = type(enum)
choices: ListAny = [e.name for e in enum.__class__]
super().__init__(choices, case_sensitive)
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
if isinstance(value, self.internal_type):
return value
result = super().convert(value, param, ctx)
if isinstance(result, str):
result = self.mapping[result]
result = self.internal_type[result]
return result
class LiteralChoice(EnumChoice):
name = "literal"
def __init__(self, enum: t.LiteralString, case_sensitive: bool = False):
def __init__(self, value: t.Any, case_sensitive: bool = False):
"""Literal support for click."""
# expect every literal value to belong to the same primitive type
values = list(enum.__args__)
values = list(value.__args__)
item_type = type(values[0])
assert all(isinstance(v, item_type) for v in values), f"Field {enum} contains items of different types"
if not all(isinstance(v, item_type) for v in values):
raise ValueError(f"Field {value} contains items of different types.")
self.internal_type = item_type
self.mapping = {str(v): v for v in values}
super(EnumChoice, self).__init__(list(self.mapping.keys()), case_sensitive)
def allows_multiple(field_type: t.Any) -> bool:
def allows_multiple(field_type: type) -> bool:
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
For containers, it exploits click's support for lists and such to use the same option multiple times
to create a complex object: `python run.py --subsets train --subsets test`
# becomes `subsets: ["train", "test"]`.
Args:
field_type (type): pydantic type
field_type: pydantic type.
Returns:
bool: true if it's a composite field (lists, containers and so on), false otherwise
@@ -360,8 +348,7 @@ def is_mapping(field_type: type) -> bool:
def is_container(field_type: type) -> bool:
"""Checks whether the current type is a container type ('contains' other types), like
lists and tuples.
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
Args:
field_type: pydantic field type
@@ -391,12 +378,13 @@ def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType
Returns:
ParamType | tuple[ParamType]: single click-compatible type or a tuple
"""
assert is_container(field_type), "Field type is not a container"
if not is_container(field_type):
raise ValueError("Field type is not a container type.")
args = t.get_args(field_type)
# Early out for untyped containers: standard lists, tuples, List[Any]
# Use strings when the type is unknown, avoid click's type guessing
if len(args) == 0:
return str
return click_types.convert_type(str)
# Early out for homogenous containers: Tuple[int], List[str]
if len(args) == 1:
return parse_single_arg(args[0])
@@ -409,6 +397,7 @@ def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType
def parse_single_arg(arg: type) -> ParamType:
"""Returns the click-compatible type for container origin types.
In this case, returns string when it's not inferrable, a JSON for mappings
and the original type itself in every other case (ints, floats and so on).
Bytes is a special case, not natively handled by click.
@@ -421,13 +410,13 @@ def parse_single_arg(arg: type) -> ParamType:
"""
# When we don't know the type, we choose 'str'
if arg is t.Any:
return str
return click_types.convert_type(str)
# For containers and nested models, we use JSON
if is_container(arg):
return JsonType()
if openllm.utils.lenient_issubclass(arg, bytes):
return BytesType()
return arg
return click_types.convert_type(arg)
class BytesType(ParamType):
@@ -439,7 +428,7 @@ class BytesType(ParamType):
try:
return str.encode(value)
except Exception as exc:
self.fail(f"'{value}' is not a valid string ({str(exc)})", param, ctx)
self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
CYGWIN = sys.platform.startswith("cygwin")
@@ -470,17 +459,14 @@ class CudaValueType(ParamType):
return var
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
"""Return a list of
:class:`~click.shell_completion.CompletionItem` objects for the
incomplete value. Most types do not provide completions, but
some do, and this allows custom types to provide custom
completions as well.
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
:param ctx: Invocation context for this command.
:param param: The parameter that is requesting completion.
:param incomplete: Value being completed. May be empty.
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
.. versionadded:: 8.0
Args:
ctx: Invocation context for this command.
param: The parameter that is requesting completion.
incomplete: Value being completed. May be empty.
"""
from ..utils import gpu_count
@@ -506,6 +492,7 @@ class CudaValueType(ParamType):
return tuple(self.typ(x, param, ctx) for x in value.split(","))
def __repr__(self) -> str:
"""CUDA is a click.STRING extension."""
return "STRING"
@@ -516,6 +503,11 @@ class JsonType(ParamType):
name = "json"
def __init__(self, should_load: bool = True) -> None:
"""Support JSON type for click.ParamType.
Args:
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
"""
super().__init__()
self.should_load = should_load
@@ -525,4 +517,4 @@ class JsonType(ParamType):
try:
return orjson.loads(value)
except orjson.JSONDecodeError as exc:
self.fail(f"'{value}' is not a valid JSON string ({str(exc)})", param, ctx)
self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ..utils import DummyMetaclass

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ..utils import DummyMetaclass

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ..utils import DummyMetaclass

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ..utils import DummyMetaclass

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ..utils import DummyMetaclass

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from ..utils import DummyMetaclass

View File

@@ -12,17 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons.
"""
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
from __future__ import annotations
import functools
import importlib
import importlib.metadata
import importlib.util
import logging
import os
import sys
import typing as t
from abc import ABCMeta
from collections import OrderedDict
@@ -38,15 +36,21 @@ from .representation import ReprMixin
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
if sys.version_info[:2] >= (3, 11):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
from .._types import LiteralRuntime
from .._types import P
class _AnnotatedLazyLoader(LazyLoader):
DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[..., t.LiteralString]
else:
_AnnotatedLazyLoader = LazyLoader
BackendOrderredDict = OrderedDict
logger = logging.getLogger(__name__)
@@ -188,7 +192,7 @@ def is_tf_available():
_tf_available = _tf_version is not None
if _tf_available:
if _tf_version and version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. OpenLLM only supports TF 2.x")
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
_tf_available = False
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
@@ -321,8 +325,9 @@ BACKENDS_MAPPING = BackendOrderredDict(
class DummyMetaclass(ABCMeta):
"""Metaclass for dummy object. It will raises ImportError
generated by ``require_backends`` if users try to access attributes from given class
"""Metaclass for dummy object.
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
"""
_backends: t.List[str]
@@ -368,7 +373,7 @@ class EnvVarMixin(ReprMixin):
bettertransformer: str
runtime: t.Literal["ggml", "transformers"]
framework_value: t.Literal["pt", "tf", "flax"]
framework_value: LiteralRuntime
quantize_value: str | None
bettertransformer_value: str | None
runtime_value: t.Literal["ggml", "transformers"]
@@ -385,17 +390,17 @@ class EnvVarMixin(ReprMixin):
@overload
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal['runtime']) -> str: ...
def __getitem__(self, item: t.Literal["runtime"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal['framework_value']) -> t.Literal['pt', 'tf', 'flax']: ...
def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ...
@overload
def __getitem__(self, item: t.Literal['quantize_value']) -> str | None: ...
def __getitem__(self, item: t.Literal["quantize_value"]) -> str | None: ...
@overload
def __getitem__(self, item: t.Literal['model_id_value']) -> str | None: ...
def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ...
@overload
def __getitem__(self, item: t.Literal['bettertransformer_value']) -> str | None: ...
def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> str | None: ...
@overload
def __getitem__(self, item: t.Literal['runtime_value']) -> t.Literal['ggml', 'transformers']: ...
def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
# fmt: on
def __getitem__(self, item: str | t.Any) -> t.Any:
if hasattr(self, item):
@@ -409,8 +414,8 @@ class EnvVarMixin(ReprMixin):
quantize: t.LiteralString | None = None,
runtime: t.Literal["ggml", "transformers"] = "transformers",
):
from .._configuration import field_env_key
from . import codegen
from .._configuration import field_env_key
model_name = inflection.underscore(model_name)
@@ -464,5 +469,5 @@ class EnvVarMixin(ReprMixin):
return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
@property
def module(self) -> LazyLoader:
return LazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
def module(self):
return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import importlib
import importlib.machinery
import itertools
@@ -40,10 +39,9 @@ _reserved_namespace = {"__openllm_special__", "__openllm_migration__"}
class LazyModule(types.ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
This is a direct port from transformers.utils.import_utils._LazyModule for
backwards compatibility with transformers < 4.18
"""Module class that surfaces all objects but only performs associated imports when the objects are requested.
This is a direct port from transformers.utils.import_utils._LazyModule for backwards compatibility with transformers < 4.18.
This is an extension a more powerful LazyLoader.
"""
@@ -56,8 +54,22 @@ class LazyModule(types.ModuleType):
module_file: str,
import_structure: dict[str, list[str]],
module_spec: importlib.machinery.ModuleSpec | None = None,
doc: str | None = None,
extra_objects: dict[str, t.Any] | None = None,
):
"""Lazily load this module as an object.
It does instantiate a __all__ and __dir__ for IDE support
Args:
name: module name
module_file: the given file. Often default to 'globals()['__file__']'
import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module'
module_spec: __spec__ of the lazily loaded module
doc: Optional docstring for this module.
extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well
as any locals() functions
"""
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module: dict[str, str] = {}
@@ -70,24 +82,22 @@ class LazyModule(types.ModuleType):
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self.__doc__ = doc
self._objects = _extra_objects
self._name = name
self._import_structure = import_structure
# Needed for autocompletion in an IDE
def __dir__(self):
"""Needed for autocompletion in an IDE."""
result = t.cast("list[str]", super().__dir__())
# The elements of self.__all__ that are submodules may or
# may not be in the dir already, depending on whether
# they have been accessed or not. So we only add the
# elements of self.__all__ that are not already in the dir.
for attribute in self.__all__:
if attribute not in result:
result.append(attribute)
return result
return result + [i for i in self.__all__ if i not in result]
def __getitem__(self, key: str) -> t.Any:
# currently, this is reserved to only internal uses and users shouldn't use this.
"""This is reserved to only internal uses and users shouldn't use this."""
if self._objects.get("__openllm_special__") is None:
raise UsageNotAllowedError(f"'{self._name}' is not allowed to be used as a dict.")
_special_mapping = self._objects.get("__openllm_special__", {})
@@ -101,6 +111,10 @@ class LazyModule(types.ModuleType):
raise KeyError(f"Failed to lookup '{key}' in '{self._name}'") from e
def __getattr__(self, name: str) -> t.Any:
"""Equivocal __getattr__ implementation.
It checks from _objects > _modules and does it recursively.
"""
if name in _reserved_namespace:
raise ForbiddenAttributeError(
f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified."
@@ -111,6 +125,7 @@ class LazyModule(types.ModuleType):
warnings.warn(
f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead",
DeprecationWarning,
stacklevel=3,
)
return getattr(self, cur_value)
if name in self._objects:
@@ -136,4 +151,5 @@ class LazyModule(types.ModuleType):
) from e
def __reduce__(self):
"""This is to ensure any given module is pickle-able."""
return (self.__class__, (self._name, self.__file__, self._import_structure))

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import typing as t
from abc import abstractmethod
@@ -27,6 +26,7 @@ if t.TYPE_CHECKING:
class ReprMixin:
"""This class display possible representation of given class.
It can be used for implementing __rich_pretty__ and __pretty__ methods in the future.
Most subclass needs to implement a __repr_keys__ property.
@@ -41,12 +41,20 @@ class ReprMixin:
"""This can be overriden by base class using this mixin."""
def __repr__(self) -> str:
"""The `__repr__` for any subclass of Mixin.
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
"""
from . import bentoml_cattr
serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}
return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}"
def __str__(self) -> str:
"""The string representation of the given Mixin subclass.
It will contains all of the attributes from __repr_keys__
"""
return self.__repr_str__(" ")
def __repr_name__(self) -> str:
@@ -54,7 +62,12 @@ class ReprMixin:
return self.__class__.__name__
def __repr_str__(self, join_str: str) -> str:
return join_str.join(repr(v) if a is None else f"{a}={repr(v)}" for a, v in self.__repr_args__())
"""To be used with __str__."""
return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
def __repr_args__(self) -> ReprArgs:
"""This can also be overriden by base class using this mixin.
By default it does a getattr of the current object from __repr_keys__.
"""
return ((k, getattr(self, k)) for k in self.__repr_keys__)

View File

@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The actual client implementation. Use ``openllm.client`` instead.
"""The actual client implementation.
Use ``openllm.client`` instead.
This holds the implementation of the client, which is used to communicate with the
OpenLLM server. It is used to send requests to the server, and receive responses.
"""

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
import attr
@@ -42,7 +41,7 @@ class PromptTemplate:
input_variables: t.Sequence[str]
def to_str(self, __partial_dict__: PartialDict | None = None, **attrs: str) -> str:
"""Generate a prompt from the template and input variables"""
"""Generate a prompt from the template and input variables."""
if __partial_dict__:
return _default_formatter.vformat(self.template, (), __partial_dict__)
if not attrs:
@@ -58,7 +57,7 @@ class PromptTemplate:
@classmethod
def from_default(cls, model: str, /, **prompt_attrs: t.Any) -> PromptTemplate:
template = getattr(openllm.utils.EnvVarMixin(model).module, "DEFAULT_PROMPT_TEMPLATE")
template = openllm.utils.EnvVarMixin(model).module.DEFAULT_PROMPT_TEMPLATE
if template is None:
raise ValueError(f"Model {model} does not have a default prompt template.")
if callable(template):

View File

@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client that supports REST/gRPC protocol to interact with a LLMServer."""
from .grpc import AsyncGrpcClient as AsyncGrpcClient
from .grpc import GrpcClient as GrpcClient
from .http import AsyncHTTPClient as AsyncHTTPClient
from .http import HTTPClient as HTTPClient

View File

@@ -13,29 +13,30 @@
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
import sys
import typing as t
from abc import abstractmethod
from http import HTTPStatus
from urllib.parse import urljoin
import httpx
import bentoml
import openllm
import logging
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
if sys.version_info[:2] >= (3, 11):
from typing import overload
else:
from typing_extensions import overload
if t.TYPE_CHECKING:
import transformers
from openllm.models.auto.factory import _BaseAutoLLMClass
from openllm._types import LiteralRuntime
class AnnotatedClient(bentoml.client.Client):
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any:
@@ -44,12 +45,6 @@ if t.TYPE_CHECKING:
async def async_health(self) -> t.Any:
...
def call(self, name: str, inputs: t.Any, **attrs: t.Any) -> t.Any:
...
async def acall(self, name: str, inputs: t.Any, **attrs: t.Any) -> t.Any:
...
def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]:
...
@@ -70,7 +65,10 @@ def in_async_context() -> bool:
return False
class ClientMixin:
T = t.TypeVar("T")
class ClientMeta(t.Generic[T]):
_api_version: str
_client_class: type[bentoml.client.Client]
@@ -84,9 +82,9 @@ class ClientMixin:
def __init__(self, address: str, timeout: int = 30):
self._address = address
self._timeout = timeout
assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation."
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
"""Initialise subclass for HTTP and gRPC client type."""
cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
cls._api_version = api_version
@@ -102,7 +100,7 @@ class ClientMixin:
return self.__agent__
@property
def _metadata(self) -> dict[str, t.Any]:
def _metadata(self) -> T:
if in_async_context():
return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
return self.call("metadata")
@@ -114,7 +112,7 @@ class ClientMixin:
@property
@abstractmethod
def framework(self) -> t.Literal["pt", "flax", "tf"]:
def framework(self) -> LiteralRuntime:
raise NotImplementedError
@property
@@ -135,10 +133,7 @@ class ClientMixin:
@property
def llm(self) -> openllm.LLM[t.Any, t.Any]:
if self.__llm__ is None:
self.__llm__ = t.cast(
"_BaseAutoLLMClass",
openllm[self.framework], # type: ignore (internal API)
).for_model(self.model_name)
self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
return self.__llm__
@property
@@ -171,7 +166,7 @@ class ClientMixin:
...
class BaseClient(ClientMixin):
class BaseClient(ClientMeta[T]):
def health(self) -> t.Any:
raise NotImplementedError
@@ -183,22 +178,32 @@ class BaseClient(ClientMixin):
def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]:
...
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str:
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, **attrs)
@overload
def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput:
...
def query(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | dict[str, t.Any] | str:
# NOTE: We set use_default_prompt_template to False for now.
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
return_attrs = attrs.pop("return_attrs", False)
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(
prompt, use_default_prompt_template=use_default_prompt_template, **attrs
)
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
if in_async_context():
result = httpx.post(
urljoin(self._address, f"/{self._api_version}/generate"),
json=openllm.utils.bentoml_cattr.unstructure(inputs),
json=inputs.model_dump(),
timeout=self.timeout,
).json()
else:
result = self.call("generate", inputs)
result = self.call("generate", inputs.model_dump())
r = self.postprocess(result)
if return_attrs:
return r
if return_raw_response:
return openllm.utils.bentoml_cattr.unstructure(r)
return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
def ask_agent(
@@ -235,10 +240,21 @@ class BaseClient(ClientMixin):
raise NotImplementedError
class BaseAsyncClient(ClientMixin):
class BaseAsyncClient(ClientMeta[T]):
async def health(self) -> t.Any:
raise NotImplementedError
@overload
async def query(
self,
prompt: str,
*,
return_attrs: t.Literal[True] = True,
return_raw_response: bool | None = ...,
**attrs: t.Any,
) -> openllm.GenerationOutput:
...
@overload
async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
...
@@ -249,19 +265,21 @@ class BaseAsyncClient(ClientMixin):
) -> dict[str, t.Any]:
...
async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str:
async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str | openllm.GenerationOutput:
# NOTE: We set use_default_prompt_template to False for now.
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
return_attrs = attrs.pop("return_attrs", False)
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(
prompt, use_default_prompt_template=use_default_prompt_template, **attrs
)
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
res = await self.acall("generate", inputs)
res = await self.acall("generate", inputs.model_dump())
r = self.postprocess(res)
if return_attrs:
return r
if return_raw_response:
return openllm.utils.bentoml_cattr.unstructure(r)
return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
async def ask_agent(
@@ -273,13 +291,17 @@ class BaseAsyncClient(ClientMixin):
agent_type: t.LiteralString = "hf",
**attrs: t.Any,
) -> t.Any:
"""Async version of agent.run"""
"""Async version of agent.run."""
if agent_type == "hf":
return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else:
raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if not openllm.utils.is_transformers_supports_agent():
raise RuntimeError(
"This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0"
)
if len(args) > 1:
raise ValueError("'args' should only take one positional argument.")
task = kwargs.pop("task", args[0])
@@ -293,7 +315,7 @@ class BaseAsyncClient(ClientMixin):
_hf_agent = self._hf_agent
prompt = _hf_agent.format_prompt(task)
prompt = t.cast(str, _hf_agent.format_prompt(task))
stop = ["Task:"]
async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client:
response = await client.post(
@@ -303,7 +325,7 @@ class BaseAsyncClient(ClientMixin):
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
},
)
if response.status_code != 200:
if response.status_code != HTTPStatus.OK:
raise ValueError(f"Error {response.status_code}: {response.json()}")
result = response.json()[0]["generated_text"]

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
import typing as t
@@ -27,46 +26,51 @@ from .base import BaseClient
if t.TYPE_CHECKING:
import grpc_health.v1.health_pb2 as health_pb2
from grpc_health.v1 import health_pb2
from bentoml.grpc.v1.service_pb2 import Response
from openllm._types import LiteralRuntime
logger = logging.getLogger(__name__)
class GrpcClientMixin:
_metadata: Response
if t.TYPE_CHECKING:
@property
def _metadata(self) -> Response:
...
@property
def model_name(self) -> str:
try:
return self._metadata.json.struct_value.fields["model_name"].string_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def framework(self) -> t.Literal["pt", "flax", "tf"]:
def framework(self) -> LiteralRuntime:
try:
value = self._metadata.json.struct_value.fields["framework"].string_value
if value not in ("pt", "flax", "tf"):
raise KeyError
return value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def timeout(self) -> int:
try:
return int(self._metadata.json.struct_value.fields["timeout"].number_value)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def model_id(self) -> str:
try:
return self._metadata.json.struct_value.fields["model_id"].string_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def configuration(self) -> dict[str, t.Any]:
@@ -74,7 +78,7 @@ class GrpcClientMixin:
v = self._metadata.json.struct_value.fields["configuration"].string_value
return orjson.loads(v)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
if isinstance(result, dict):
@@ -85,7 +89,7 @@ class GrpcClientMixin:
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"):
class GrpcClient(GrpcClientMixin, BaseClient["Response"], client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)
@@ -94,7 +98,7 @@ class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"):
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient, client_type="grpc"):
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient["Response"], client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
from urllib.parse import urlparse
@@ -26,52 +25,63 @@ from .base import BaseAsyncClient
from .base import BaseClient
if t.TYPE_CHECKING:
from openllm._types import DictStrAny
from openllm._types import LiteralRuntime
else:
DictStrAny = dict
logger = logging.getLogger(__name__)
class HTTPClientMixin:
_metadata: dict[str, t.Any]
if t.TYPE_CHECKING:
@property
def _metadata(self) -> DictStrAny:
...
@property
def model_name(self) -> str:
try:
return self._metadata["model_name"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def model_id(self) -> str:
try:
return self._metadata["model_name"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def framework(self) -> t.Literal["pt", "flax", "tf"]:
def framework(self) -> LiteralRuntime:
try:
return self._metadata["framework"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def timeout(self) -> int:
try:
return self._metadata["timeout"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def configuration(self) -> dict[str, t.Any]:
try:
return orjson.loads(self._metadata["configuration"])
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
return openllm.GenerationOutput(**result)
class HTTPClient(HTTPClientMixin, BaseClient):
class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]):
def __init__(self, address: str, timeout: int = 30):
address = address if "://" in address else "http://" + address
self._host, self._port = urlparse(address).netloc.split(":")
@@ -81,7 +91,7 @@ class HTTPClient(HTTPClientMixin, BaseClient):
return self._cached.health()
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient):
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]):
def __init__(self, address: str, timeout: int = 30):
address = address if "://" in address else "http://" + address
self._host, self._port = urlparse(address).netloc.split(":")

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import logging
import typing as t
@@ -62,11 +61,11 @@ def make_llm_config(
lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
if fields is not None:
for field, type_, default in fields:
lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({repr(default)})")
lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})")
if generation_fields is not None:
generation_lines = ["class GenerationConfig:"]
for field, default in generation_fields:
generation_lines.append(f" {field} = {repr(default)}")
generation_lines.append(f" {field} = {default!r}")
lines.extend((" " + line for line in generation_lines))
script = "\n".join(lines)

View File

@@ -13,9 +13,9 @@
# limitations under the License.
"""All configuration-related tests for openllm.LLMConfig. This will include testing
for ModelEnv construction and parsing environment variables."""
for ModelEnv construction and parsing environment variables.
"""
from __future__ import annotations
import contextlib
import logging
import os
@@ -125,29 +125,20 @@ def test_complex_struct_dump(
generation_fields=(("temperature", temperature),),
)
sent = cl_()
assert (
sent.model_dump()["field1"] == field1 and sent.model_dump()["generation_config"]["temperature"] == temperature
)
assert (
sent.model_dump(flatten=True)["field1"] == field1
and sent.model_dump(flatten=True)["temperature"] == temperature
)
assert sent.model_dump()["field1"] == field1
assert sent.model_dump()["generation_config"]["temperature"] == temperature
assert sent.model_dump(flatten=True)["field1"] == field1
assert sent.model_dump(flatten=True)["temperature"] == temperature
passed = cl_(field1=input_field1, temperature=input_temperature)
assert (
passed.model_dump()["field1"] == input_field1
and passed.model_dump()["generation_config"]["temperature"] == input_temperature
)
assert (
passed.model_dump(flatten=True)["field1"] == input_field1
and passed.model_dump(flatten=True)["temperature"] == input_temperature
)
assert passed.model_dump()["field1"] == input_field1
assert passed.model_dump()["generation_config"]["temperature"] == input_temperature
assert passed.model_dump(flatten=True)["field1"] == input_field1
assert passed.model_dump(flatten=True)["temperature"] == input_temperature
pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1)
assert (
pas_nested.model_dump()["field1"] == input_field1
and pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature
)
assert pas_nested.model_dump()["field1"] == input_field1
assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature
@contextlib.contextmanager
@@ -207,7 +198,7 @@ def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPat
@given(model_settings())
@pytest.mark.parametrize("return_dict,typ", [(True, DictStrAny), (False, transformers.GenerationConfig)])
@pytest.mark.parametrize(("return_dict", "typ"), [(True, DictStrAny), (False, transformers.GenerationConfig)])
def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings):
cl_ = make_llm_config("ConversionLLM", gen_settings)
assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)

View File

@@ -13,9 +13,55 @@
# limitations under the License.
from __future__ import annotations
import itertools
import typing as t
import pytest
import openllm
if t.TYPE_CHECKING:
from openllm._types import LiteralRuntime
_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m"}
_PROMPT_MAPPING = {
"qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",
}
def parametrise_local_llm(
model: str,
) -> t.Generator[tuple[str, openllm.LLMRunner | openllm.LLM[t.Any, t.Any]], None, None]:
if model not in _FRAMEWORK_MAPPING:
pytest.skip(f"'{model}' is not yet supported in framework testing.")
runtime_impl: tuple[LiteralRuntime, ...] = tuple()
if model in openllm.MODEL_MAPPING_NAMES:
runtime_impl += ("pt",)
if model in openllm.MODEL_FLAX_MAPPING_NAMES:
runtime_impl += ("flax",)
if model in openllm.MODEL_TF_MAPPING_NAMES:
runtime_impl += ("tf",)
for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()):
llm = openllm.Runner(
model,
model_id=_FRAMEWORK_MAPPING[model],
ensure_available=True,
implementation=framework,
init_local=True,
)
yield prompt, llm
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames:
metafunc.parametrize(
"prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]
)
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
# If no tests are collected, pytest exists with code 5, which makes the CI fail.

View File

@@ -13,12 +13,16 @@
# limitations under the License.
from __future__ import annotations
import typing as t
import openllm
import pytest
from openllm._llm import make_tag
if t.TYPE_CHECKING:
import pytest
HF_INTERNAL_T5_TESTING = "hf-internal-testing/tiny-random-t5"
@@ -31,21 +35,6 @@ def patch_hash_from_file(_: str, algorithm: t.LiteralString = "sha1") -> str:
return "d88a1a40e354a0c7fa6f9055938594e6a4c712e0"
def test_tag_generation_from_custom_path(
tmp_path_factory: pytest.TempPathFactory, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
):
monkeypatch.setattr(openllm._llm, "generate_hash_from_file", patch_hash_from_file)
local_path = tmp_path_factory.mktemp("local_t5")
llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
llm.save_pretrained(local_path)
with caplog.at_level("WARNING"):
tag = make_tag(local_path.resolve().__fspath__())
assert tag.version == "d88a1a40e354a0c7fa6f9055938594e6a4c712e0"
assert "Given 'model_id" in caplog.text
def test_tag_generation_quiet_log(tmp_path_factory: pytest.TempPathFactory, caplog: pytest.LogCaptureFixture):
local_path = tmp_path_factory.mktemp("local_t5")
llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
@@ -54,12 +43,3 @@ def test_tag_generation_quiet_log(tmp_path_factory: pytest.TempPathFactory, capl
with caplog.at_level("WARNING"):
make_tag(local_path.resolve().__fspath__(), quiet=True)
assert not caplog.text
def test_tag_generation_debug_log(caplog: pytest.LogCaptureFixture):
with caplog.at_level("DEBUG"):
make_tag(HF_INTERNAL_T5_TESTING)
assert (
"The full tag to be saved under model store: 'pt-hf-internal-testing-tiny-random-t5:2f582cd79ed5795b71539951d237945bc1c5ac7e'"
in caplog.text
)

View File

@@ -0,0 +1,34 @@
{
"configuration": {
"format_outputs": false,
"generation_config": {
"diversity_penalty": 0.0,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"encoder_repetition_penalty": 1.0,
"epsilon_cutoff": 0.0,
"eta_cutoff": 0.0,
"length_penalty": 1.0,
"max_new_tokens": 20,
"min_length": 0,
"no_repeat_ngram_size": 0,
"num_beam_groups": 1,
"num_beams": 1,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"remove_invalid_values": false,
"renormalize_logits": false,
"repetition_penalty": 1.0,
"temperature": 0.75,
"top_k": 15,
"top_p": 1.0,
"typical_p": 1.0,
"use_cache": true
}
},
"responses": [
"What is Deep learning?\nDeep learning is a new way of studying the content and making an informed decision. It is the"
]
}

View File

@@ -13,45 +13,303 @@
# limitations under the License.
from __future__ import annotations
import types
import asyncio
import contextlib
import functools
import logging
import sys
import time
import typing as t
from abc import ABC
from abc import abstractmethod
import attr
import docker
import docker.errors
import docker.types
import orjson
import pytest
from syrupy.extensions.json import JSONSnapshotExtension
import openllm
from openllm._llm import normalise_model_name
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from openllm.models.auto.factory import _BaseAutoLLMClass
import subprocess
_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m"}
_PROMPT_MAPPING = {
"qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",
"default": "What is the weather in SF?",
}
from openllm_client.runtimes.base import BaseAsyncClient
from syrupy.assertion import SnapshotAssertion
from syrupy.types import PropertyFilter
from syrupy.types import PropertyMatcher
from syrupy.types import SerializableData
from syrupy.types import SerializedData
from openllm._configuration import GenerationConfig
from openllm._types import DictStrAny
from openllm._types import ListAny
else:
DictStrAny = dict
ListAny = list
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
models, fname = t.cast(types.ModuleType, metafunc.module).__name__.partition(".")[-1].split(".")[1:]
class ResponseComparator(JSONSnapshotExtension):
def serialize(
self,
data: SerializableData,
*,
exclude: PropertyFilter | None = None,
matcher: PropertyMatcher | None = None,
) -> SerializedData:
if openllm.utils.LazyType(ListAny).isinstance(data):
data = [d.unmarshaled for d in data]
else:
data = data.unmarshaled
data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher)
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
if "tf" in fname:
framework = "tf"
elif "flax" in fname:
framework = "flax"
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
try:
data = orjson.loads(data)
except orjson.JSONDecodeError as err:
raise ValueError(f"Failed to decode JSON data: {data}") from err
if openllm.utils.LazyType(DictStrAny).isinstance(data):
return openllm.GenerationOutput(**data)
elif openllm.utils.LazyType(ListAny).isinstance(data):
return [openllm.GenerationOutput(**d) for d in data]
else:
raise NotImplementedError(f"Data {data} has unsupported type.")
serialized_data = convert_data(serialized_data)
snapshot_data = convert_data(snapshot_data)
if openllm.utils.LazyType(ListAny).isinstance(serialized_data):
serialized_data = [serialized_data]
if openllm.utils.LazyType(ListAny).isinstance(snapshot_data):
snapshot_data = [snapshot_data]
def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool:
return s == t
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
return (
len(s.responses) == len(t.responses)
and all([_s == _t for _s, _t in zip(s.responses, t.responses)])
and eq_config(s.marshaled_config, t.marshaled_config)
)
return len(serialized_data) == len(snapshot_data) and all(
[eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)]
)
@pytest.fixture()
def response_snapshot(snapshot: SnapshotAssertion):
return snapshot.use_extension(ResponseComparator)
@attr.define(init=False)
class _Handle(ABC):
port: int
deployment_mode: t.Literal["container", "local"]
client: BaseAsyncClient[t.Any] = attr.field(init=False)
if t.TYPE_CHECKING:
def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
...
def __attrs_post_init__(self):
self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}")
@abstractmethod
def status(self) -> bool:
raise NotImplementedError
async def health(self, timeout: int = 240):
start_time = time.time()
while time.time() - start_time < timeout:
if not self.status():
raise RuntimeError(f"Failed to initialise {self.__class__.__name__}")
await self.client.health()
try:
await self.client.query("sanity")
return
except Exception:
time.sleep(1)
raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.")
@attr.define(init=False)
class LocalHandle(_Handle):
process: subprocess.Popen[bytes]
def __init__(
self,
process: subprocess.Popen[bytes],
port: int,
deployment_mode: t.Literal["container", "local"],
):
self.__attrs_init__(port, deployment_mode, process)
def status(self) -> bool:
return self.process.poll() is None
class HandleProtocol(t.Protocol):
@contextlib.contextmanager
def __call__(
*,
model: str,
model_id: str,
image_tag: str,
quantize: t.AnyStr | None = None,
) -> t.Generator[_Handle, None, None]:
...
@attr.define(init=False)
class DockerHandle(_Handle):
container_name: str
docker_client: docker.DockerClient
def __init__(
self,
docker_client: docker.DockerClient,
container_name: str,
port: int,
deployment_mode: t.Literal["container", "local"],
):
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
def status(self) -> bool:
container = self.docker_client.containers.get(self.container_name)
return container.status in ["running", "created"]
@contextlib.contextmanager
def _local_handle(
model: str,
model_id: str,
image_tag: str,
deployment_mode: t.Literal["container", "local"],
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
*,
_serve_grpc: bool = False,
):
with openllm.utils.reserve_free_port() as port:
pass
if not _serve_grpc:
proc = openllm.start(
model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True
)
else:
framework = "pt"
proc = openllm.start_grpc(
model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True
)
llm, runner_kwargs = t.cast(
"_BaseAutoLLMClass",
openllm[framework], # type: ignore
).for_model(models, model_id=_FRAMEWORK_MAPPING[models], return_runner_kwargs=True, ensure_available=True)
llm.ensure_model_id_exists()
if "runner" in metafunc.function.__name__:
llm = llm.to_runner(**runner_kwargs)
llm.init_local(quiet=True)
yield LocalHandle(proc, port, deployment_mode)
proc.terminate()
proc.wait(60)
if "qa" in metafunc.fixturenames:
metafunc.parametrize("prompt,llm,qa", [(_PROMPT_MAPPING["qa"], llm, True)])
process_output = proc.stdout.read()
print(process_output, file=sys.stderr)
proc.stdout.close()
if proc.stderr:
proc.stderr.close()
@contextlib.contextmanager
def _container_handle(
model: str,
model_id: str,
image_tag: str,
deployment_mode: t.Literal["container", "local"],
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
*,
_serve_grpc: bool = False,
):
envvar = openllm.utils.EnvVarMixin(model)
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
pass
container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_")
client = docker.from_env()
try:
container = client.containers.get(container_name)
container.stop()
container.wait()
container.remove()
except docker.errors.NotFound:
pass
args = ["serve" if not _serve_grpc else "serve-grpc"]
env: DictStrAny = {}
if quantize is not None:
env[envvar.quantize] = quantize
available = openllm.utils.gpu_count()
gpus = len(available) if len(available) > 0 else -1
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None
container = client.containers.run(
image_tag,
command=args,
name=container_name,
environment=env,
auto_remove=False,
detach=True,
device_requests=devs,
ports={"3000/tcp": port, "3001/tcp": prom_port},
)
yield DockerHandle(client, container.name, port, deployment_mode)
try:
container.stop()
container.wait()
except docker.errors.NotFound:
pass
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
container.remove()
@pytest.fixture(scope="session", autouse=True)
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
stack = contextlib.ExitStack()
yield stack
stack.close()
@pytest.fixture(scope="module")
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(params=["container", "local"], scope="session")
def deployment_mode(request: pytest.FixtureRequest) -> str:
return request.param
@pytest.fixture(scope="module")
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]):
if deployment_mode == "container":
return functools.partial(_container_handle, deployment_mode=deployment_mode)
elif deployment_mode == "local":
return functools.partial(_local_handle, deployment_mode=deployment_mode)
else:
metafunc.parametrize("prompt,llm", [(_PROMPT_MAPPING["default"], llm)])
raise ValueError(f"Unknown deployment mode: {deployment_mode}")

View File

@@ -1,13 +0,0 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -1,27 +0,0 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
import openllm
def test_small_flan(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool):
assert llm(prompt)
def test_small_runner_flan(prompt: str, llm: openllm.LLMRunner, qa: bool):
assert llm(prompt)

View File

@@ -1,29 +0,0 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
import openllm
@openllm.tests.require_tf
def test_small_tf_flan(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool):
assert llm(prompt)
@openllm.tests.require_tf
def test_small_tf_runner_flan(prompt: str, llm: openllm.LLMRunner, qa: bool):
assert llm(prompt)

Some files were not shown because too many files have changed in this diff Show More