mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 18:39:16 -04:00
fix(test): robustness (#64)
This commit is contained in:
2
.github/actions/setup-repo/action.yml
vendored
2
.github/actions/setup-repo/action.yml
vendored
@@ -53,7 +53,7 @@ runs:
|
||||
${{ steps.get-cache-key-prefix.outputs.prefix }}-pypi-
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: pip install hatch towncrier
|
||||
run: pip install hatch towncrier && pip install -e ".[all]"
|
||||
- name: Install pyright
|
||||
shell: bash
|
||||
run: npm install -g npm@^7 pyright
|
||||
|
||||
56
.github/workflows/ci.yml
vendored
56
.github/workflows/ci.yml
vendored
@@ -23,6 +23,7 @@ env:
|
||||
LINES: 120
|
||||
COLUMNS: 120
|
||||
OPENLLM_DO_NOT_TRACK: True
|
||||
PYTHONUNBUFFERED: '1'
|
||||
# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
|
||||
defaults:
|
||||
run:
|
||||
@@ -43,7 +44,60 @@ jobs:
|
||||
- name: Setup CI
|
||||
uses: ./.github/actions/setup-repo
|
||||
- name: Run tests
|
||||
run: hatch run test:p
|
||||
run: hatch run test:full
|
||||
- name: Disambiguate coverage filename
|
||||
run: mv .coverage ".coverage.${{ matrix.os }}"
|
||||
- name: Upload coverage data
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: coverage-data
|
||||
path: .coverage.*
|
||||
coverage:
|
||||
name: Coverage
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- tests
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup CI
|
||||
uses: ./.github/actions/setup-repo
|
||||
- name: Download coverage data
|
||||
uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: coverage-data
|
||||
- name: Combine coverage data
|
||||
run: hatch run coverage:combine
|
||||
- name: Export coverage reports
|
||||
run: |
|
||||
hatch run coverage:report-xml
|
||||
hatch run coverage:report-uncovered-html
|
||||
- name: Upload uncovered HTML report
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: uncovered-html-report
|
||||
path: htmlcov
|
||||
- name: Generate coverage summary
|
||||
run: hatch run coverage:generate-summary
|
||||
- name: Write coverage summary report
|
||||
if: github.event_name == 'pull_request'
|
||||
run: hatch run coverage:write-summary-report
|
||||
- name: Update coverage pull request comment
|
||||
if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
path: coverage-report.md
|
||||
check: # https://github.com/marketplace/actions/alls-green#why
|
||||
if: always()
|
||||
needs:
|
||||
- coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Decide whether the needed jobs succeeded or failed
|
||||
uses: re-actors/alls-green@release/v1
|
||||
with:
|
||||
jobs: ${{ toJSON(needs) }}
|
||||
concurrency:
|
||||
group: ci-${{ github.event.pull_request.number || github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
@@ -107,7 +107,7 @@ After setting up your environment, here's how you can start contributing:
|
||||
5. Run all tests to ensure your changes haven't broken anything:
|
||||
|
||||
```bash
|
||||
hatch run test:p
|
||||
hatch run test:full
|
||||
```
|
||||
|
||||
6. Commit your changes:
|
||||
@@ -141,7 +141,7 @@ directory and their filenames start with `test_`.
|
||||
Run all tests with:
|
||||
|
||||
```bash
|
||||
hatch run test:p
|
||||
hatch run test:full
|
||||
```
|
||||
|
||||
## Releasing a New Version
|
||||
|
||||
4
changelog.d/64.fix.md
Normal file
4
changelog.d/64.fix.md
Normal file
@@ -0,0 +1,4 @@
|
||||
Remove duplicated class instance of `generation_config` as it should be set via
|
||||
instance attributes.
|
||||
|
||||
fixes tests flakiness and one broken cases for parsing env
|
||||
18
hatch.toml
18
hatch.toml
@@ -20,6 +20,7 @@ typing = "pyright {args:src/openllm tests}"
|
||||
dependencies = [
|
||||
# NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy
|
||||
"coverage[toml]>=6.5",
|
||||
"lxml",
|
||||
"pytest",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
"pytest-xdist[psutil]",
|
||||
@@ -30,9 +31,20 @@ dependencies = [
|
||||
"hypothesis",
|
||||
"syrupy",
|
||||
]
|
||||
[envs.test.overrides]
|
||||
env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT="
|
||||
env.HERMETIC_TESTS.type = [{ value = "container", if = ["true"] }, "virtual"]
|
||||
[envs.test.scripts]
|
||||
cov = ["cov-test", "- coverage combine", "coverage report"]
|
||||
cov-test = "coverage run -m pytest {args:tests}"
|
||||
p = "pytest {args:tests}"
|
||||
_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml"
|
||||
full = "_run_script -n 3 --reruns 5 --reruns-delay 3 -r aR {args:tests}"
|
||||
[[envs.test.matrix]]
|
||||
python = ["3.8", "3.9", "3.10", "3.11"]
|
||||
[envs.coverage]
|
||||
dependencies = ["coverage[toml]>=6.5", "lxml", "orjson"]
|
||||
detached = true
|
||||
[envs.coverage.scripts]
|
||||
combine = "coverage combine {args}"
|
||||
generate-summary = "python tools/generate-coverage.py"
|
||||
report-uncovered-html = "coverage html --skip-covered --skip-empty"
|
||||
report-xml = "coverage xml"
|
||||
write-summary-report = "python tools/write-coverage-report.py"
|
||||
|
||||
@@ -220,13 +220,11 @@ typeCheckingMode = "strict"
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
omit = ["src/openllm/__about__.py"]
|
||||
parallel = true
|
||||
source_pkgs = ["openllm", "tests"]
|
||||
omit = ["src/openllm/__about__.py", "src/openllm/__main__.py", "src/openllm/tests.py", "src/openllm/utils/dummy_*.py"]
|
||||
source_pkgs = ["openllm"]
|
||||
|
||||
[tool.coverage.paths]
|
||||
openllm = ["src/openllm", "*/openllm/src/openllm"]
|
||||
tests = ["tests", "*/openllm/tests"]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:"]
|
||||
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:", "@overload", "# pragma: no cover"]
|
||||
|
||||
@@ -35,6 +35,13 @@ from .__about__ import __version__ as __version__
|
||||
from .exceptions import MissingDependencyError
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if utils.DEBUG:
|
||||
utils.set_debug_mode(True)
|
||||
utils.set_quiet_mode(False)
|
||||
@@ -71,6 +78,7 @@ _import_structure = {
|
||||
"models": [],
|
||||
"client": [],
|
||||
"playground": [],
|
||||
"tests": [],
|
||||
"cli": ["start", "start_grpc"],
|
||||
# NOTE: models
|
||||
"models.auto": [
|
||||
@@ -164,6 +172,7 @@ if t.TYPE_CHECKING:
|
||||
from . import exceptions as exceptions
|
||||
from . import models as models
|
||||
from . import playground as playground
|
||||
from . import tests as tests
|
||||
|
||||
# Specific types import
|
||||
from ._configuration import LLMConfig as LLMConfig
|
||||
|
||||
@@ -50,8 +50,6 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -358,8 +356,8 @@ class FineTuneConfig:
|
||||
return klass
|
||||
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationConfig:
|
||||
@attr.frozen(slots=True, repr=False)
|
||||
class GenerationConfig(ReprMixin):
|
||||
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
|
||||
with some additional validation and environment constructor.
|
||||
|
||||
@@ -609,6 +607,10 @@ class GenerationConfig:
|
||||
return getattr(self, item)
|
||||
raise KeyError(f"GenerationConfig has no attribute {item}")
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
return {i.name for i in attr.fields(self.__class__)}
|
||||
|
||||
|
||||
bentoml_cattr.register_unstructure_hook_factory(
|
||||
lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig),
|
||||
@@ -623,7 +625,7 @@ bentoml_cattr.register_unstructure_hook_factory(
|
||||
)
|
||||
|
||||
|
||||
def _field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str:
|
||||
def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None = None) -> str:
|
||||
return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
|
||||
|
||||
|
||||
@@ -731,13 +733,13 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]):
|
||||
)
|
||||
|
||||
_cl_name = cl_.__name__.replace("Config", "")
|
||||
_settings_attr = _ModelSettingsAttr.default()
|
||||
try:
|
||||
cls(**t.cast(DictStrAny, cl_.__config__))
|
||||
_settings_attr = attr.evolve(_settings_attr, **t.cast(DictStrAny, cl_.__config__))
|
||||
except TypeError:
|
||||
_settings_attr = cls.default()
|
||||
|
||||
if any(i not in cl_.__config__ for i in {"default_id", "model_ids"}):
|
||||
raise ValueError("Either 'default_id' or 'model_ids' are emptied under '__config__' (required fields).")
|
||||
|
||||
_settings_attr = attr.evolve(_settings_attr, **t.cast(DictStrAny, cl_.__config__))
|
||||
|
||||
_final_value_dct: DictStrAny = {
|
||||
"model_name": inflection.underscore(_cl_name)
|
||||
if _settings_attr["name_type"] == "dasherize"
|
||||
@@ -805,10 +807,9 @@ def _make_env_transformer(
|
||||
globs = {} if globs is None else globs
|
||||
globs.update(
|
||||
{
|
||||
"functools": functools,
|
||||
"__populate_env": dantic.env_converter,
|
||||
"__default_callback": default_callback,
|
||||
"__field_env": _field_env_key,
|
||||
"__field_env": field_env_key,
|
||||
"__suffix": suffix or "",
|
||||
"__model_name": model_name,
|
||||
}
|
||||
@@ -1289,9 +1290,9 @@ class LLMConfig(_ConfigAttr):
|
||||
|
||||
def add_repr(self):
|
||||
for key, fn in ReprMixin.__dict__.items():
|
||||
if key not in ("__module__", "__doc__", "__repr_keys__"):
|
||||
if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"):
|
||||
self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
|
||||
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs})
|
||||
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"})
|
||||
return self
|
||||
|
||||
def __init_subclass__(cls: type[LLMConfig]):
|
||||
@@ -1326,7 +1327,7 @@ class LLMConfig(_ConfigAttr):
|
||||
slots=True,
|
||||
weakref_slot=True,
|
||||
frozen=True,
|
||||
repr=True,
|
||||
repr=False,
|
||||
collect_by_mro=True,
|
||||
field_transformer=_make_env_transformer(
|
||||
cls,
|
||||
@@ -1356,9 +1357,9 @@ class LLMConfig(_ConfigAttr):
|
||||
val = cd.get(attr_name, attr.NOTHING)
|
||||
if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val):
|
||||
if val is attr.NOTHING:
|
||||
val = cls.Field(env=_field_env_key(model_name, attr_name))
|
||||
val = cls.Field(env=field_env_key(model_name, attr_name))
|
||||
else:
|
||||
val = cls.Field(default=val, env=_field_env_key(model_name, attr_name))
|
||||
val = cls.Field(default=val, env=field_env_key(model_name, attr_name))
|
||||
these[attr_name] = val
|
||||
unannotated = ca_names - annotated_names
|
||||
if len(unannotated) > 0:
|
||||
@@ -1371,13 +1372,6 @@ class LLMConfig(_ConfigAttr):
|
||||
cls.__openllm_accepted_keys__ = set(these.keys()) | {
|
||||
a.name for a in attr.fields(cls.__openllm_generation_class__)
|
||||
}
|
||||
# 'generation_config' wraps the GenerationConfig class
|
||||
# which is handled in _make_assignment_script
|
||||
these["generation_config"] = cls.Field(
|
||||
default=cls.__openllm_generation_class__(),
|
||||
description=inspect.cleandoc(cls.__openllm_generation_class__.__doc__ or ""),
|
||||
type=GenerationConfig,
|
||||
)
|
||||
|
||||
cls = cls._ConfigBuilder(cls, model_name, these).add_attrs_init().add_repr().build_class()
|
||||
# auto assignment attributes generated from __config__ after create the new slot class.
|
||||
@@ -1390,7 +1384,6 @@ class LLMConfig(_ConfigAttr):
|
||||
if cls.__module__ in sys.modules:
|
||||
globs.update(sys.modules[cls.__module__].__dict__)
|
||||
attr.resolve_types(cls.__openllm_generation_class__, globalns=globs)
|
||||
|
||||
cls = attr.resolve_types(cls, globalns=globs)
|
||||
# the hint cache for easier access
|
||||
cls.__openllm_hints__ = {
|
||||
@@ -1428,20 +1421,15 @@ class LLMConfig(_ConfigAttr):
|
||||
for k in _cached_keys:
|
||||
if k in generation_config or attrs.get(k) is None:
|
||||
del attrs[k]
|
||||
_cached_keys = tuple(k for k in _cached_keys if k in attrs)
|
||||
|
||||
self.__openllm_extras__ = config_merger.merge(
|
||||
first_not_none(__openllm_extras__, default={}),
|
||||
{k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__},
|
||||
)
|
||||
|
||||
for k in _cached_keys:
|
||||
if k in self.__openllm_extras__:
|
||||
del attrs[k]
|
||||
_cached_keys = tuple(k for k in _cached_keys if k in attrs)
|
||||
self.generation_config = self["generation_class"](**generation_config)
|
||||
|
||||
# The rest of attrs should only be the attributes to be passed to __attrs_init__
|
||||
self.__attrs_init__(generation_config=self["generation_class"](**generation_config), **attrs)
|
||||
self.__attrs_init__(**attrs)
|
||||
|
||||
# NOTE: These required fields should be at the top, as it will be kw_only
|
||||
|
||||
@@ -1487,6 +1475,98 @@ class LLMConfig(_ConfigAttr):
|
||||
def __getitem__(self, item: t.Literal["generation_class"] = ...) -> t.Type[GenerationConfig]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["extras"] = ...) -> t.Dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["max_new_tokens"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["min_length"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["min_new_tokens"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["early_stopping"] = ...) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["max_time"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["num_beams"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["num_beam_groups"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["penalty_alpha"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["use_cache"] = ...) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["temperature"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["top_k"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["top_p"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["typical_p"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["epsilon_cutoff"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["eta_cutoff"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["diversity_penalty"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["repetition_penalty"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["encoder_repetition_penalty"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["length_penalty"] = ...) -> float: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["no_repeat_ngram_size"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bad_words_ids"] = ...) -> t.List[t.List[int]]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["force_words_ids"] = ...) -> t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["renormalize_logits"] = ...) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["constraints"] = ...) -> t.List[Constraint]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["forced_bos_token_id"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["forced_eos_token_id"] = ...) -> t.Union[int, t.List[int]]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["remove_invalid_values"] = ...) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["exponential_decay_length_penalty"] = ...) -> t.Tuple[int, float]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["suppress_tokens"] = ...) -> t.List[int]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["begin_suppress_tokens"] = ...) -> t.List[int]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["forced_decoder_ids"] = ...) -> t.List[t.List[int]]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["num_return_sequences"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["output_attentions"] = ...) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["output_hidden_states"] = ...) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["output_scores"] = ...) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["pad_token_id"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bos_token_id"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["eos_token_id"] = ...) -> t.Union[int, t.List[int]]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["encoder_no_repeat_ngram_size"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["decoder_start_token_id"] = ...) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["prompt_tuning"] = ...) -> dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["p_tuning"] = ...) -> dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["prefix_tuning"] = ...) -> dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["lora"] = ...) -> dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["adalora"] = ...) -> dict[str, t.Any]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["adaption_prompt"] = ...) -> dict[str, t.Any]: ...
|
||||
# update-config-stubs.py: stop
|
||||
|
||||
# fmt: on
|
||||
@@ -1556,6 +1636,7 @@ class LLMConfig(_ConfigAttr):
|
||||
lambda ns: ns.update(
|
||||
{
|
||||
"__config__": config_merger.merge(copy.deepcopy(cls.__dict__["__config__"]), _new_cfg),
|
||||
"__base_config__": cls, # keep a reference for easy access
|
||||
}
|
||||
),
|
||||
)
|
||||
@@ -1573,9 +1654,11 @@ class LLMConfig(_ConfigAttr):
|
||||
|
||||
def model_dump(self, flatten: bool = False, **_: t.Any):
|
||||
dumped = bentoml_cattr.unstructure(self)
|
||||
generation_config = bentoml_cattr.unstructure(self.generation_config)
|
||||
if flatten:
|
||||
generation_config = dumped.pop("generation_config")
|
||||
dumped.update(generation_config)
|
||||
else:
|
||||
dumped["generation_config"] = generation_config
|
||||
return dumped
|
||||
|
||||
def model_dump_json(self, **kwargs: t.Any):
|
||||
|
||||
@@ -1356,6 +1356,7 @@ def Runner(
|
||||
model_name: str,
|
||||
ensure_available: bool | None = None,
|
||||
init_local: bool = False,
|
||||
implementation: t.Literal["pt", "flax", "tf"] | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> LLMRunner:
|
||||
"""Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'
|
||||
@@ -1380,6 +1381,10 @@ def Runner(
|
||||
|
||||
Args:
|
||||
model_name: Supported model name from 'openllm models'
|
||||
ensure_available: If True, it will download the model if it is not available. If False, it will skip downloading the model.
|
||||
If False, make sure the model is available locally.
|
||||
implementation: The given Runner implementation one choose for this Runner. By default, it is retrieved from the enviroment variable
|
||||
of the respected model_name. For example: 'flan-t5' -> "OPENLLM_FLAN_T5_FRAMEWORK"
|
||||
init_local: If True, it will initialize the model locally. This is useful if you want to
|
||||
run the model locally. (Symmetrical to bentoml.Runner.init_local())
|
||||
**attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs
|
||||
@@ -1387,9 +1392,11 @@ def Runner(
|
||||
"""
|
||||
runner = t.cast(
|
||||
"_BaseAutoLLMClass",
|
||||
openllm[EnvVarMixin(model_name)["framework_value"]], # type: ignore (internal API)
|
||||
openllm[implementation if implementation is not None else EnvVarMixin(model_name)["framework_value"]], # type: ignore (internal API)
|
||||
).create_runner(
|
||||
model_name, ensure_available=ensure_available if ensure_available is not None else init_local, **attrs
|
||||
model_name,
|
||||
ensure_available=ensure_available if ensure_available is not None else init_local,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
if init_local:
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
class PeftTrainer(transformers.Trainer):
|
||||
...
|
||||
|
||||
|
||||
class PeftSaveCallback(transformers.TrainerCallback):
|
||||
...
|
||||
39
src/openllm/tests.py
Normal file
39
src/openllm/tests.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from .utils import is_flax_available
|
||||
from .utils import is_tf_available
|
||||
from .utils import is_torch_available
|
||||
|
||||
|
||||
try:
|
||||
import pytest
|
||||
except ImportError:
|
||||
raise ImportError("You need to install pytest to use 'openllm.tests' utilities: 'pip install pytest'")
|
||||
|
||||
|
||||
def require_tf(f: t.Callable[..., t.Any]):
|
||||
return pytest.mark.skipif(not is_tf_available(), reason="requires TensorFlow")(f)
|
||||
|
||||
|
||||
def require_flax(f: t.Callable[..., t.Any]):
|
||||
return pytest.mark.skipif(not is_flax_available(), reason="requires Flax")(f)
|
||||
|
||||
|
||||
def require_torch(f: t.Callable[..., t.Any]):
|
||||
return pytest.mark.skipif(not is_torch_available(), reason="requires PyTorch")(f)
|
||||
@@ -343,7 +343,7 @@ class EnvVarMixin(ReprMixin):
|
||||
raise KeyError(f"Key {item} not found in {self}")
|
||||
|
||||
def __new__(cls, model_name: str, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None):
|
||||
from .._configuration import _field_env_key
|
||||
from .._configuration import field_env_key
|
||||
from . import codegen
|
||||
|
||||
model_name = inflection.underscore(model_name)
|
||||
@@ -354,7 +354,7 @@ class EnvVarMixin(ReprMixin):
|
||||
# gen properties env key
|
||||
attributes = {"config", "model_id", "quantize", "framework", "bettertransformer"}
|
||||
for att in attributes:
|
||||
setattr(res, att, _field_env_key(model_name, att.upper()))
|
||||
setattr(res, att, field_env_key(model_name, att.upper()))
|
||||
|
||||
# gen properties env value
|
||||
attributes_with_values = {
|
||||
|
||||
@@ -62,7 +62,7 @@ def make_llm_config(
|
||||
lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
|
||||
if fields is not None:
|
||||
for field, type_, default in fields:
|
||||
lines.append(f" {field}: {type_} = {repr(default)}")
|
||||
lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({repr(default)})")
|
||||
if generation_fields is not None:
|
||||
generation_lines = ["class GenerationConfig:"]
|
||||
for field, default in generation_fields:
|
||||
|
||||
43
tests/models/flan_t5/test_modeling_flax_flan_t5.py
Normal file
43
tests/models/flan_t5/test_modeling_flax_flan_t5.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qa_prompt() -> str:
|
||||
return "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flan_t5_id() -> str:
|
||||
return "google/flan-t5-small"
|
||||
|
||||
|
||||
@openllm.tests.require_flax
|
||||
def test_small_flax_flan(qa_prompt: str, flan_t5_id: str):
|
||||
llm = openllm.AutoFlaxLLM.for_model("flan-t5", model_id=flan_t5_id, ensure_available=True)
|
||||
generate = llm(qa_prompt)
|
||||
assert generate
|
||||
|
||||
|
||||
@openllm.tests.require_flax
|
||||
def test_small_flax_runner_flan(qa_prompt: str, flan_t5_id: str):
|
||||
llm = openllm.Runner("flan-t5", model_id=flan_t5_id, implementation="flax", init_local=True)
|
||||
generate = llm(qa_prompt)
|
||||
assert generate
|
||||
43
tests/models/flan_t5/test_modeling_tf_flan_t5.py
Normal file
43
tests/models/flan_t5/test_modeling_tf_flan_t5.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qa_prompt() -> str:
|
||||
return "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flan_t5_id() -> str:
|
||||
return "google/flan-t5-small"
|
||||
|
||||
|
||||
@openllm.tests.require_tf
|
||||
def test_small_tf_flan(qa_prompt: str, flan_t5_id: str):
|
||||
llm = openllm.AutoTFLLM.for_model("flan-t5", model_id=flan_t5_id, ensure_available=True)
|
||||
generate = llm(qa_prompt)
|
||||
assert generate
|
||||
|
||||
|
||||
@openllm.tests.require_tf
|
||||
def test_small_tf_runner_flan(qa_prompt: str, flan_t5_id: str):
|
||||
llm = openllm.Runner("flan-t5", model_id=flan_t5_id, implementation="tf", init_local=True)
|
||||
generate = llm(qa_prompt)
|
||||
assert generate
|
||||
43
tests/models/opt/test_modeling_flax_opt.py
Normal file
43
tests/models/opt/test_modeling_flax_opt.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qa_prompt() -> str:
|
||||
return "Answer the following yes/no question by reasoning step-by-step. What is the weather in SF?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opt_id() -> str:
|
||||
return "facebook/opt-125m"
|
||||
|
||||
|
||||
@openllm.tests.require_flax
|
||||
def test_small_opt(qa_prompt: str, opt_id: str):
|
||||
llm = openllm.AutoFlaxLLM.for_model("opt", model_id=opt_id, ensure_available=True)
|
||||
generate = llm(qa_prompt)
|
||||
assert generate
|
||||
|
||||
|
||||
@openllm.tests.require_flax
|
||||
def test_small_runner_opt(qa_prompt: str, opt_id: str):
|
||||
llm = openllm.Runner("opt", implementation="flax", model_id=opt_id, init_local=True)
|
||||
generate = llm(qa_prompt)
|
||||
assert generate
|
||||
43
tests/models/opt/test_modeling_tf_opt.py
Normal file
43
tests/models/opt/test_modeling_tf_opt.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qa_prompt() -> str:
|
||||
return "Answer the following yes/no question by reasoning step-by-step. What is the weather in SF?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opt_id() -> str:
|
||||
return "facebook/opt-125m"
|
||||
|
||||
|
||||
@openllm.tests.require_tf
|
||||
def test_small_opt(qa_prompt: str, opt_id: str):
|
||||
llm = openllm.AutoTFLLM.for_model("opt", model_id=opt_id, ensure_available=True)
|
||||
generate = llm(qa_prompt)
|
||||
assert generate
|
||||
|
||||
|
||||
@openllm.tests.require_tf
|
||||
def test_small_runner_opt(qa_prompt: str, opt_id: str):
|
||||
llm = openllm.Runner("opt", implementation="tf", model_id=opt_id, init_local=True)
|
||||
generate = llm(qa_prompt)
|
||||
assert generate
|
||||
@@ -16,18 +16,23 @@
|
||||
for ModelEnv construction and parsing environment variables."""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
from unittest import mock
|
||||
|
||||
import attr
|
||||
import pytest
|
||||
from hypothesis import assume
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
import openllm
|
||||
import transformers
|
||||
from openllm._configuration import GenerationConfig
|
||||
from openllm._configuration import ModelSettings
|
||||
from openllm._configuration import _field_env_key
|
||||
from openllm.utils import DEBUG
|
||||
from openllm._configuration import field_env_key
|
||||
|
||||
from ._strategies._configuration import make_llm_config
|
||||
from ._strategies._configuration import model_settings
|
||||
@@ -35,6 +40,11 @@ from ._strategies._configuration import model_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
DictStrAny = dict[str, t.Any]
|
||||
else:
|
||||
DictStrAny = dict
|
||||
|
||||
|
||||
def test_missing_default():
|
||||
with pytest.raises(ValueError, match="Either 'default_id' or 'model_ids'*"):
|
||||
@@ -92,6 +102,12 @@ def test_config_derivation(gen_settings: ModelSettings, field1: int):
|
||||
assert new_cls.__openllm_default_id__ == "asdfasdf"
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config("AttrsProtocolLLM", gen_settings)
|
||||
assert attr.has(cl_)
|
||||
|
||||
|
||||
@given(
|
||||
model_settings(),
|
||||
st.integers(max_value=283473),
|
||||
@@ -134,7 +150,37 @@ def test_complex_struct_dump(
|
||||
)
|
||||
|
||||
|
||||
def test_struct_envvar(monkeypatch: pytest.MonkeyPatch):
|
||||
@contextlib.contextmanager
|
||||
def patch_env(**attrs: t.Any):
|
||||
with mock.patch.dict(os.environ, attrs, clear=True):
|
||||
yield
|
||||
|
||||
|
||||
def test_struct_envvar():
|
||||
with patch_env(
|
||||
**{
|
||||
field_env_key("env_llm", "field1"): "4",
|
||||
field_env_key("env_llm", "temperature", suffix="generation"): "0.2",
|
||||
}
|
||||
):
|
||||
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]}
|
||||
field1: int = 2
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.8
|
||||
|
||||
sent = EnvLLM.model_construct_env()
|
||||
assert sent.field1 == 4
|
||||
assert sent["temperature"] == 0.2
|
||||
|
||||
overwrite_default = EnvLLM()
|
||||
assert overwrite_default.field1 == 4
|
||||
assert overwrite_default["temperature"] == 0.2
|
||||
|
||||
|
||||
def test_struct_provided_fields():
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]}
|
||||
field1: int = 2
|
||||
@@ -142,23 +188,39 @@ def test_struct_envvar(monkeypatch: pytest.MonkeyPatch):
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.8
|
||||
|
||||
f1_env = _field_env_key(EnvLLM.__openllm_model_name__, "field1")
|
||||
temperature_env = _field_env_key(EnvLLM.__openllm_model_name__, "temperature", suffix="generation")
|
||||
sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
|
||||
assert sent.field1 == 20
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
|
||||
if DEBUG:
|
||||
logger.info(f"Env keys: {f1_env}, {temperature_env}")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(f1_env, "4")
|
||||
m.setenv(temperature_env, "0.2")
|
||||
sent = EnvLLM()
|
||||
assert sent.field1 == 4
|
||||
assert sent.generation_config.temperature == 0.8
|
||||
|
||||
# NOTE: This is the expected behaviour, where users pass in value, we respect it over envvar.
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(f1_env, "4")
|
||||
m.setenv(temperature_env, "0.2")
|
||||
sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
|
||||
assert sent.field1 == 4
|
||||
def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mk:
|
||||
mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0))
|
||||
mk.setenv(field_env_key("overwrite_with_env_available", "temperature", suffix="generation"), str(0.2))
|
||||
sent = make_llm_config(
|
||||
"OverwriteWithEnvAvailable",
|
||||
{"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"]},
|
||||
fields=(("field1", "float", 3.0),),
|
||||
).model_construct_env(field1=20.0, temperature=0.4)
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
assert sent.field1 == 20.0
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
@pytest.mark.parametrize("return_dict,typ", [(True, DictStrAny), (False, transformers.GenerationConfig)])
|
||||
def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config("ConversionLLM", gen_settings)
|
||||
assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
def test_click_conversion(gen_settings: ModelSettings):
|
||||
# currently our conversion omit Union type.
|
||||
def cli_mock(**attrs: t.Any):
|
||||
return attrs
|
||||
|
||||
cl_ = make_llm_config("ClickConversionLLM", gen_settings)
|
||||
wrapped = cl_.to_click_options(cli_mock)
|
||||
filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
|
||||
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")]
|
||||
assert len(filtered) == len(click_options_filtered)
|
||||
|
||||
54
tools/generate-coverage.py
Executable file
54
tools/generate-coverage.py
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import orjson
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from lxml import etree
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
PACKAGES = {
|
||||
'src/openllm/': 'openllm'
|
||||
}
|
||||
|
||||
def main() -> int:
|
||||
coverage_report = ROOT / "coverage.xml"
|
||||
root = etree.fromstring(coverage_report.read_text())
|
||||
|
||||
raw_package_data: defaultdict[str, dict[str, int]] = defaultdict(lambda: {'hits': 0, 'misses': 0})
|
||||
for package in root.find('packages'):
|
||||
for module in package.find('classes'):
|
||||
filename = module.attrib['filename']
|
||||
for relative_path, package_name in PACKAGES.items():
|
||||
if filename.startswith(relative_path):
|
||||
data = raw_package_data[package_name]
|
||||
break
|
||||
else:
|
||||
message = f'unknown package: {module}'
|
||||
raise ValueError(message)
|
||||
|
||||
for line in module.find('lines'):
|
||||
if line.attrib['hits'] == '1':
|
||||
data['hits'] += 1
|
||||
else:
|
||||
data['misses'] += 1
|
||||
|
||||
total_statements_covered = 0
|
||||
total_statements = 0
|
||||
coverage_data = {}
|
||||
for package_name, data in sorted(raw_package_data.items()):
|
||||
statements_covered = data['hits']
|
||||
statements = statements_covered + data['misses']
|
||||
total_statements_covered += statements_covered
|
||||
total_statements += statements
|
||||
|
||||
coverage_data[package_name] = {'statements_covered': statements_covered, 'statements': statements}
|
||||
coverage_data['total'] = {'statements_covered': total_statements_covered, 'statements': total_statements}
|
||||
|
||||
coverage_summary = ROOT / 'coverage-summary.json'
|
||||
coverage_summary.write_text(orjson.dumps(coverage_data,option=orjson.OPT_INDENT_2).decode(), encoding='utf-8')
|
||||
|
||||
return 0
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,3 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,14 +15,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import openllm
|
||||
import importlib
|
||||
from openllm._configuration import ModelSettings
|
||||
from openllm._configuration import ModelSettings, GenerationConfig, PeftType
|
||||
|
||||
# currently we are assuming the indentatio level is 4 for comments
|
||||
START_COMMENT = f"# {os.path.basename(__file__)}: start\n"
|
||||
@@ -76,6 +76,31 @@ def main() -> int:
|
||||
)
|
||||
)
|
||||
)
|
||||
for keys, type_pep563 in openllm.utils.codegen.get_annotations(GenerationConfig).items():
|
||||
lines.extend(
|
||||
list(
|
||||
map(
|
||||
lambda line: " " * 8 + line,
|
||||
[
|
||||
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
|
||||
f'def __getitem__(self, item: t.Literal["{keys}"] = ...) -> {type_pep563}: ...\n',
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for keys in PeftType._member_names_:
|
||||
lines.extend(
|
||||
list(
|
||||
map(
|
||||
lambda line: " " * 8 + line,
|
||||
[
|
||||
"@overload\n" if "overload" in dir(_imported) else "@t.overload\n",
|
||||
f'def __getitem__(self, item: t.Literal["{keys.lower()}"] = ...) -> dict[str, t.Any]: ...\n',
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
processed = (
|
||||
processed[:start_idx] + [" " * 4 + START_COMMENT] + lines + [" " * 4 + END_COMMENT] + processed[end_idx + 1 :]
|
||||
|
||||
53
tools/write-coverage-report.py
Executable file
53
tools/write-coverage-report.py
Executable file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
import orjson
|
||||
from decimal import ROUND_DOWN, Decimal
|
||||
from pathlib import Path
|
||||
|
||||
PRECISION = Decimal('.01')
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def main():
|
||||
coverage_summary = ROOT / 'coverage-summary.json'
|
||||
|
||||
coverage_data = orjson.loads(coverage_summary.read_text(encoding='utf-8'))
|
||||
total_data = coverage_data.pop('total')
|
||||
|
||||
lines = [
|
||||
'\n',
|
||||
'Package | Statements\n',
|
||||
'------- | ----------\n',
|
||||
]
|
||||
|
||||
for package, data in sorted(coverage_data.items()):
|
||||
statements_covered = data['statements_covered']
|
||||
statements = data['statements']
|
||||
|
||||
rate = Decimal(statements_covered) / Decimal(statements) * 100
|
||||
rate = rate.quantize(PRECISION, rounding=ROUND_DOWN)
|
||||
lines.append(
|
||||
f'{package} | {100 if rate == 100 else rate}% ({statements_covered} / {statements})\n' # noqa: PLR2004
|
||||
)
|
||||
|
||||
total_statements_covered = total_data['statements_covered']
|
||||
total_statements = total_data['statements']
|
||||
total_rate = Decimal(total_statements_covered) / Decimal(total_statements) * 100
|
||||
total_rate = total_rate.quantize(PRECISION, rounding=ROUND_DOWN)
|
||||
color = 'ok' if float(total_rate) >= 95 else 'critical'
|
||||
lines.insert(0, f'\n')
|
||||
|
||||
lines.append(
|
||||
f'**Summary** | {100 if total_rate == 100 else total_rate}% '
|
||||
f'({total_statements_covered} / {total_statements})\n'
|
||||
)
|
||||
|
||||
coverage_report = ROOT / 'coverage-report.md'
|
||||
with coverage_report.open('w', encoding='utf-8') as f:
|
||||
f.write(''.join(lines))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user