fix(test): robustness (#64)

This commit is contained in:
Aaron Pham
2023-06-24 11:10:07 -04:00
committed by GitHub
parent 98328be394
commit 3593c764f0
21 changed files with 644 additions and 97 deletions

View File

@@ -53,7 +53,7 @@ runs:
${{ steps.get-cache-key-prefix.outputs.prefix }}-pypi-
- name: Install dependencies
shell: bash
run: pip install hatch towncrier
run: pip install hatch towncrier && pip install -e ".[all]"
- name: Install pyright
shell: bash
run: npm install -g npm@^7 pyright

View File

@@ -23,6 +23,7 @@ env:
LINES: 120
COLUMNS: 120
OPENLLM_DO_NOT_TRACK: True
PYTHONUNBUFFERED: '1'
# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
defaults:
run:
@@ -43,7 +44,60 @@ jobs:
- name: Setup CI
uses: ./.github/actions/setup-repo
- name: Run tests
run: hatch run test:p
run: hatch run test:full
- name: Disambiguate coverage filename
run: mv .coverage ".coverage.${{ matrix.os }}"
- name: Upload coverage data
uses: actions/upload-artifact@v3
with:
name: coverage-data
path: .coverage.*
coverage:
name: Coverage
runs-on: ubuntu-latest
needs:
- tests
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup CI
uses: ./.github/actions/setup-repo
- name: Download coverage data
uses: actions/download-artifact@v3
with:
name: coverage-data
- name: Combine coverage data
run: hatch run coverage:combine
- name: Export coverage reports
run: |
hatch run coverage:report-xml
hatch run coverage:report-uncovered-html
- name: Upload uncovered HTML report
uses: actions/upload-artifact@v3
with:
name: uncovered-html-report
path: htmlcov
- name: Generate coverage summary
run: hatch run coverage:generate-summary
- name: Write coverage summary report
if: github.event_name == 'pull_request'
run: hatch run coverage:write-summary-report
- name: Update coverage pull request comment
if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork
uses: marocchino/sticky-pull-request-comment@v2
with:
path: coverage-report.md
check: # https://github.com/marketplace/actions/alls-green#why
if: always()
needs:
- coverage
runs-on: ubuntu-latest
steps:
- name: Decide whether the needed jobs succeeded or failed
uses: re-actors/alls-green@release/v1
with:
jobs: ${{ toJSON(needs) }}
concurrency:
group: ci-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true

View File

@@ -107,7 +107,7 @@ After setting up your environment, here's how you can start contributing:
5. Run all tests to ensure your changes haven't broken anything:
```bash
hatch run test:p
hatch run test:full
```
6. Commit your changes:
@@ -141,7 +141,7 @@ directory and their filenames start with `test_`.
Run all tests with:
```bash
hatch run test:p
hatch run test:full
```
## Releasing a New Version

4
changelog.d/64.fix.md Normal file
View File

@@ -0,0 +1,4 @@
Remove duplicated class instance of `generation_config` as it should be set via
instance attributes.
fixes tests flakiness and one broken cases for parsing env

View File

@@ -20,6 +20,7 @@ typing = "pyright {args:src/openllm tests}"
dependencies = [
# NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy
"coverage[toml]>=6.5",
"lxml",
"pytest",
"pytest-asyncio>=0.21.0",
"pytest-xdist[psutil]",
@@ -30,9 +31,20 @@ dependencies = [
"hypothesis",
"syrupy",
]
[envs.test.overrides]
env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT="
env.HERMETIC_TESTS.type = [{ value = "container", if = ["true"] }, "virtual"]
[envs.test.scripts]
cov = ["cov-test", "- coverage combine", "coverage report"]
cov-test = "coverage run -m pytest {args:tests}"
p = "pytest {args:tests}"
_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml"
full = "_run_script -n 3 --reruns 5 --reruns-delay 3 -r aR {args:tests}"
[[envs.test.matrix]]
python = ["3.8", "3.9", "3.10", "3.11"]
[envs.coverage]
dependencies = ["coverage[toml]>=6.5", "lxml", "orjson"]
detached = true
[envs.coverage.scripts]
combine = "coverage combine {args}"
generate-summary = "python tools/generate-coverage.py"
report-uncovered-html = "coverage html --skip-covered --skip-empty"
report-xml = "coverage xml"
write-summary-report = "python tools/write-coverage-report.py"

View File

@@ -220,13 +220,11 @@ typeCheckingMode = "strict"
[tool.coverage.run]
branch = true
omit = ["src/openllm/__about__.py"]
parallel = true
source_pkgs = ["openllm", "tests"]
omit = ["src/openllm/__about__.py", "src/openllm/__main__.py", "src/openllm/tests.py", "src/openllm/utils/dummy_*.py"]
source_pkgs = ["openllm"]
[tool.coverage.paths]
openllm = ["src/openllm", "*/openllm/src/openllm"]
tests = ["tests", "*/openllm/tests"]
[tool.coverage.report]
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:"]
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:", "@overload", "# pragma: no cover"]

View File

@@ -35,6 +35,13 @@ from .__about__ import __version__ as __version__
from .exceptions import MissingDependencyError
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if hasattr(t, "get_overloads"):
from typing import overload
else:
from typing_extensions import overload
if utils.DEBUG:
utils.set_debug_mode(True)
utils.set_quiet_mode(False)
@@ -71,6 +78,7 @@ _import_structure = {
"models": [],
"client": [],
"playground": [],
"tests": [],
"cli": ["start", "start_grpc"],
# NOTE: models
"models.auto": [
@@ -164,6 +172,7 @@ if t.TYPE_CHECKING:
from . import exceptions as exceptions
from . import models as models
from . import playground as playground
from . import tests as tests
# Specific types import
from ._configuration import LLMConfig as LLMConfig

View File

@@ -50,8 +50,6 @@ from __future__ import annotations
import copy
import enum
import functools
import inspect
import logging
import os
import sys
@@ -358,8 +356,8 @@ class FineTuneConfig:
return klass
@attr.frozen(slots=True)
class GenerationConfig:
@attr.frozen(slots=True, repr=False)
class GenerationConfig(ReprMixin):
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
with some additional validation and environment constructor.
@@ -609,6 +607,10 @@ class GenerationConfig:
return getattr(self, item)
raise KeyError(f"GenerationConfig has no attribute {item}")
@property
def __repr_keys__(self) -> set[str]:
return {i.name for i in attr.fields(self.__class__)}
bentoml_cattr.register_unstructure_hook_factory(
lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig),
@@ -623,7 +625,7 @@ bentoml_cattr.register_unstructure_hook_factory(
)
def _field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str:
def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str:
return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
@@ -731,13 +733,13 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]):
)
_cl_name = cl_.__name__.replace("Config", "")
_settings_attr = _ModelSettingsAttr.default()
try:
cls(**t.cast(DictStrAny, cl_.__config__))
_settings_attr = attr.evolve(_settings_attr, **t.cast(DictStrAny, cl_.__config__))
except TypeError:
_settings_attr = cls.default()
if any(i not in cl_.__config__ for i in {"default_id", "model_ids"}):
raise ValueError("Either 'default_id' or 'model_ids' are emptied under '__config__' (required fields).")
_settings_attr = attr.evolve(_settings_attr, **t.cast(DictStrAny, cl_.__config__))
_final_value_dct: DictStrAny = {
"model_name": inflection.underscore(_cl_name)
if _settings_attr["name_type"] == "dasherize"
@@ -805,10 +807,9 @@ def _make_env_transformer(
globs = {} if globs is None else globs
globs.update(
{
"functools": functools,
"__populate_env": dantic.env_converter,
"__default_callback": default_callback,
"__field_env": _field_env_key,
"__field_env": field_env_key,
"__suffix": suffix or "",
"__model_name": model_name,
}
@@ -1289,9 +1290,9 @@ class LLMConfig(_ConfigAttr):
def add_repr(self):
for key, fn in ReprMixin.__dict__.items():
if key not in ("__module__", "__doc__", "__repr_keys__"):
if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"):
self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs})
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"})
return self
def __init_subclass__(cls: type[LLMConfig]):
@@ -1326,7 +1327,7 @@ class LLMConfig(_ConfigAttr):
slots=True,
weakref_slot=True,
frozen=True,
repr=True,
repr=False,
collect_by_mro=True,
field_transformer=_make_env_transformer(
cls,
@@ -1356,9 +1357,9 @@ class LLMConfig(_ConfigAttr):
val = cd.get(attr_name, attr.NOTHING)
if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val):
if val is attr.NOTHING:
val = cls.Field(env=_field_env_key(model_name, attr_name))
val = cls.Field(env=field_env_key(model_name, attr_name))
else:
val = cls.Field(default=val, env=_field_env_key(model_name, attr_name))
val = cls.Field(default=val, env=field_env_key(model_name, attr_name))
these[attr_name] = val
unannotated = ca_names - annotated_names
if len(unannotated) > 0:
@@ -1371,13 +1372,6 @@ class LLMConfig(_ConfigAttr):
cls.__openllm_accepted_keys__ = set(these.keys()) | {
a.name for a in attr.fields(cls.__openllm_generation_class__)
}
# 'generation_config' wraps the GenerationConfig class
# which is handled in _make_assignment_script
these["generation_config"] = cls.Field(
default=cls.__openllm_generation_class__(),
description=inspect.cleandoc(cls.__openllm_generation_class__.__doc__ or ""),
type=GenerationConfig,
)
cls = cls._ConfigBuilder(cls, model_name, these).add_attrs_init().add_repr().build_class()
# auto assignment attributes generated from __config__ after create the new slot class.
@@ -1390,7 +1384,6 @@ class LLMConfig(_ConfigAttr):
if cls.__module__ in sys.modules:
globs.update(sys.modules[cls.__module__].__dict__)
attr.resolve_types(cls.__openllm_generation_class__, globalns=globs)
cls = attr.resolve_types(cls, globalns=globs)
# the hint cache for easier access
cls.__openllm_hints__ = {
@@ -1428,20 +1421,15 @@ class LLMConfig(_ConfigAttr):
for k in _cached_keys:
if k in generation_config or attrs.get(k) is None:
del attrs[k]
_cached_keys = tuple(k for k in _cached_keys if k in attrs)
self.__openllm_extras__ = config_merger.merge(
first_not_none(__openllm_extras__, default={}),
{k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__},
)
for k in _cached_keys:
if k in self.__openllm_extras__:
del attrs[k]
_cached_keys = tuple(k for k in _cached_keys if k in attrs)
self.generation_config = self["generation_class"](**generation_config)
# The rest of attrs should only be the attributes to be passed to __attrs_init__
self.__attrs_init__(generation_config=self["generation_class"](**generation_config), **attrs)
self.__attrs_init__(**attrs)
# NOTE: These required fields should be at the top, as it will be kw_only
@@ -1487,6 +1475,98 @@ class LLMConfig(_ConfigAttr):
def __getitem__(self, item: t.Literal["generation_class"] = ...) -> t.Type[GenerationConfig]: ...
@overload
def __getitem__(self, item: t.Literal["extras"] = ...) -> t.Dict[str, t.Any]: ...
@overload
def __getitem__(self, item: t.Literal["max_new_tokens"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["min_length"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["min_new_tokens"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["early_stopping"] = ...) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["max_time"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["num_beams"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["num_beam_groups"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["penalty_alpha"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["use_cache"] = ...) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["temperature"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["top_k"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["top_p"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["typical_p"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["epsilon_cutoff"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["eta_cutoff"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["diversity_penalty"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["repetition_penalty"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["encoder_repetition_penalty"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["length_penalty"] = ...) -> float: ...
@overload
def __getitem__(self, item: t.Literal["no_repeat_ngram_size"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["bad_words_ids"] = ...) -> t.List[t.List[int]]: ...
@overload
def __getitem__(self, item: t.Literal["force_words_ids"] = ...) -> t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]]: ...
@overload
def __getitem__(self, item: t.Literal["renormalize_logits"] = ...) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["constraints"] = ...) -> t.List[Constraint]: ...
@overload
def __getitem__(self, item: t.Literal["forced_bos_token_id"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["forced_eos_token_id"] = ...) -> t.Union[int, t.List[int]]: ...
@overload
def __getitem__(self, item: t.Literal["remove_invalid_values"] = ...) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["exponential_decay_length_penalty"] = ...) -> t.Tuple[int, float]: ...
@overload
def __getitem__(self, item: t.Literal["suppress_tokens"] = ...) -> t.List[int]: ...
@overload
def __getitem__(self, item: t.Literal["begin_suppress_tokens"] = ...) -> t.List[int]: ...
@overload
def __getitem__(self, item: t.Literal["forced_decoder_ids"] = ...) -> t.List[t.List[int]]: ...
@overload
def __getitem__(self, item: t.Literal["num_return_sequences"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["output_attentions"] = ...) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["output_hidden_states"] = ...) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["output_scores"] = ...) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["pad_token_id"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["bos_token_id"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["eos_token_id"] = ...) -> t.Union[int, t.List[int]]: ...
@overload
def __getitem__(self, item: t.Literal["encoder_no_repeat_ngram_size"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["decoder_start_token_id"] = ...) -> int: ...
@overload
def __getitem__(self, item: t.Literal["prompt_tuning"] = ...) -> dict[str, t.Any]: ...
@overload
def __getitem__(self, item: t.Literal["p_tuning"] = ...) -> dict[str, t.Any]: ...
@overload
def __getitem__(self, item: t.Literal["prefix_tuning"] = ...) -> dict[str, t.Any]: ...
@overload
def __getitem__(self, item: t.Literal["lora"] = ...) -> dict[str, t.Any]: ...
@overload
def __getitem__(self, item: t.Literal["adalora"] = ...) -> dict[str, t.Any]: ...
@overload
def __getitem__(self, item: t.Literal["adaption_prompt"] = ...) -> dict[str, t.Any]: ...
# update-config-stubs.py: stop
# fmt: on
@@ -1556,6 +1636,7 @@ class LLMConfig(_ConfigAttr):
lambda ns: ns.update(
{
"__config__": config_merger.merge(copy.deepcopy(cls.__dict__["__config__"]), _new_cfg),
"__base_config__": cls, # keep a reference for easy access
}
),
)
@@ -1573,9 +1654,11 @@ class LLMConfig(_ConfigAttr):
def model_dump(self, flatten: bool = False, **_: t.Any):
dumped = bentoml_cattr.unstructure(self)
generation_config = bentoml_cattr.unstructure(self.generation_config)
if flatten:
generation_config = dumped.pop("generation_config")
dumped.update(generation_config)
else:
dumped["generation_config"] = generation_config
return dumped
def model_dump_json(self, **kwargs: t.Any):

View File

@@ -1356,6 +1356,7 @@ def Runner(
model_name: str,
ensure_available: bool | None = None,
init_local: bool = False,
implementation: t.Literal["pt", "flax", "tf"] | None = None,
**attrs: t.Any,
) -> LLMRunner:
"""Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'
@@ -1380,6 +1381,10 @@ def Runner(
Args:
model_name: Supported model name from 'openllm models'
ensure_available: If True, it will download the model if it is not available. If False, it will skip downloading the model.
If False, make sure the model is available locally.
implementation: The given Runner implementation one choose for this Runner. By default, it is retrieved from the enviroment variable
of the respected model_name. For example: 'flan-t5' -> "OPENLLM_FLAN_T5_FRAMEWORK"
init_local: If True, it will initialize the model locally. This is useful if you want to
run the model locally. (Symmetrical to bentoml.Runner.init_local())
**attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs
@@ -1387,9 +1392,11 @@ def Runner(
"""
runner = t.cast(
"_BaseAutoLLMClass",
openllm[EnvVarMixin(model_name)["framework_value"]], # type: ignore (internal API)
openllm[implementation if implementation is not None else EnvVarMixin(model_name)["framework_value"]], # type: ignore (internal API)
).create_runner(
model_name, ensure_available=ensure_available if ensure_available is not None else init_local, **attrs
model_name,
ensure_available=ensure_available if ensure_available is not None else init_local,
**attrs,
)
if init_local:

View File

@@ -1,25 +0,0 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import transformers
class PeftTrainer(transformers.Trainer):
...
class PeftSaveCallback(transformers.TrainerCallback):
...

39
src/openllm/tests.py Normal file
View File

@@ -0,0 +1,39 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
from .utils import is_flax_available
from .utils import is_tf_available
from .utils import is_torch_available
try:
import pytest
except ImportError:
raise ImportError("You need to install pytest to use 'openllm.tests' utilities: 'pip install pytest'")
def require_tf(f: t.Callable[..., t.Any]):
return pytest.mark.skipif(not is_tf_available(), reason="requires TensorFlow")(f)
def require_flax(f: t.Callable[..., t.Any]):
return pytest.mark.skipif(not is_flax_available(), reason="requires Flax")(f)
def require_torch(f: t.Callable[..., t.Any]):
return pytest.mark.skipif(not is_torch_available(), reason="requires PyTorch")(f)

View File

@@ -343,7 +343,7 @@ class EnvVarMixin(ReprMixin):
raise KeyError(f"Key {item} not found in {self}")
def __new__(cls, model_name: str, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None):
from .._configuration import _field_env_key
from .._configuration import field_env_key
from . import codegen
model_name = inflection.underscore(model_name)
@@ -354,7 +354,7 @@ class EnvVarMixin(ReprMixin):
# gen properties env key
attributes = {"config", "model_id", "quantize", "framework", "bettertransformer"}
for att in attributes:
setattr(res, att, _field_env_key(model_name, att.upper()))
setattr(res, att, field_env_key(model_name, att.upper()))
# gen properties env value
attributes_with_values = {

View File

@@ -62,7 +62,7 @@ def make_llm_config(
lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
if fields is not None:
for field, type_, default in fields:
lines.append(f" {field}: {type_} = {repr(default)}")
lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({repr(default)})")
if generation_fields is not None:
generation_lines = ["class GenerationConfig:"]
for field, default in generation_fields:

View File

@@ -0,0 +1,43 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
import openllm
@pytest.fixture
def qa_prompt() -> str:
return "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
@pytest.fixture
def flan_t5_id() -> str:
return "google/flan-t5-small"
@openllm.tests.require_flax
def test_small_flax_flan(qa_prompt: str, flan_t5_id: str):
llm = openllm.AutoFlaxLLM.for_model("flan-t5", model_id=flan_t5_id, ensure_available=True)
generate = llm(qa_prompt)
assert generate
@openllm.tests.require_flax
def test_small_flax_runner_flan(qa_prompt: str, flan_t5_id: str):
llm = openllm.Runner("flan-t5", model_id=flan_t5_id, implementation="flax", init_local=True)
generate = llm(qa_prompt)
assert generate

View File

@@ -0,0 +1,43 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
import openllm
@pytest.fixture
def qa_prompt() -> str:
return "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
@pytest.fixture
def flan_t5_id() -> str:
return "google/flan-t5-small"
@openllm.tests.require_tf
def test_small_tf_flan(qa_prompt: str, flan_t5_id: str):
llm = openllm.AutoTFLLM.for_model("flan-t5", model_id=flan_t5_id, ensure_available=True)
generate = llm(qa_prompt)
assert generate
@openllm.tests.require_tf
def test_small_tf_runner_flan(qa_prompt: str, flan_t5_id: str):
llm = openllm.Runner("flan-t5", model_id=flan_t5_id, implementation="tf", init_local=True)
generate = llm(qa_prompt)
assert generate

View File

@@ -0,0 +1,43 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
import openllm
@pytest.fixture
def qa_prompt() -> str:
return "Answer the following yes/no question by reasoning step-by-step. What is the weather in SF?"
@pytest.fixture
def opt_id() -> str:
return "facebook/opt-125m"
@openllm.tests.require_flax
def test_small_opt(qa_prompt: str, opt_id: str):
llm = openllm.AutoFlaxLLM.for_model("opt", model_id=opt_id, ensure_available=True)
generate = llm(qa_prompt)
assert generate
@openllm.tests.require_flax
def test_small_runner_opt(qa_prompt: str, opt_id: str):
llm = openllm.Runner("opt", implementation="flax", model_id=opt_id, init_local=True)
generate = llm(qa_prompt)
assert generate

View File

@@ -0,0 +1,43 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
import openllm
@pytest.fixture
def qa_prompt() -> str:
return "Answer the following yes/no question by reasoning step-by-step. What is the weather in SF?"
@pytest.fixture
def opt_id() -> str:
return "facebook/opt-125m"
@openllm.tests.require_tf
def test_small_opt(qa_prompt: str, opt_id: str):
llm = openllm.AutoTFLLM.for_model("opt", model_id=opt_id, ensure_available=True)
generate = llm(qa_prompt)
assert generate
@openllm.tests.require_tf
def test_small_runner_opt(qa_prompt: str, opt_id: str):
llm = openllm.Runner("opt", implementation="tf", model_id=opt_id, init_local=True)
generate = llm(qa_prompt)
assert generate

View File

@@ -16,18 +16,23 @@
for ModelEnv construction and parsing environment variables."""
from __future__ import annotations
import contextlib
import logging
import os
import typing as t
from unittest import mock
import attr
import pytest
from hypothesis import assume
from hypothesis import given
from hypothesis import strategies as st
import openllm
import transformers
from openllm._configuration import GenerationConfig
from openllm._configuration import ModelSettings
from openllm._configuration import _field_env_key
from openllm.utils import DEBUG
from openllm._configuration import field_env_key
from ._strategies._configuration import make_llm_config
from ._strategies._configuration import model_settings
@@ -35,6 +40,11 @@ from ._strategies._configuration import model_settings
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
DictStrAny = dict[str, t.Any]
else:
DictStrAny = dict
def test_missing_default():
with pytest.raises(ValueError, match="Either 'default_id' or 'model_ids'*"):
@@ -92,6 +102,12 @@ def test_config_derivation(gen_settings: ModelSettings, field1: int):
assert new_cls.__openllm_default_id__ == "asdfasdf"
@given(model_settings())
def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
cl_ = make_llm_config("AttrsProtocolLLM", gen_settings)
assert attr.has(cl_)
@given(
model_settings(),
st.integers(max_value=283473),
@@ -134,7 +150,37 @@ def test_complex_struct_dump(
)
def test_struct_envvar(monkeypatch: pytest.MonkeyPatch):
@contextlib.contextmanager
def patch_env(**attrs: t.Any):
with mock.patch.dict(os.environ, attrs, clear=True):
yield
def test_struct_envvar():
with patch_env(
**{
field_env_key("env_llm", "field1"): "4",
field_env_key("env_llm", "temperature", suffix="generation"): "0.2",
}
):
class EnvLLM(openllm.LLMConfig):
__config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]}
field1: int = 2
class GenerationConfig:
temperature: float = 0.8
sent = EnvLLM.model_construct_env()
assert sent.field1 == 4
assert sent["temperature"] == 0.2
overwrite_default = EnvLLM()
assert overwrite_default.field1 == 4
assert overwrite_default["temperature"] == 0.2
def test_struct_provided_fields():
class EnvLLM(openllm.LLMConfig):
__config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]}
field1: int = 2
@@ -142,23 +188,39 @@ def test_struct_envvar(monkeypatch: pytest.MonkeyPatch):
class GenerationConfig:
temperature: float = 0.8
f1_env = _field_env_key(EnvLLM.__openllm_model_name__, "field1")
temperature_env = _field_env_key(EnvLLM.__openllm_model_name__, "temperature", suffix="generation")
sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
assert sent.field1 == 20
assert sent.generation_config.temperature == 0.4
if DEBUG:
logger.info(f"Env keys: {f1_env}, {temperature_env}")
with monkeypatch.context() as m:
m.setenv(f1_env, "4")
m.setenv(temperature_env, "0.2")
sent = EnvLLM()
assert sent.field1 == 4
assert sent.generation_config.temperature == 0.8
# NOTE: This is the expected behaviour, where users pass in value, we respect it over envvar.
with monkeypatch.context() as m:
m.setenv(f1_env, "4")
m.setenv(temperature_env, "0.2")
sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
assert sent.field1 == 4
def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as mk:
mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0))
mk.setenv(field_env_key("overwrite_with_env_available", "temperature", suffix="generation"), str(0.2))
sent = make_llm_config(
"OverwriteWithEnvAvailable",
{"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]},
fields=(("field1", "float", 3.0),),
).model_construct_env(field1=20.0, temperature=0.4)
assert sent.generation_config.temperature == 0.4
assert sent.field1 == 20.0
@given(model_settings())
@pytest.mark.parametrize("return_dict,typ", [(True, DictStrAny), (False, transformers.GenerationConfig)])
def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings):
cl_ = make_llm_config("ConversionLLM", gen_settings)
assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)
@given(model_settings())
def test_click_conversion(gen_settings: ModelSettings):
# currently our conversion omit Union type.
def cli_mock(**attrs: t.Any):
return attrs
cl_ = make_llm_config("ClickConversionLLM", gen_settings)
wrapped = cl_.to_click_options(cli_mock)
filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")]
assert len(filtered) == len(click_options_filtered)

54
tools/generate-coverage.py Executable file
View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python3
from __future__ import annotations
import orjson
from collections import defaultdict
from pathlib import Path
from lxml import etree
ROOT = Path(__file__).resolve().parent.parent
PACKAGES = {
'src/openllm/': 'openllm'
}
def main() -> int:
coverage_report = ROOT / "coverage.xml"
root = etree.fromstring(coverage_report.read_text())
raw_package_data: defaultdict[str, dict[str, int]] = defaultdict(lambda: {'hits': 0, 'misses': 0})
for package in root.find('packages'):
for module in package.find('classes'):
filename = module.attrib['filename']
for relative_path, package_name in PACKAGES.items():
if filename.startswith(relative_path):
data = raw_package_data[package_name]
break
else:
message = f'unknown package: {module}'
raise ValueError(message)
for line in module.find('lines'):
if line.attrib['hits'] == '1':
data['hits'] += 1
else:
data['misses'] += 1
total_statements_covered = 0
total_statements = 0
coverage_data = {}
for package_name, data in sorted(raw_package_data.items()):
statements_covered = data['hits']
statements = statements_covered + data['misses']
total_statements_covered += statements_covered
total_statements += statements
coverage_data[package_name] = {'statements_covered': statements_covered, 'statements': statements}
coverage_data['total'] = {'statements_covered': total_statements_covered, 'statements': total_statements}
coverage_summary = ROOT / 'coverage-summary.json'
coverage_summary.write_text(orjson.dumps(coverage_data,option=orjson.OPT_INDENT_2).decode(), encoding='utf-8')
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,3 +1,4 @@
#!/usr/bin/env python3
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,14 +15,13 @@
from __future__ import annotations
import typing as t
import os
from pathlib import Path
import openllm
import importlib
from openllm._configuration import ModelSettings
from openllm._configuration import ModelSettings, GenerationConfig, PeftType
# currently we are assuming the indentatio level is 4 for comments
START_COMMENT = f"# {os.path.basename(__file__)}: start\n"
@@ -76,6 +76,31 @@ def main() -> int:
)
)
)
for keys, type_pep563 in openllm.utils.codegen.get_annotations(GenerationConfig).items():
lines.extend(
list(
map(
lambda line: " " * 8 + line,
[
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
f'def __getitem__(self, item: t.Literal["{keys}"] = ...) -> {type_pep563}: ...\n',
],
)
)
)
for keys in PeftType._member_names_:
lines.extend(
list(
map(
lambda line: " " * 8 + line,
[
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
f'def __getitem__(self, item: t.Literal["{keys.lower()}"] = ...) -> dict[str, t.Any]: ...\n',
],
)
)
)
processed = (
processed[:start_idx] + [" " * 4 + START_COMMENT] + lines + [" " * 4 + END_COMMENT] + processed[end_idx + 1 :]

53
tools/write-coverage-report.py Executable file
View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python3
from __future__ import annotations
import orjson
from decimal import ROUND_DOWN, Decimal
from pathlib import Path
PRECISION = Decimal('.01')
ROOT = Path(__file__).resolve().parent.parent
def main():
coverage_summary = ROOT / 'coverage-summary.json'
coverage_data = orjson.loads(coverage_summary.read_text(encoding='utf-8'))
total_data = coverage_data.pop('total')
lines = [
'\n',
'Package | Statements\n',
'------- | ----------\n',
]
for package, data in sorted(coverage_data.items()):
statements_covered = data['statements_covered']
statements = data['statements']
rate = Decimal(statements_covered) / Decimal(statements) * 100
rate = rate.quantize(PRECISION, rounding=ROUND_DOWN)
lines.append(
f'{package} | {100 if rate == 100 else rate}% ({statements_covered} / {statements})\n' # noqa: PLR2004
)
total_statements_covered = total_data['statements_covered']
total_statements = total_data['statements']
total_rate = Decimal(total_statements_covered) / Decimal(total_statements) * 100
total_rate = total_rate.quantize(PRECISION, rounding=ROUND_DOWN)
color = 'ok' if float(total_rate) >= 95 else 'critical'
lines.insert(0, f'![Code Coverage](https://img.shields.io/badge/coverage-{total_rate}%25-{color}?style=flat)\n')
lines.append(
f'**Summary** | {100 if total_rate == 100 else total_rate}% '
f'({total_statements_covered} / {total_statements})\n'
)
coverage_report = ROOT / 'coverage-report.md'
with coverage_report.open('w', encoding='utf-8') as f:
f.write(''.join(lines))
return 0
if __name__ == '__main__':
raise SystemExit(main())