From 3593c764f09bf8218bc46de17a4207ec22e13b07 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Sat, 24 Jun 2023 11:10:07 -0400 Subject: [PATCH] fix(test): robustness (#64) --- .github/actions/setup-repo/action.yml | 2 +- .github/workflows/ci.yml | 56 ++++++- DEVELOPMENT.md | 4 +- changelog.d/64.fix.md | 4 + hatch.toml | 18 ++- pyproject.toml | 8 +- src/openllm/__init__.py | 9 ++ src/openllm/_configuration.py | 149 ++++++++++++++---- src/openllm/_llm.py | 11 +- src/openllm/_trainer.py | 25 --- src/openllm/tests.py | 39 +++++ src/openllm/utils/import_utils.py | 4 +- tests/_strategies/_configuration.py | 2 +- .../flan_t5/test_modeling_flax_flan_t5.py | 43 +++++ .../flan_t5/test_modeling_tf_flan_t5.py | 43 +++++ tests/models/opt/test_modeling_flax_opt.py | 43 +++++ tests/models/opt/test_modeling_tf_opt.py | 43 +++++ tests/test_configuration.py | 102 +++++++++--- tools/generate-coverage.py | 54 +++++++ tools/update-config-stubs.py | 29 +++- tools/write-coverage-report.py | 53 +++++++ 21 files changed, 644 insertions(+), 97 deletions(-) create mode 100644 changelog.d/64.fix.md delete mode 100644 src/openllm/_trainer.py create mode 100644 src/openllm/tests.py create mode 100644 tests/models/flan_t5/test_modeling_flax_flan_t5.py create mode 100644 tests/models/flan_t5/test_modeling_tf_flan_t5.py create mode 100644 tests/models/opt/test_modeling_flax_opt.py create mode 100644 tests/models/opt/test_modeling_tf_opt.py create mode 100755 tools/generate-coverage.py create mode 100755 tools/write-coverage-report.py diff --git a/.github/actions/setup-repo/action.yml b/.github/actions/setup-repo/action.yml index 1afa5cc9..3aa0cf66 100644 --- a/.github/actions/setup-repo/action.yml +++ b/.github/actions/setup-repo/action.yml @@ -53,7 +53,7 @@ runs: ${{ steps.get-cache-key-prefix.outputs.prefix }}-pypi- - name: Install dependencies shell: bash - run: pip install hatch towncrier + run: pip install hatch towncrier && pip install -e ".[all]" - name: Install pyright shell: bash run: npm install -g npm@^7 pyright diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b34ac240..9b3e045d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,6 +23,7 @@ env: LINES: 120 COLUMNS: 120 OPENLLM_DO_NOT_TRACK: True + PYTHONUNBUFFERED: '1' # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun defaults: run: @@ -43,7 +44,60 @@ jobs: - name: Setup CI uses: ./.github/actions/setup-repo - name: Run tests - run: hatch run test:p + run: hatch run test:full + - name: Disambiguate coverage filename + run: mv .coverage ".coverage.${{ matrix.os }}" + - name: Upload coverage data + uses: actions/upload-artifact@v3 + with: + name: coverage-data + path: .coverage.* + coverage: + name: Coverage + runs-on: ubuntu-latest + needs: + - tests + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Setup CI + uses: ./.github/actions/setup-repo + - name: Download coverage data + uses: actions/download-artifact@v3 + with: + name: coverage-data + - name: Combine coverage data + run: hatch run coverage:combine + - name: Export coverage reports + run: | + hatch run coverage:report-xml + hatch run coverage:report-uncovered-html + - name: Upload uncovered HTML report + uses: actions/upload-artifact@v3 + with: + name: uncovered-html-report + path: htmlcov + - name: Generate coverage summary + run: hatch run coverage:generate-summary + - name: Write coverage summary report + if: github.event_name == 'pull_request' + run: hatch run coverage:write-summary-report + - name: Update coverage pull request comment + if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork + uses: marocchino/sticky-pull-request-comment@v2 + with: + path: coverage-report.md + check: # https://github.com/marketplace/actions/alls-green#why + if: always() + needs: + - coverage + runs-on: ubuntu-latest + steps: + - name: Decide whether the needed jobs succeeded or failed + uses: re-actors/alls-green@release/v1 + with: + jobs: ${{ toJSON(needs) }} concurrency: group: ci-${{ github.event.pull_request.number || github.sha }} cancel-in-progress: true diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 786a9b88..7269fce5 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -107,7 +107,7 @@ After setting up your environment, here's how you can start contributing: 5. Run all tests to ensure your changes haven't broken anything: ```bash - hatch run test:p + hatch run test:full ``` 6. Commit your changes: @@ -141,7 +141,7 @@ directory and their filenames start with `test_`. Run all tests with: ```bash -hatch run test:p +hatch run test:full ``` ## Releasing a New Version diff --git a/changelog.d/64.fix.md b/changelog.d/64.fix.md new file mode 100644 index 00000000..0f4fc32b --- /dev/null +++ b/changelog.d/64.fix.md @@ -0,0 +1,4 @@ +Remove duplicated class instance of `generation_config` as it should be set via +instance attributes. + +fixes tests flakiness and one broken cases for parsing env diff --git a/hatch.toml b/hatch.toml index 83b5d850..896b7d33 100644 --- a/hatch.toml +++ b/hatch.toml @@ -20,6 +20,7 @@ typing = "pyright {args:src/openllm tests}" dependencies = [ # NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy "coverage[toml]>=6.5", + "lxml", "pytest", "pytest-asyncio>=0.21.0", "pytest-xdist[psutil]", @@ -30,9 +31,20 @@ dependencies = [ "hypothesis", "syrupy", ] +[envs.test.overrides] +env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT=" +env.HERMETIC_TESTS.type = [{ value = "container", if = ["true"] }, "virtual"] [envs.test.scripts] -cov = ["cov-test", "- coverage combine", "coverage report"] -cov-test = "coverage run -m pytest {args:tests}" -p = "pytest {args:tests}" +_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml" +full = "_run_script -n 3 --reruns 5 --reruns-delay 3 -r aR {args:tests}" [[envs.test.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] +[envs.coverage] +dependencies = ["coverage[toml]>=6.5", "lxml", "orjson"] +detached = true +[envs.coverage.scripts] +combine = "coverage combine {args}" +generate-summary = "python tools/generate-coverage.py" +report-uncovered-html = "coverage html --skip-covered --skip-empty" +report-xml = "coverage xml" +write-summary-report = "python tools/write-coverage-report.py" diff --git a/pyproject.toml b/pyproject.toml index 76c7cf67..7039d976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,13 +220,11 @@ typeCheckingMode = "strict" [tool.coverage.run] branch = true -omit = ["src/openllm/__about__.py"] -parallel = true -source_pkgs = ["openllm", "tests"] +omit = ["src/openllm/__about__.py", "src/openllm/__main__.py", "src/openllm/tests.py", "src/openllm/utils/dummy_*.py"] +source_pkgs = ["openllm"] [tool.coverage.paths] openllm = ["src/openllm", "*/openllm/src/openllm"] -tests = ["tests", "*/openllm/tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:"] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:", "@overload", "# pragma: no cover"] diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index 94a66f88..711eb609 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -35,6 +35,13 @@ from .__about__ import __version__ as __version__ from .exceptions import MissingDependencyError +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if hasattr(t, "get_overloads"): + from typing import overload +else: + from typing_extensions import overload + if utils.DEBUG: utils.set_debug_mode(True) utils.set_quiet_mode(False) @@ -71,6 +78,7 @@ _import_structure = { "models": [], "client": [], "playground": [], + "tests": [], "cli": ["start", "start_grpc"], # NOTE: models "models.auto": [ @@ -164,6 +172,7 @@ if t.TYPE_CHECKING: from . import exceptions as exceptions from . import models as models from . import playground as playground + from . import tests as tests # Specific types import from ._configuration import LLMConfig as LLMConfig diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index e1a82de2..8a886d32 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -50,8 +50,6 @@ from __future__ import annotations import copy import enum -import functools -import inspect import logging import os import sys @@ -358,8 +356,8 @@ class FineTuneConfig: return klass -@attr.frozen(slots=True) -class GenerationConfig: +@attr.frozen(slots=True, repr=False) +class GenerationConfig(ReprMixin): """Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``, with some additional validation and environment constructor. @@ -609,6 +607,10 @@ class GenerationConfig: return getattr(self, item) raise KeyError(f"GenerationConfig has no attribute {item}") + @property + def __repr_keys__(self) -> set[str]: + return {i.name for i in attr.fields(self.__class__)} + bentoml_cattr.register_unstructure_hook_factory( lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig), @@ -623,7 +625,7 @@ bentoml_cattr.register_unstructure_hook_factory( ) -def _field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str: +def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str: return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key]))) @@ -731,13 +733,13 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]): ) _cl_name = cl_.__name__.replace("Config", "") - _settings_attr = _ModelSettingsAttr.default() - try: - cls(**t.cast(DictStrAny, cl_.__config__)) - _settings_attr = attr.evolve(_settings_attr, **t.cast(DictStrAny, cl_.__config__)) - except TypeError: + _settings_attr = cls.default() + + if any(i not in cl_.__config__ for i in {"default_id", "model_ids"}): raise ValueError("Either 'default_id' or 'model_ids' are emptied under '__config__' (required fields).") + _settings_attr = attr.evolve(_settings_attr, **t.cast(DictStrAny, cl_.__config__)) + _final_value_dct: DictStrAny = { "model_name": inflection.underscore(_cl_name) if _settings_attr["name_type"] == "dasherize" @@ -805,10 +807,9 @@ def _make_env_transformer( globs = {} if globs is None else globs globs.update( { - "functools": functools, "__populate_env": dantic.env_converter, "__default_callback": default_callback, - "__field_env": _field_env_key, + "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name, } @@ -1289,9 +1290,9 @@ class LLMConfig(_ConfigAttr): def add_repr(self): for key, fn in ReprMixin.__dict__.items(): - if key not in ("__module__", "__doc__", "__repr_keys__"): + if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"): self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn) - self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs}) + self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"}) return self def __init_subclass__(cls: type[LLMConfig]): @@ -1326,7 +1327,7 @@ class LLMConfig(_ConfigAttr): slots=True, weakref_slot=True, frozen=True, - repr=True, + repr=False, collect_by_mro=True, field_transformer=_make_env_transformer( cls, @@ -1356,9 +1357,9 @@ class LLMConfig(_ConfigAttr): val = cd.get(attr_name, attr.NOTHING) if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val): if val is attr.NOTHING: - val = cls.Field(env=_field_env_key(model_name, attr_name)) + val = cls.Field(env=field_env_key(model_name, attr_name)) else: - val = cls.Field(default=val, env=_field_env_key(model_name, attr_name)) + val = cls.Field(default=val, env=field_env_key(model_name, attr_name)) these[attr_name] = val unannotated = ca_names - annotated_names if len(unannotated) > 0: @@ -1371,13 +1372,6 @@ class LLMConfig(_ConfigAttr): cls.__openllm_accepted_keys__ = set(these.keys()) | { a.name for a in attr.fields(cls.__openllm_generation_class__) } - # 'generation_config' wraps the GenerationConfig class - # which is handled in _make_assignment_script - these["generation_config"] = cls.Field( - default=cls.__openllm_generation_class__(), - description=inspect.cleandoc(cls.__openllm_generation_class__.__doc__ or ""), - type=GenerationConfig, - ) cls = cls._ConfigBuilder(cls, model_name, these).add_attrs_init().add_repr().build_class() # auto assignment attributes generated from __config__ after create the new slot class. @@ -1390,7 +1384,6 @@ class LLMConfig(_ConfigAttr): if cls.__module__ in sys.modules: globs.update(sys.modules[cls.__module__].__dict__) attr.resolve_types(cls.__openllm_generation_class__, globalns=globs) - cls = attr.resolve_types(cls, globalns=globs) # the hint cache for easier access cls.__openllm_hints__ = { @@ -1428,20 +1421,15 @@ class LLMConfig(_ConfigAttr): for k in _cached_keys: if k in generation_config or attrs.get(k) is None: del attrs[k] - _cached_keys = tuple(k for k in _cached_keys if k in attrs) self.__openllm_extras__ = config_merger.merge( first_not_none(__openllm_extras__, default={}), {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__}, ) - - for k in _cached_keys: - if k in self.__openllm_extras__: - del attrs[k] - _cached_keys = tuple(k for k in _cached_keys if k in attrs) + self.generation_config = self["generation_class"](**generation_config) # The rest of attrs should only be the attributes to be passed to __attrs_init__ - self.__attrs_init__(generation_config=self["generation_class"](**generation_config), **attrs) + self.__attrs_init__(**attrs) # NOTE: These required fields should be at the top, as it will be kw_only @@ -1487,6 +1475,98 @@ class LLMConfig(_ConfigAttr): def __getitem__(self, item: t.Literal["generation_class"] = ...) -> t.Type[GenerationConfig]: ... @overload def __getitem__(self, item: t.Literal["extras"] = ...) -> t.Dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["max_new_tokens"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["min_length"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["min_new_tokens"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["early_stopping"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["max_time"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["num_beams"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["num_beam_groups"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["penalty_alpha"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["use_cache"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["temperature"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["top_k"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["top_p"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["typical_p"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["epsilon_cutoff"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["eta_cutoff"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["diversity_penalty"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["repetition_penalty"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["encoder_repetition_penalty"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["length_penalty"] = ...) -> float: ... + @overload + def __getitem__(self, item: t.Literal["no_repeat_ngram_size"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["bad_words_ids"] = ...) -> t.List[t.List[int]]: ... + @overload + def __getitem__(self, item: t.Literal["force_words_ids"] = ...) -> t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]]: ... + @overload + def __getitem__(self, item: t.Literal["renormalize_logits"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["constraints"] = ...) -> t.List[Constraint]: ... + @overload + def __getitem__(self, item: t.Literal["forced_bos_token_id"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["forced_eos_token_id"] = ...) -> t.Union[int, t.List[int]]: ... + @overload + def __getitem__(self, item: t.Literal["remove_invalid_values"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["exponential_decay_length_penalty"] = ...) -> t.Tuple[int, float]: ... + @overload + def __getitem__(self, item: t.Literal["suppress_tokens"] = ...) -> t.List[int]: ... + @overload + def __getitem__(self, item: t.Literal["begin_suppress_tokens"] = ...) -> t.List[int]: ... + @overload + def __getitem__(self, item: t.Literal["forced_decoder_ids"] = ...) -> t.List[t.List[int]]: ... + @overload + def __getitem__(self, item: t.Literal["num_return_sequences"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["output_attentions"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["output_hidden_states"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["output_scores"] = ...) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["pad_token_id"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["bos_token_id"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["eos_token_id"] = ...) -> t.Union[int, t.List[int]]: ... + @overload + def __getitem__(self, item: t.Literal["encoder_no_repeat_ngram_size"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["decoder_start_token_id"] = ...) -> int: ... + @overload + def __getitem__(self, item: t.Literal["prompt_tuning"] = ...) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["p_tuning"] = ...) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["prefix_tuning"] = ...) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["lora"] = ...) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["adalora"] = ...) -> dict[str, t.Any]: ... + @overload + def __getitem__(self, item: t.Literal["adaption_prompt"] = ...) -> dict[str, t.Any]: ... # update-config-stubs.py: stop # fmt: on @@ -1556,6 +1636,7 @@ class LLMConfig(_ConfigAttr): lambda ns: ns.update( { "__config__": config_merger.merge(copy.deepcopy(cls.__dict__["__config__"]), _new_cfg), + "__base_config__": cls, # keep a reference for easy access } ), ) @@ -1573,9 +1654,11 @@ class LLMConfig(_ConfigAttr): def model_dump(self, flatten: bool = False, **_: t.Any): dumped = bentoml_cattr.unstructure(self) + generation_config = bentoml_cattr.unstructure(self.generation_config) if flatten: - generation_config = dumped.pop("generation_config") dumped.update(generation_config) + else: + dumped["generation_config"] = generation_config return dumped def model_dump_json(self, **kwargs: t.Any): diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 33fb0844..d6aa07f5 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -1356,6 +1356,7 @@ def Runner( model_name: str, ensure_available: bool | None = None, init_local: bool = False, + implementation: t.Literal["pt", "flax", "tf"] | None = None, **attrs: t.Any, ) -> LLMRunner: """Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models' @@ -1380,6 +1381,10 @@ def Runner( Args: model_name: Supported model name from 'openllm models' + ensure_available: If True, it will download the model if it is not available. If False, it will skip downloading the model. + If False, make sure the model is available locally. + implementation: The given Runner implementation one choose for this Runner. By default, it is retrieved from the enviroment variable + of the respected model_name. For example: 'flan-t5' -> "OPENLLM_FLAN_T5_FRAMEWORK" init_local: If True, it will initialize the model locally. This is useful if you want to run the model locally. (Symmetrical to bentoml.Runner.init_local()) **attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs @@ -1387,9 +1392,11 @@ def Runner( """ runner = t.cast( "_BaseAutoLLMClass", - openllm[EnvVarMixin(model_name)["framework_value"]], # type: ignore (internal API) + openllm[implementation if implementation is not None else EnvVarMixin(model_name)["framework_value"]], # type: ignore (internal API) ).create_runner( - model_name, ensure_available=ensure_available if ensure_available is not None else init_local, **attrs + model_name, + ensure_available=ensure_available if ensure_available is not None else init_local, + **attrs, ) if init_local: diff --git a/src/openllm/_trainer.py b/src/openllm/_trainer.py deleted file mode 100644 index 905f2bea..00000000 --- a/src/openllm/_trainer.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import transformers - - -class PeftTrainer(transformers.Trainer): - ... - - -class PeftSaveCallback(transformers.TrainerCallback): - ... diff --git a/src/openllm/tests.py b/src/openllm/tests.py new file mode 100644 index 00000000..201d6804 --- /dev/null +++ b/src/openllm/tests.py @@ -0,0 +1,39 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing as t + +from .utils import is_flax_available +from .utils import is_tf_available +from .utils import is_torch_available + + +try: + import pytest +except ImportError: + raise ImportError("You need to install pytest to use 'openllm.tests' utilities: 'pip install pytest'") + + +def require_tf(f: t.Callable[..., t.Any]): + return pytest.mark.skipif(not is_tf_available(), reason="requires TensorFlow")(f) + + +def require_flax(f: t.Callable[..., t.Any]): + return pytest.mark.skipif(not is_flax_available(), reason="requires Flax")(f) + + +def require_torch(f: t.Callable[..., t.Any]): + return pytest.mark.skipif(not is_torch_available(), reason="requires PyTorch")(f) diff --git a/src/openllm/utils/import_utils.py b/src/openllm/utils/import_utils.py index ff5d03ec..3930c164 100644 --- a/src/openllm/utils/import_utils.py +++ b/src/openllm/utils/import_utils.py @@ -343,7 +343,7 @@ class EnvVarMixin(ReprMixin): raise KeyError(f"Key {item} not found in {self}") def __new__(cls, model_name: str, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None): - from .._configuration import _field_env_key + from .._configuration import field_env_key from . import codegen model_name = inflection.underscore(model_name) @@ -354,7 +354,7 @@ class EnvVarMixin(ReprMixin): # gen properties env key attributes = {"config", "model_id", "quantize", "framework", "bettertransformer"} for att in attributes: - setattr(res, att, _field_env_key(model_name, att.upper())) + setattr(res, att, field_env_key(model_name, att.upper())) # gen properties env value attributes_with_values = { diff --git a/tests/_strategies/_configuration.py b/tests/_strategies/_configuration.py index 2567f22b..e088e8ae 100644 --- a/tests/_strategies/_configuration.py +++ b/tests/_strategies/_configuration.py @@ -62,7 +62,7 @@ def make_llm_config( lines.append(f' __config__ = {{ {", ".join(_config_args)} }}') if fields is not None: for field, type_, default in fields: - lines.append(f" {field}: {type_} = {repr(default)}") + lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({repr(default)})") if generation_fields is not None: generation_lines = ["class GenerationConfig:"] for field, default in generation_fields: diff --git a/tests/models/flan_t5/test_modeling_flax_flan_t5.py b/tests/models/flan_t5/test_modeling_flax_flan_t5.py new file mode 100644 index 00000000..1cecce6e --- /dev/null +++ b/tests/models/flan_t5/test_modeling_flax_flan_t5.py @@ -0,0 +1,43 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +import openllm + + +@pytest.fixture +def qa_prompt() -> str: + return "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" + + +@pytest.fixture +def flan_t5_id() -> str: + return "google/flan-t5-small" + + +@openllm.tests.require_flax +def test_small_flax_flan(qa_prompt: str, flan_t5_id: str): + llm = openllm.AutoFlaxLLM.for_model("flan-t5", model_id=flan_t5_id, ensure_available=True) + generate = llm(qa_prompt) + assert generate + + +@openllm.tests.require_flax +def test_small_flax_runner_flan(qa_prompt: str, flan_t5_id: str): + llm = openllm.Runner("flan-t5", model_id=flan_t5_id, implementation="flax", init_local=True) + generate = llm(qa_prompt) + assert generate diff --git a/tests/models/flan_t5/test_modeling_tf_flan_t5.py b/tests/models/flan_t5/test_modeling_tf_flan_t5.py new file mode 100644 index 00000000..632c9a42 --- /dev/null +++ b/tests/models/flan_t5/test_modeling_tf_flan_t5.py @@ -0,0 +1,43 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +import openllm + + +@pytest.fixture +def qa_prompt() -> str: + return "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?" + + +@pytest.fixture +def flan_t5_id() -> str: + return "google/flan-t5-small" + + +@openllm.tests.require_tf +def test_small_tf_flan(qa_prompt: str, flan_t5_id: str): + llm = openllm.AutoTFLLM.for_model("flan-t5", model_id=flan_t5_id, ensure_available=True) + generate = llm(qa_prompt) + assert generate + + +@openllm.tests.require_tf +def test_small_tf_runner_flan(qa_prompt: str, flan_t5_id: str): + llm = openllm.Runner("flan-t5", model_id=flan_t5_id, implementation="tf", init_local=True) + generate = llm(qa_prompt) + assert generate diff --git a/tests/models/opt/test_modeling_flax_opt.py b/tests/models/opt/test_modeling_flax_opt.py new file mode 100644 index 00000000..111cf359 --- /dev/null +++ b/tests/models/opt/test_modeling_flax_opt.py @@ -0,0 +1,43 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +import openllm + + +@pytest.fixture +def qa_prompt() -> str: + return "Answer the following yes/no question by reasoning step-by-step. What is the weather in SF?" + + +@pytest.fixture +def opt_id() -> str: + return "facebook/opt-125m" + + +@openllm.tests.require_flax +def test_small_opt(qa_prompt: str, opt_id: str): + llm = openllm.AutoFlaxLLM.for_model("opt", model_id=opt_id, ensure_available=True) + generate = llm(qa_prompt) + assert generate + + +@openllm.tests.require_flax +def test_small_runner_opt(qa_prompt: str, opt_id: str): + llm = openllm.Runner("opt", implementation="flax", model_id=opt_id, init_local=True) + generate = llm(qa_prompt) + assert generate diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py new file mode 100644 index 00000000..bb510f32 --- /dev/null +++ b/tests/models/opt/test_modeling_tf_opt.py @@ -0,0 +1,43 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +import openllm + + +@pytest.fixture +def qa_prompt() -> str: + return "Answer the following yes/no question by reasoning step-by-step. What is the weather in SF?" + + +@pytest.fixture +def opt_id() -> str: + return "facebook/opt-125m" + + +@openllm.tests.require_tf +def test_small_opt(qa_prompt: str, opt_id: str): + llm = openllm.AutoTFLLM.for_model("opt", model_id=opt_id, ensure_available=True) + generate = llm(qa_prompt) + assert generate + + +@openllm.tests.require_tf +def test_small_runner_opt(qa_prompt: str, opt_id: str): + llm = openllm.Runner("opt", implementation="tf", model_id=opt_id, init_local=True) + generate = llm(qa_prompt) + assert generate diff --git a/tests/test_configuration.py b/tests/test_configuration.py index b995ee8c..491c5350 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -16,18 +16,23 @@ for ModelEnv construction and parsing environment variables.""" from __future__ import annotations +import contextlib import logging +import os +import typing as t +from unittest import mock +import attr import pytest from hypothesis import assume from hypothesis import given from hypothesis import strategies as st import openllm +import transformers from openllm._configuration import GenerationConfig from openllm._configuration import ModelSettings -from openllm._configuration import _field_env_key -from openllm.utils import DEBUG +from openllm._configuration import field_env_key from ._strategies._configuration import make_llm_config from ._strategies._configuration import model_settings @@ -35,6 +40,11 @@ from ._strategies._configuration import model_settings logger = logging.getLogger(__name__) +if t.TYPE_CHECKING: + DictStrAny = dict[str, t.Any] +else: + DictStrAny = dict + def test_missing_default(): with pytest.raises(ValueError, match="Either 'default_id' or 'model_ids'*"): @@ -92,6 +102,12 @@ def test_config_derivation(gen_settings: ModelSettings, field1: int): assert new_cls.__openllm_default_id__ == "asdfasdf" +@given(model_settings()) +def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings): + cl_ = make_llm_config("AttrsProtocolLLM", gen_settings) + assert attr.has(cl_) + + @given( model_settings(), st.integers(max_value=283473), @@ -134,7 +150,37 @@ def test_complex_struct_dump( ) -def test_struct_envvar(monkeypatch: pytest.MonkeyPatch): +@contextlib.contextmanager +def patch_env(**attrs: t.Any): + with mock.patch.dict(os.environ, attrs, clear=True): + yield + + +def test_struct_envvar(): + with patch_env( + **{ + field_env_key("env_llm", "field1"): "4", + field_env_key("env_llm", "temperature", suffix="generation"): "0.2", + } + ): + + class EnvLLM(openllm.LLMConfig): + __config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]} + field1: int = 2 + + class GenerationConfig: + temperature: float = 0.8 + + sent = EnvLLM.model_construct_env() + assert sent.field1 == 4 + assert sent["temperature"] == 0.2 + + overwrite_default = EnvLLM() + assert overwrite_default.field1 == 4 + assert overwrite_default["temperature"] == 0.2 + + +def test_struct_provided_fields(): class EnvLLM(openllm.LLMConfig): __config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]} field1: int = 2 @@ -142,23 +188,39 @@ def test_struct_envvar(monkeypatch: pytest.MonkeyPatch): class GenerationConfig: temperature: float = 0.8 - f1_env = _field_env_key(EnvLLM.__openllm_model_name__, "field1") - temperature_env = _field_env_key(EnvLLM.__openllm_model_name__, "temperature", suffix="generation") + sent = EnvLLM.model_construct_env(field1=20, temperature=0.4) + assert sent.field1 == 20 + assert sent.generation_config.temperature == 0.4 - if DEBUG: - logger.info(f"Env keys: {f1_env}, {temperature_env}") - with monkeypatch.context() as m: - m.setenv(f1_env, "4") - m.setenv(temperature_env, "0.2") - sent = EnvLLM() - assert sent.field1 == 4 - assert sent.generation_config.temperature == 0.8 - - # NOTE: This is the expected behaviour, where users pass in value, we respect it over envvar. - with monkeypatch.context() as m: - m.setenv(f1_env, "4") - m.setenv(temperature_env, "0.2") - sent = EnvLLM.model_construct_env(field1=20, temperature=0.4) - assert sent.field1 == 4 +def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as mk: + mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0)) + mk.setenv(field_env_key("overwrite_with_env_available", "temperature", suffix="generation"), str(0.2)) + sent = make_llm_config( + "OverwriteWithEnvAvailable", + {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]}, + fields=(("field1", "float", 3.0),), + ).model_construct_env(field1=20.0, temperature=0.4) assert sent.generation_config.temperature == 0.4 + assert sent.field1 == 20.0 + + +@given(model_settings()) +@pytest.mark.parametrize("return_dict,typ", [(True, DictStrAny), (False, transformers.GenerationConfig)]) +def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings): + cl_ = make_llm_config("ConversionLLM", gen_settings) + assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ) + + +@given(model_settings()) +def test_click_conversion(gen_settings: ModelSettings): + # currently our conversion omit Union type. + def cli_mock(**attrs: t.Any): + return attrs + + cl_ = make_llm_config("ClickConversionLLM", gen_settings) + wrapped = cl_.to_click_options(cli_mock) + filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union} + click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")] + assert len(filtered) == len(click_options_filtered) diff --git a/tools/generate-coverage.py b/tools/generate-coverage.py new file mode 100755 index 00000000..3d229648 --- /dev/null +++ b/tools/generate-coverage.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import orjson +from collections import defaultdict +from pathlib import Path +from lxml import etree + +ROOT = Path(__file__).resolve().parent.parent + +PACKAGES = { + 'src/openllm/': 'openllm' +} + +def main() -> int: + coverage_report = ROOT / "coverage.xml" + root = etree.fromstring(coverage_report.read_text()) + + raw_package_data: defaultdict[str, dict[str, int]] = defaultdict(lambda: {'hits': 0, 'misses': 0}) + for package in root.find('packages'): + for module in package.find('classes'): + filename = module.attrib['filename'] + for relative_path, package_name in PACKAGES.items(): + if filename.startswith(relative_path): + data = raw_package_data[package_name] + break + else: + message = f'unknown package: {module}' + raise ValueError(message) + + for line in module.find('lines'): + if line.attrib['hits'] == '1': + data['hits'] += 1 + else: + data['misses'] += 1 + + total_statements_covered = 0 + total_statements = 0 + coverage_data = {} + for package_name, data in sorted(raw_package_data.items()): + statements_covered = data['hits'] + statements = statements_covered + data['misses'] + total_statements_covered += statements_covered + total_statements += statements + + coverage_data[package_name] = {'statements_covered': statements_covered, 'statements': statements} + coverage_data['total'] = {'statements_covered': total_statements_covered, 'statements': total_statements} + + coverage_summary = ROOT / 'coverage-summary.json' + coverage_summary.write_text(orjson.dumps(coverage_data,option=orjson.OPT_INDENT_2).decode(), encoding='utf-8') + + return 0 +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/update-config-stubs.py b/tools/update-config-stubs.py index a881738e..2b162a24 100755 --- a/tools/update-config-stubs.py +++ b/tools/update-config-stubs.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2023 BentoML Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,14 +15,13 @@ from __future__ import annotations -import typing as t import os from pathlib import Path import openllm import importlib -from openllm._configuration import ModelSettings +from openllm._configuration import ModelSettings, GenerationConfig, PeftType # currently we are assuming the indentatio level is 4 for comments START_COMMENT = f"# {os.path.basename(__file__)}: start\n" @@ -76,6 +76,31 @@ def main() -> int: ) ) ) + for keys, type_pep563 in openllm.utils.codegen.get_annotations(GenerationConfig).items(): + lines.extend( + list( + map( + lambda line: " " * 8 + line, + [ + "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", + f'def __getitem__(self, item: t.Literal["{keys}"] = ...) -> {type_pep563}: ...\n', + ], + ) + ) + ) + + for keys in PeftType._member_names_: + lines.extend( + list( + map( + lambda line: " " * 8 + line, + [ + "@overload\n" if "overload" in dir(_imported) else "@t.overload\n", + f'def __getitem__(self, item: t.Literal["{keys.lower()}"] = ...) -> dict[str, t.Any]: ...\n', + ], + ) + ) + ) processed = ( processed[:start_idx] + [" " * 4 + START_COMMENT] + lines + [" " * 4 + END_COMMENT] + processed[end_idx + 1 :] diff --git a/tools/write-coverage-report.py b/tools/write-coverage-report.py new file mode 100755 index 00000000..9047406f --- /dev/null +++ b/tools/write-coverage-report.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +from __future__ import annotations +import orjson +from decimal import ROUND_DOWN, Decimal +from pathlib import Path + +PRECISION = Decimal('.01') + +ROOT = Path(__file__).resolve().parent.parent + + +def main(): + coverage_summary = ROOT / 'coverage-summary.json' + + coverage_data = orjson.loads(coverage_summary.read_text(encoding='utf-8')) + total_data = coverage_data.pop('total') + + lines = [ + '\n', + 'Package | Statements\n', + '------- | ----------\n', + ] + + for package, data in sorted(coverage_data.items()): + statements_covered = data['statements_covered'] + statements = data['statements'] + + rate = Decimal(statements_covered) / Decimal(statements) * 100 + rate = rate.quantize(PRECISION, rounding=ROUND_DOWN) + lines.append( + f'{package} | {100 if rate == 100 else rate}% ({statements_covered} / {statements})\n' # noqa: PLR2004 + ) + + total_statements_covered = total_data['statements_covered'] + total_statements = total_data['statements'] + total_rate = Decimal(total_statements_covered) / Decimal(total_statements) * 100 + total_rate = total_rate.quantize(PRECISION, rounding=ROUND_DOWN) + color = 'ok' if float(total_rate) >= 95 else 'critical' + lines.insert(0, f'![Code Coverage](https://img.shields.io/badge/coverage-{total_rate}%25-{color}?style=flat)\n') + + lines.append( + f'**Summary** | {100 if total_rate == 100 else total_rate}% ' + f'({total_statements_covered} / {total_statements})\n' + ) + + coverage_report = ROOT / 'coverage-report.md' + with coverage_report.open('w', encoding='utf-8') as f: + f.write(''.join(lines)) + return 0 + + +if __name__ == '__main__': + raise SystemExit(main())